mirror of
https://github.com/sourcegraph/sourcegraph.git
synced 2026-02-06 18:51:59 +00:00
search jobs: Enforce passing in a non-zero userID (#56825)
This adjusts the search API for exhaustive to enforce explicit passing in of a userID to ensure we always run searches as the correct user. This is to prevent accidently running as another user and thus exposing results from repos a user is not allowed to see. Test Plan: CI and new unit test Co-authored-by: Stefan Hengl <stefan@sourcegraph.com>
This commit is contained in:
parent
b0e8f22f5e
commit
ca04f6db28
@ -6,6 +6,7 @@ import (
|
||||
|
||||
"github.com/sourcegraph/log"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/internal/actor"
|
||||
"github.com/sourcegraph/sourcegraph/internal/goroutine"
|
||||
"github.com/sourcegraph/sourcegraph/internal/observation"
|
||||
"github.com/sourcegraph/sourcegraph/internal/search/exhaustive/service"
|
||||
@ -54,7 +55,10 @@ var _ workerutil.Handler[*types.ExhaustiveSearchJob] = &exhaustiveSearchHandler{
|
||||
func (h *exhaustiveSearchHandler) Handle(ctx context.Context, logger log.Logger, record *types.ExhaustiveSearchJob) (err error) {
|
||||
// TODO observability? read other handlers to see if we are missing stuff
|
||||
|
||||
q, err := h.newSearcher.NewSearch(ctx, record.Query)
|
||||
userID := record.InitiatorID
|
||||
ctx = actor.WithActor(ctx, actor.FromUser(userID))
|
||||
|
||||
q, err := h.newSearcher.NewSearch(ctx, userID, record.Query)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -6,6 +6,7 @@ import (
|
||||
|
||||
"github.com/sourcegraph/log"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/internal/actor"
|
||||
"github.com/sourcegraph/sourcegraph/internal/goroutine"
|
||||
"github.com/sourcegraph/sourcegraph/internal/observation"
|
||||
"github.com/sourcegraph/sourcegraph/internal/search/exhaustive/service"
|
||||
@ -62,7 +63,10 @@ func (h *exhaustiveSearchRepoHandler) Handle(ctx context.Context, logger log.Log
|
||||
return err
|
||||
}
|
||||
|
||||
q, err := h.newSearcher.NewSearch(ctx, parent.Query)
|
||||
userID := parent.InitiatorID
|
||||
ctx = actor.WithActor(ctx, actor.FromUser(userID))
|
||||
|
||||
q, err := h.newSearcher.NewSearch(ctx, userID, parent.Query)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/sourcegraph/log"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/internal/actor"
|
||||
"github.com/sourcegraph/sourcegraph/internal/goroutine"
|
||||
"github.com/sourcegraph/sourcegraph/internal/observation"
|
||||
"github.com/sourcegraph/sourcegraph/internal/search/exhaustive/service"
|
||||
@ -58,12 +59,14 @@ type exhaustiveSearchRepoRevHandler struct {
|
||||
var _ workerutil.Handler[*types.ExhaustiveSearchRepoRevisionJob] = &exhaustiveSearchRepoRevHandler{}
|
||||
|
||||
func (h *exhaustiveSearchRepoRevHandler) Handle(ctx context.Context, logger log.Logger, record *types.ExhaustiveSearchRepoRevisionJob) error {
|
||||
jobID, query, repoRev, err := h.store.GetQueryRepoRev(ctx, record)
|
||||
jobID, query, repoRev, initiatorID, err := h.store.GetQueryRepoRev(ctx, record)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
q, err := h.newSearcher.NewSearch(ctx, query)
|
||||
ctx = actor.WithActor(ctx, actor.FromUser(initiatorID))
|
||||
|
||||
q, err := h.newSearcher.NewSearch(ctx, initiatorID, query)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -43,6 +43,7 @@ go_test(
|
||||
],
|
||||
embed = [":service"],
|
||||
deps = [
|
||||
"//internal/actor",
|
||||
"//internal/api",
|
||||
"//internal/database",
|
||||
"//internal/database/dbmocks",
|
||||
|
||||
@ -8,6 +8,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/internal/actor"
|
||||
"github.com/sourcegraph/sourcegraph/internal/search/exhaustive/types"
|
||||
"github.com/sourcegraph/sourcegraph/internal/uploadstore"
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
@ -19,6 +20,10 @@ type NewSearcher interface {
|
||||
// that calling this again in the future should return the same Searcher. IE
|
||||
// it can speak to the DB, but maybe not gitserver.
|
||||
//
|
||||
// userID is explicitly passed in and must match the actor for ctx. This
|
||||
// is done to prevent accidental bugs where we do a search on behalf of a
|
||||
// user as an internal user/etc.
|
||||
//
|
||||
// I expect this to be roughly equivalent to creation of a search plan in
|
||||
// our search codes job creator.
|
||||
//
|
||||
@ -26,7 +31,7 @@ type NewSearcher interface {
|
||||
// affect what is returned. Alternatively as we release new versions of
|
||||
// Sourcegraph what is returned could change. This means we are not exactly
|
||||
// safe across repeated calls.
|
||||
NewSearch(ctx context.Context, q string) (SearchQuery, error)
|
||||
NewSearch(ctx context.Context, userID int32, q string) (SearchQuery, error)
|
||||
}
|
||||
|
||||
// SearchQuery represents a search in a way we can break up the work. The flow is
|
||||
@ -236,7 +241,11 @@ func NewSearcherFake() NewSearcher {
|
||||
|
||||
type backendFake struct{}
|
||||
|
||||
func (backendFake) NewSearch(ctx context.Context, q string) (SearchQuery, error) {
|
||||
func (backendFake) NewSearch(ctx context.Context, userID int32, q string) (SearchQuery, error) {
|
||||
if err := isSameUser(ctx, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var repoRevs []types.RepositoryRevision
|
||||
for _, part := range strings.Fields(q) {
|
||||
var r types.RepositoryRevision
|
||||
@ -250,14 +259,22 @@ func (backendFake) NewSearch(ctx context.Context, q string) (SearchQuery, error)
|
||||
if len(repoRevs) == 0 {
|
||||
return nil, errors.Errorf("no repository revisions found in %q", q)
|
||||
}
|
||||
return searcherFake{repoRevs: repoRevs}, nil
|
||||
return searcherFake{
|
||||
userID: userID,
|
||||
repoRevs: repoRevs,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type searcherFake struct {
|
||||
userID int32
|
||||
repoRevs []types.RepositoryRevision
|
||||
}
|
||||
|
||||
func (s searcherFake) RepositoryRevSpecs(context.Context) ([]types.RepositoryRevSpecs, error) {
|
||||
func (s searcherFake) RepositoryRevSpecs(ctx context.Context) ([]types.RepositoryRevSpecs, error) {
|
||||
if err := isSameUser(ctx, s.userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
seen := map[types.RepositoryRevSpecs]bool{}
|
||||
var repoRevSpecs []types.RepositoryRevSpecs
|
||||
for _, r := range s.repoRevs {
|
||||
@ -270,7 +287,11 @@ func (s searcherFake) RepositoryRevSpecs(context.Context) ([]types.RepositoryRev
|
||||
return repoRevSpecs, nil
|
||||
}
|
||||
|
||||
func (s searcherFake) ResolveRepositoryRevSpec(_ context.Context, repoRevSpec types.RepositoryRevSpecs) ([]types.RepositoryRevision, error) {
|
||||
func (s searcherFake) ResolveRepositoryRevSpec(ctx context.Context, repoRevSpec types.RepositoryRevSpecs) ([]types.RepositoryRevision, error) {
|
||||
if err := isSameUser(ctx, s.userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var repoRevs []types.RepositoryRevision
|
||||
for _, r := range s.repoRevs {
|
||||
if r.RepositoryRevSpecs == repoRevSpec {
|
||||
@ -280,9 +301,24 @@ func (s searcherFake) ResolveRepositoryRevSpec(_ context.Context, repoRevSpec ty
|
||||
return repoRevs, nil
|
||||
}
|
||||
|
||||
func (s searcherFake) Search(_ context.Context, r types.RepositoryRevision, w CSVWriter) error {
|
||||
func (s searcherFake) Search(ctx context.Context, r types.RepositoryRevision, w CSVWriter) error {
|
||||
if err := isSameUser(ctx, s.userID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := w.WriteHeader("repo", "revspec", "revision"); err != nil {
|
||||
return err
|
||||
}
|
||||
return w.WriteRow(strconv.Itoa(int(r.Repository)), string(r.RevisionSpecifiers), string(r.Revision))
|
||||
}
|
||||
|
||||
func isSameUser(ctx context.Context, userID int32) error {
|
||||
if userID == 0 {
|
||||
return errors.New("exhaustive search must be done on behalf of an authenticated user")
|
||||
}
|
||||
a := actor.FromContext(ctx)
|
||||
if a == nil || a.UID != userID {
|
||||
return errors.Errorf("exhaustive search must be run as user %d", userID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -12,6 +12,8 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/exp/slices"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/internal/actor"
|
||||
"github.com/sourcegraph/sourcegraph/internal/search/client"
|
||||
"github.com/sourcegraph/sourcegraph/internal/search/exhaustive/types"
|
||||
"github.com/sourcegraph/sourcegraph/internal/uploadstore/mocks"
|
||||
)
|
||||
@ -39,7 +41,10 @@ type newSearcherTestCase struct {
|
||||
func testNewSearcher(t *testing.T, ctx context.Context, newSearcher NewSearcher, tc newSearcherTestCase) {
|
||||
assert := require.New(t)
|
||||
|
||||
searcher, err := newSearcher.NewSearch(ctx, tc.Query)
|
||||
userID := int32(1)
|
||||
ctx = actor.WithActor(ctx, actor.FromMockUser(userID))
|
||||
|
||||
searcher, err := newSearcher.NewSearch(ctx, userID, tc.Query)
|
||||
assert.NoError(err)
|
||||
|
||||
// Test RepositoryRevSpecs
|
||||
@ -65,6 +70,19 @@ func testNewSearcher(t *testing.T, ctx context.Context, newSearcher NewSearcher,
|
||||
assert.Equal(tc.WantCSV, csv.buf.String())
|
||||
}
|
||||
|
||||
func TestWrongUser(t *testing.T) {
|
||||
assert := require.New(t)
|
||||
|
||||
userID1 := int32(1)
|
||||
userID2 := int32(2)
|
||||
|
||||
ctx := actor.WithActor(context.Background(), actor.FromMockUser(userID1))
|
||||
|
||||
newSearcher := FromSearchClient(client.NewStrictMockSearchClient())
|
||||
_, err := newSearcher.NewSearch(ctx, userID2, "foo")
|
||||
assert.Error(err)
|
||||
}
|
||||
|
||||
func joinStringer[T fmt.Stringer](xs []T) string {
|
||||
var parts []string
|
||||
for _, x := range xs {
|
||||
|
||||
@ -20,9 +20,10 @@ import (
|
||||
)
|
||||
|
||||
func FromSearchClient(client client.SearchClient) NewSearcher {
|
||||
return newSearcherFunc(func(ctx context.Context, q string) (SearchQuery, error) {
|
||||
// TODO adjust NewSearch API to enforce the user passing in a user id.
|
||||
// IE do not rely on ctx actor since that could easily lead to a bug.
|
||||
return newSearcherFunc(func(ctx context.Context, userID int32, q string) (SearchQuery, error) {
|
||||
if err := isSameUser(ctx, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// TODO this hack is an ugly workaround to get the plan and jobs to
|
||||
// get into a shape we like. it will break in bad ways but works for
|
||||
@ -50,6 +51,7 @@ func FromSearchClient(client client.SearchClient) NewSearcher {
|
||||
}
|
||||
|
||||
return searchQuery{
|
||||
userID: userID,
|
||||
exhaustive: exhaustive,
|
||||
clients: client.JobClients(),
|
||||
}, nil
|
||||
@ -57,13 +59,14 @@ func FromSearchClient(client client.SearchClient) NewSearcher {
|
||||
}
|
||||
|
||||
// TODO maybe reuse for the fake
|
||||
type newSearcherFunc func(context.Context, string) (SearchQuery, error)
|
||||
type newSearcherFunc func(context.Context, int32, string) (SearchQuery, error)
|
||||
|
||||
func (f newSearcherFunc) NewSearch(ctx context.Context, q string) (SearchQuery, error) {
|
||||
return f(ctx, q)
|
||||
func (f newSearcherFunc) NewSearch(ctx context.Context, userID int32, q string) (SearchQuery, error) {
|
||||
return f(ctx, userID, q)
|
||||
}
|
||||
|
||||
type searchQuery struct {
|
||||
userID int32
|
||||
exhaustive jobutil.Exhaustive
|
||||
clients job.RuntimeClients
|
||||
}
|
||||
@ -71,6 +74,10 @@ type searchQuery struct {
|
||||
// TODO make this an iterator return since the result could be large and the
|
||||
// underlying infra already relies on iterators
|
||||
func (s searchQuery) RepositoryRevSpecs(ctx context.Context) ([]types.RepositoryRevSpecs, error) {
|
||||
if err := isSameUser(ctx, s.userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var repoRevSpecs []types.RepositoryRevSpecs
|
||||
it := s.exhaustive.RepositoryRevSpecs(ctx, s.clients)
|
||||
for it.Next() {
|
||||
@ -102,6 +109,10 @@ func (s searchQuery) RepositoryRevSpecs(ctx context.Context) ([]types.Repository
|
||||
}
|
||||
|
||||
func (s searchQuery) ResolveRepositoryRevSpec(ctx context.Context, repoRevSpec types.RepositoryRevSpecs) ([]types.RepositoryRevision, error) {
|
||||
if err := isSameUser(ctx, s.userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
repoPagerRepoRevSpec, err := s.toRepoRevSpecs(ctx, repoRevSpec)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -152,6 +163,10 @@ func (s searchQuery) toRepoRevSpecs(ctx context.Context, repoRevSpec types.Repos
|
||||
}
|
||||
|
||||
func (s searchQuery) Search(ctx context.Context, repoRev types.RepositoryRevision, w CSVWriter) error {
|
||||
if err := isSameUser(ctx, s.userID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
repo, err := s.minimalRepo(ctx, repoRev.Repository)
|
||||
if err != nil {
|
||||
return err
|
||||
|
||||
@ -89,8 +89,8 @@ VALUES (%s, %s)
|
||||
RETURNING id
|
||||
`
|
||||
|
||||
const getRepoRevSpecFmtStr = `
|
||||
SELECT sj.id, sj.query, srj.repo_id, srj.ref_spec
|
||||
const getQueryRepoRevFmtStr = `
|
||||
SELECT sj.id, sj.initiator_id, sj.query, srj.repo_id, srj.ref_spec
|
||||
FROM exhaustive_search_repo_jobs srj
|
||||
JOIN exhaustive_search_jobs sj ON srj.search_job_id = sj.id
|
||||
WHERE srj.id = %s
|
||||
@ -100,15 +100,16 @@ func (s *Store) GetQueryRepoRev(ctx context.Context, job *types.ExhaustiveSearch
|
||||
id int64,
|
||||
query string,
|
||||
repoRev types.RepositoryRevision,
|
||||
initiatorID int32,
|
||||
err error,
|
||||
) {
|
||||
row := s.QueryRow(ctx, sqlf.Sprintf(getRepoRevSpecFmtStr, job.SearchRepoJobID))
|
||||
err = row.Scan(&id, &query, &repoRev.Repository, &repoRev.RevisionSpecifiers)
|
||||
row := s.QueryRow(ctx, sqlf.Sprintf(getQueryRepoRevFmtStr, job.SearchRepoJobID))
|
||||
err = row.Scan(&id, &initiatorID, &query, &repoRev.Repository, &repoRev.RevisionSpecifiers)
|
||||
if err != nil {
|
||||
return 0, "", types.RepositoryRevision{}, err
|
||||
return 0, "", types.RepositoryRevision{}, -1, err
|
||||
}
|
||||
repoRev.Revision = job.Revision
|
||||
return id, query, repoRev, nil
|
||||
return id, query, repoRev, initiatorID, nil
|
||||
}
|
||||
|
||||
func scanRevSearchJob(sc dbutil.Scanner) (*types.ExhaustiveSearchRepoRevisionJob, error) {
|
||||
|
||||
Loading…
Reference in New Issue
Block a user