search jobs: Enforce passing in a non-zero userID (#56825)

This adjusts the search API for exhaustive to enforce explicit passing
in of a userID to ensure we always run searches as the correct user.
This is to prevent accidently running as another user and thus exposing
results from repos a user is not allowed to see.

Test Plan: CI and new unit test

Co-authored-by: Stefan Hengl <stefan@sourcegraph.com>
This commit is contained in:
Keegan Carruthers-Smith 2023-09-20 16:02:47 +02:00 committed by GitHub
parent b0e8f22f5e
commit ca04f6db28
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 105 additions and 23 deletions

View File

@ -6,6 +6,7 @@ import (
"github.com/sourcegraph/log"
"github.com/sourcegraph/sourcegraph/internal/actor"
"github.com/sourcegraph/sourcegraph/internal/goroutine"
"github.com/sourcegraph/sourcegraph/internal/observation"
"github.com/sourcegraph/sourcegraph/internal/search/exhaustive/service"
@ -54,7 +55,10 @@ var _ workerutil.Handler[*types.ExhaustiveSearchJob] = &exhaustiveSearchHandler{
func (h *exhaustiveSearchHandler) Handle(ctx context.Context, logger log.Logger, record *types.ExhaustiveSearchJob) (err error) {
// TODO observability? read other handlers to see if we are missing stuff
q, err := h.newSearcher.NewSearch(ctx, record.Query)
userID := record.InitiatorID
ctx = actor.WithActor(ctx, actor.FromUser(userID))
q, err := h.newSearcher.NewSearch(ctx, userID, record.Query)
if err != nil {
return err
}

View File

@ -6,6 +6,7 @@ import (
"github.com/sourcegraph/log"
"github.com/sourcegraph/sourcegraph/internal/actor"
"github.com/sourcegraph/sourcegraph/internal/goroutine"
"github.com/sourcegraph/sourcegraph/internal/observation"
"github.com/sourcegraph/sourcegraph/internal/search/exhaustive/service"
@ -62,7 +63,10 @@ func (h *exhaustiveSearchRepoHandler) Handle(ctx context.Context, logger log.Log
return err
}
q, err := h.newSearcher.NewSearch(ctx, parent.Query)
userID := parent.InitiatorID
ctx = actor.WithActor(ctx, actor.FromUser(userID))
q, err := h.newSearcher.NewSearch(ctx, userID, parent.Query)
if err != nil {
return err
}

View File

@ -7,6 +7,7 @@ import (
"github.com/sourcegraph/log"
"github.com/sourcegraph/sourcegraph/internal/actor"
"github.com/sourcegraph/sourcegraph/internal/goroutine"
"github.com/sourcegraph/sourcegraph/internal/observation"
"github.com/sourcegraph/sourcegraph/internal/search/exhaustive/service"
@ -58,12 +59,14 @@ type exhaustiveSearchRepoRevHandler struct {
var _ workerutil.Handler[*types.ExhaustiveSearchRepoRevisionJob] = &exhaustiveSearchRepoRevHandler{}
func (h *exhaustiveSearchRepoRevHandler) Handle(ctx context.Context, logger log.Logger, record *types.ExhaustiveSearchRepoRevisionJob) error {
jobID, query, repoRev, err := h.store.GetQueryRepoRev(ctx, record)
jobID, query, repoRev, initiatorID, err := h.store.GetQueryRepoRev(ctx, record)
if err != nil {
return err
}
q, err := h.newSearcher.NewSearch(ctx, query)
ctx = actor.WithActor(ctx, actor.FromUser(initiatorID))
q, err := h.newSearcher.NewSearch(ctx, initiatorID, query)
if err != nil {
return err
}

View File

@ -43,6 +43,7 @@ go_test(
],
embed = [":service"],
deps = [
"//internal/actor",
"//internal/api",
"//internal/database",
"//internal/database/dbmocks",

View File

@ -8,6 +8,7 @@ import (
"strconv"
"strings"
"github.com/sourcegraph/sourcegraph/internal/actor"
"github.com/sourcegraph/sourcegraph/internal/search/exhaustive/types"
"github.com/sourcegraph/sourcegraph/internal/uploadstore"
"github.com/sourcegraph/sourcegraph/lib/errors"
@ -19,6 +20,10 @@ type NewSearcher interface {
// that calling this again in the future should return the same Searcher. IE
// it can speak to the DB, but maybe not gitserver.
//
// userID is explicitly passed in and must match the actor for ctx. This
// is done to prevent accidental bugs where we do a search on behalf of a
// user as an internal user/etc.
//
// I expect this to be roughly equivalent to creation of a search plan in
// our search codes job creator.
//
@ -26,7 +31,7 @@ type NewSearcher interface {
// affect what is returned. Alternatively as we release new versions of
// Sourcegraph what is returned could change. This means we are not exactly
// safe across repeated calls.
NewSearch(ctx context.Context, q string) (SearchQuery, error)
NewSearch(ctx context.Context, userID int32, q string) (SearchQuery, error)
}
// SearchQuery represents a search in a way we can break up the work. The flow is
@ -236,7 +241,11 @@ func NewSearcherFake() NewSearcher {
type backendFake struct{}
func (backendFake) NewSearch(ctx context.Context, q string) (SearchQuery, error) {
func (backendFake) NewSearch(ctx context.Context, userID int32, q string) (SearchQuery, error) {
if err := isSameUser(ctx, userID); err != nil {
return nil, err
}
var repoRevs []types.RepositoryRevision
for _, part := range strings.Fields(q) {
var r types.RepositoryRevision
@ -250,14 +259,22 @@ func (backendFake) NewSearch(ctx context.Context, q string) (SearchQuery, error)
if len(repoRevs) == 0 {
return nil, errors.Errorf("no repository revisions found in %q", q)
}
return searcherFake{repoRevs: repoRevs}, nil
return searcherFake{
userID: userID,
repoRevs: repoRevs,
}, nil
}
type searcherFake struct {
userID int32
repoRevs []types.RepositoryRevision
}
func (s searcherFake) RepositoryRevSpecs(context.Context) ([]types.RepositoryRevSpecs, error) {
func (s searcherFake) RepositoryRevSpecs(ctx context.Context) ([]types.RepositoryRevSpecs, error) {
if err := isSameUser(ctx, s.userID); err != nil {
return nil, err
}
seen := map[types.RepositoryRevSpecs]bool{}
var repoRevSpecs []types.RepositoryRevSpecs
for _, r := range s.repoRevs {
@ -270,7 +287,11 @@ func (s searcherFake) RepositoryRevSpecs(context.Context) ([]types.RepositoryRev
return repoRevSpecs, nil
}
func (s searcherFake) ResolveRepositoryRevSpec(_ context.Context, repoRevSpec types.RepositoryRevSpecs) ([]types.RepositoryRevision, error) {
func (s searcherFake) ResolveRepositoryRevSpec(ctx context.Context, repoRevSpec types.RepositoryRevSpecs) ([]types.RepositoryRevision, error) {
if err := isSameUser(ctx, s.userID); err != nil {
return nil, err
}
var repoRevs []types.RepositoryRevision
for _, r := range s.repoRevs {
if r.RepositoryRevSpecs == repoRevSpec {
@ -280,9 +301,24 @@ func (s searcherFake) ResolveRepositoryRevSpec(_ context.Context, repoRevSpec ty
return repoRevs, nil
}
func (s searcherFake) Search(_ context.Context, r types.RepositoryRevision, w CSVWriter) error {
func (s searcherFake) Search(ctx context.Context, r types.RepositoryRevision, w CSVWriter) error {
if err := isSameUser(ctx, s.userID); err != nil {
return err
}
if err := w.WriteHeader("repo", "revspec", "revision"); err != nil {
return err
}
return w.WriteRow(strconv.Itoa(int(r.Repository)), string(r.RevisionSpecifiers), string(r.Revision))
}
func isSameUser(ctx context.Context, userID int32) error {
if userID == 0 {
return errors.New("exhaustive search must be done on behalf of an authenticated user")
}
a := actor.FromContext(ctx)
if a == nil || a.UID != userID {
return errors.Errorf("exhaustive search must be run as user %d", userID)
}
return nil
}

View File

@ -12,6 +12,8 @@ import (
"github.com/stretchr/testify/require"
"golang.org/x/exp/slices"
"github.com/sourcegraph/sourcegraph/internal/actor"
"github.com/sourcegraph/sourcegraph/internal/search/client"
"github.com/sourcegraph/sourcegraph/internal/search/exhaustive/types"
"github.com/sourcegraph/sourcegraph/internal/uploadstore/mocks"
)
@ -39,7 +41,10 @@ type newSearcherTestCase struct {
func testNewSearcher(t *testing.T, ctx context.Context, newSearcher NewSearcher, tc newSearcherTestCase) {
assert := require.New(t)
searcher, err := newSearcher.NewSearch(ctx, tc.Query)
userID := int32(1)
ctx = actor.WithActor(ctx, actor.FromMockUser(userID))
searcher, err := newSearcher.NewSearch(ctx, userID, tc.Query)
assert.NoError(err)
// Test RepositoryRevSpecs
@ -65,6 +70,19 @@ func testNewSearcher(t *testing.T, ctx context.Context, newSearcher NewSearcher,
assert.Equal(tc.WantCSV, csv.buf.String())
}
func TestWrongUser(t *testing.T) {
assert := require.New(t)
userID1 := int32(1)
userID2 := int32(2)
ctx := actor.WithActor(context.Background(), actor.FromMockUser(userID1))
newSearcher := FromSearchClient(client.NewStrictMockSearchClient())
_, err := newSearcher.NewSearch(ctx, userID2, "foo")
assert.Error(err)
}
func joinStringer[T fmt.Stringer](xs []T) string {
var parts []string
for _, x := range xs {

View File

@ -20,9 +20,10 @@ import (
)
func FromSearchClient(client client.SearchClient) NewSearcher {
return newSearcherFunc(func(ctx context.Context, q string) (SearchQuery, error) {
// TODO adjust NewSearch API to enforce the user passing in a user id.
// IE do not rely on ctx actor since that could easily lead to a bug.
return newSearcherFunc(func(ctx context.Context, userID int32, q string) (SearchQuery, error) {
if err := isSameUser(ctx, userID); err != nil {
return nil, err
}
// TODO this hack is an ugly workaround to get the plan and jobs to
// get into a shape we like. it will break in bad ways but works for
@ -50,6 +51,7 @@ func FromSearchClient(client client.SearchClient) NewSearcher {
}
return searchQuery{
userID: userID,
exhaustive: exhaustive,
clients: client.JobClients(),
}, nil
@ -57,13 +59,14 @@ func FromSearchClient(client client.SearchClient) NewSearcher {
}
// TODO maybe reuse for the fake
type newSearcherFunc func(context.Context, string) (SearchQuery, error)
type newSearcherFunc func(context.Context, int32, string) (SearchQuery, error)
func (f newSearcherFunc) NewSearch(ctx context.Context, q string) (SearchQuery, error) {
return f(ctx, q)
func (f newSearcherFunc) NewSearch(ctx context.Context, userID int32, q string) (SearchQuery, error) {
return f(ctx, userID, q)
}
type searchQuery struct {
userID int32
exhaustive jobutil.Exhaustive
clients job.RuntimeClients
}
@ -71,6 +74,10 @@ type searchQuery struct {
// TODO make this an iterator return since the result could be large and the
// underlying infra already relies on iterators
func (s searchQuery) RepositoryRevSpecs(ctx context.Context) ([]types.RepositoryRevSpecs, error) {
if err := isSameUser(ctx, s.userID); err != nil {
return nil, err
}
var repoRevSpecs []types.RepositoryRevSpecs
it := s.exhaustive.RepositoryRevSpecs(ctx, s.clients)
for it.Next() {
@ -102,6 +109,10 @@ func (s searchQuery) RepositoryRevSpecs(ctx context.Context) ([]types.Repository
}
func (s searchQuery) ResolveRepositoryRevSpec(ctx context.Context, repoRevSpec types.RepositoryRevSpecs) ([]types.RepositoryRevision, error) {
if err := isSameUser(ctx, s.userID); err != nil {
return nil, err
}
repoPagerRepoRevSpec, err := s.toRepoRevSpecs(ctx, repoRevSpec)
if err != nil {
return nil, err
@ -152,6 +163,10 @@ func (s searchQuery) toRepoRevSpecs(ctx context.Context, repoRevSpec types.Repos
}
func (s searchQuery) Search(ctx context.Context, repoRev types.RepositoryRevision, w CSVWriter) error {
if err := isSameUser(ctx, s.userID); err != nil {
return err
}
repo, err := s.minimalRepo(ctx, repoRev.Repository)
if err != nil {
return err

View File

@ -89,8 +89,8 @@ VALUES (%s, %s)
RETURNING id
`
const getRepoRevSpecFmtStr = `
SELECT sj.id, sj.query, srj.repo_id, srj.ref_spec
const getQueryRepoRevFmtStr = `
SELECT sj.id, sj.initiator_id, sj.query, srj.repo_id, srj.ref_spec
FROM exhaustive_search_repo_jobs srj
JOIN exhaustive_search_jobs sj ON srj.search_job_id = sj.id
WHERE srj.id = %s
@ -100,15 +100,16 @@ func (s *Store) GetQueryRepoRev(ctx context.Context, job *types.ExhaustiveSearch
id int64,
query string,
repoRev types.RepositoryRevision,
initiatorID int32,
err error,
) {
row := s.QueryRow(ctx, sqlf.Sprintf(getRepoRevSpecFmtStr, job.SearchRepoJobID))
err = row.Scan(&id, &query, &repoRev.Repository, &repoRev.RevisionSpecifiers)
row := s.QueryRow(ctx, sqlf.Sprintf(getQueryRepoRevFmtStr, job.SearchRepoJobID))
err = row.Scan(&id, &initiatorID, &query, &repoRev.Repository, &repoRev.RevisionSpecifiers)
if err != nil {
return 0, "", types.RepositoryRevision{}, err
return 0, "", types.RepositoryRevision{}, -1, err
}
repoRev.Revision = job.Revision
return id, query, repoRev, nil
return id, query, repoRev, initiatorID, nil
}
func scanRevSearchJob(sc dbutil.Scanner) (*types.ExhaustiveSearchRepoRevisionJob, error) {