sentinel: Add store methods to insert and query vulnerability matches (#48486)

This commit is contained in:
Eric Fritz 2023-03-02 12:50:29 -06:00 committed by GitHub
parent c088168d97
commit 314c3b0128
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 582 additions and 1 deletions

View File

@ -0,0 +1,271 @@
package store
import (
"context"
"sort"
"strings"
"github.com/hashicorp/go-version"
"github.com/keegancsmith/sqlf"
"github.com/lib/pq"
"github.com/sourcegraph/sourcegraph/enterprise/internal/codeintel/sentinel/shared"
"github.com/sourcegraph/sourcegraph/internal/database/basestore"
"github.com/sourcegraph/sourcegraph/internal/database/batch"
"github.com/sourcegraph/sourcegraph/internal/database/dbutil"
"github.com/sourcegraph/sourcegraph/internal/observation"
)
func (s *store) VulnerabilityMatchByID(ctx context.Context, id int) (_ shared.VulnerabilityMatch, _ bool, err error) {
ctx, _, endObservation := s.operations.vulnerabilityMatchByID.With(ctx, &err, observation.Args{})
defer endObservation(1, observation.Args{})
matches, _, err := scanVulnerabilityMatchesAndCount(s.db.Query(ctx, sqlf.Sprintf(vulnerabilityMatchByIDQuery, id)))
if err != nil || len(matches) == 0 {
return shared.VulnerabilityMatch{}, false, err
}
return matches[0], true, nil
}
const vulnerabilityMatchByIDQuery = `
SELECT
m.id,
m.upload_id,
vap.vulnerability_id,
` + vulnerabilityAffectedPackageFields + `,
` + vulnerabilityAffectedSymbolFields + `,
0 AS count
FROM vulnerability_matches m
LEFT JOIN vulnerability_affected_packages vap ON vap.id = m.vulnerability_affected_package_id
LEFT JOIN vulnerability_affected_symbols vas ON vas.vulnerability_affected_package_id = vap.id
WHERE m.id = %s
`
func (s *store) GetVulnerabilityMatches(ctx context.Context, args shared.GetVulnerabilityMatchesArgs) (_ []shared.VulnerabilityMatch, _ int, err error) {
ctx, _, endObservation := s.operations.getVulnerabilityMatches.With(ctx, &err, observation.Args{})
defer endObservation(1, observation.Args{})
return scanVulnerabilityMatchesAndCount(s.db.Query(ctx, sqlf.Sprintf(getVulnerabilityMatchesQuery, args.Limit, args.Offset)))
}
const getVulnerabilityMatchesQuery = `
WITH limited_matches AS (
SELECT
m.id,
m.upload_id,
m.vulnerability_affected_package_id,
COUNT(*) OVER() AS count
FROM vulnerability_matches m
ORDER BY id
LIMIT %s OFFSET %s
)
SELECT
m.id,
m.upload_id,
vap.vulnerability_id,
` + vulnerabilityAffectedPackageFields + `,
` + vulnerabilityAffectedSymbolFields + `,
m.count
FROM limited_matches m
LEFT JOIN vulnerability_affected_packages vap ON vap.id = m.vulnerability_affected_package_id
LEFT JOIN vulnerability_affected_symbols vas ON vas.vulnerability_affected_package_id = vap.id
ORDER BY m.id, vap.id, vas.id
`
var flattenMatches = func(ms []shared.VulnerabilityMatch) []shared.VulnerabilityMatch {
flattened := []shared.VulnerabilityMatch{}
for _, m := range ms {
i := len(flattened) - 1
if len(flattened) == 0 || flattened[i].ID != m.ID {
flattened = append(flattened, m)
} else {
if flattened[i].AffectedPackage.PackageName == "" {
flattened[i].AffectedPackage = m.AffectedPackage
} else {
symbols := flattened[i].AffectedPackage.AffectedSymbols
symbols = append(symbols, m.AffectedPackage.AffectedSymbols...)
flattened[i].AffectedPackage.AffectedSymbols = symbols
}
}
}
return flattened
}
var scanVulnerabilityMatchesAndCount = func(rows basestore.Rows, queryErr error) ([]shared.VulnerabilityMatch, int, error) {
matches, totalCount, err := basestore.NewSliceWithCountScanner(func(s dbutil.Scanner) (match shared.VulnerabilityMatch, count int, _ error) {
var (
vap shared.AffectedPackage
vas shared.AffectedSymbol
fixedIn string
)
if err := s.Scan(
&match.ID,
&match.UploadID,
&match.VulnerabilityID,
// RHS(s) of left join (may be null)
&dbutil.NullString{S: &vap.PackageName},
&dbutil.NullString{S: &vap.Language},
&dbutil.NullString{S: &vap.Namespace},
pq.Array(&vap.VersionConstraint),
&dbutil.NullBool{B: &vap.Fixed},
&dbutil.NullString{S: &fixedIn},
&dbutil.NullString{S: &vas.Path},
pq.Array(vas.Symbols),
&count,
); err != nil {
return shared.VulnerabilityMatch{}, 0, err
}
if fixedIn != "" {
vap.FixedIn = &fixedIn
}
if vas.Path != "" {
vap.AffectedSymbols = append(vap.AffectedSymbols, vas)
}
if vap.PackageName != "" {
match.AffectedPackage = vap
}
return match, count, nil
})(rows, queryErr)
if err != nil {
return nil, 0, err
}
return flattenMatches(matches), totalCount, nil
}
func (s *store) ScanMatches(ctx context.Context) (err error) {
ctx, _, endObservation := s.operations.scanMatches.With(ctx, &err, observation.Args{})
defer endObservation(1, observation.Args{})
tx, err := s.db.Transact(ctx)
if err != nil {
return err
}
defer func() { err = tx.Done(err) }()
scipSchemeToVulnerabilityLanguage := map[string]string{
"gomod": "go",
"npm": "Javascript",
// TODO - java mapping
}
schemes := make([]string, 0, len(scipSchemeToVulnerabilityLanguage))
for scheme := range scipSchemeToVulnerabilityLanguage {
schemes = append(schemes, scheme)
}
sort.Strings(schemes)
mappings := make([]*sqlf.Query, 0, len(schemes))
for _, scheme := range schemes {
mappings = append(mappings, sqlf.Sprintf("(r.scheme = %s AND vap.language = %s)", scheme, scipSchemeToVulnerabilityLanguage[scheme]))
}
matches, err := scanFilteredVulnerabilityMatches(tx.Query(ctx, sqlf.Sprintf(
scanMatchesQuery,
sqlf.Join(mappings, " OR "),
)))
if err != nil {
return err
}
if err := tx.Exec(ctx, sqlf.Sprintf(scanMatchesTemporaryTableQuery)); err != nil {
return err
}
if err := batch.WithInserter(
ctx,
tx.Handle(),
"t_vulnerability_affected_packages",
batch.MaxNumPostgresParameters,
[]string{
"upload_id",
"vulnerability_affected_package_id",
},
func(inserter *batch.Inserter) error {
for _, match := range matches {
if err := inserter.Insert(
ctx,
match.UploadID,
match.VulnerabilityAffectedPackageID,
); err != nil {
return err
}
}
return nil
},
); err != nil {
return err
}
if err := tx.Exec(ctx, sqlf.Sprintf(scanMatchesUpdateQuery)); err != nil {
return err
}
return nil
}
const scanMatchesQuery = `
SELECT
r.dump_id,
vap.id,
r.version,
vap.version_constraint
FROM vulnerability_affected_packages vap
-- TODO - do we need the inverse? need to refine? the resulting match?
JOIN lsif_references r ON r.name LIKE '%%' || vap.package_name || '%%'
WHERE %s
`
const scanMatchesTemporaryTableQuery = `
CREATE TEMPORARY TABLE t_vulnerability_affected_packages (
upload_id INT NOT NULL,
vulnerability_affected_package_id INT NOT NULL
) ON COMMIT DROP
`
const scanMatchesUpdateQuery = `
INSERT INTO vulnerability_matches (upload_id, vulnerability_affected_package_id)
SELECT upload_id, vulnerability_affected_package_id FROM t_vulnerability_affected_packages
ON CONFLICT DO NOTHING
`
type VulnerabilityMatch struct {
UploadID int
VulnerabilityAffectedPackageID int
}
var scanFilteredVulnerabilityMatches = basestore.NewFilteredSliceScanner(func(s dbutil.Scanner) (m VulnerabilityMatch, _ bool, _ error) {
var (
version string
versionConstraints []string
)
if err := s.Scan(&m.UploadID, &m.VulnerabilityAffectedPackageID, &version, pq.Array(&versionConstraints)); err != nil {
return VulnerabilityMatch{}, false, err
}
matches, valid := versionMatchesConstraints(version, versionConstraints)
_ = valid // TODO - log un-parseable versions
return m, matches, nil
})
func versionMatchesConstraints(versionString string, constraints []string) (matches, valid bool) {
v, err := version.NewVersion(versionString)
if err != nil {
return false, false
}
constraint, err := version.NewConstraint(strings.Join(constraints, ","))
if err != nil {
return false, false
}
return constraint.Check(v), true
}

View File

@ -0,0 +1,279 @@
package store
import (
"context"
"fmt"
"math"
"strings"
"testing"
"time"
"github.com/google/go-cmp/cmp"
"github.com/keegancsmith/sqlf"
"github.com/lib/pq"
"github.com/sourcegraph/log/logtest"
"github.com/sourcegraph/sourcegraph/enterprise/internal/codeintel/sentinel/shared"
"github.com/sourcegraph/sourcegraph/enterprise/internal/codeintel/shared/types"
"github.com/sourcegraph/sourcegraph/internal/database"
"github.com/sourcegraph/sourcegraph/internal/database/basestore"
"github.com/sourcegraph/sourcegraph/internal/database/dbtest"
"github.com/sourcegraph/sourcegraph/internal/observation"
)
func TestVulnerabilityMatchByID(t *testing.T) {
ctx := context.Background()
logger := logtest.Scoped(t)
db := database.NewDB(logger, dbtest.NewDB(logger, t))
store := New(&observation.TestContext, db)
setupReferences(t, db)
if err := store.InsertVulnerabilities(ctx, testVulnerabilities); err != nil {
t.Fatalf("unexpected error inserting vulnerabilities: %s", err)
}
if err := store.ScanMatches(ctx); err != nil {
t.Fatalf("unexpected error inserting vulnerabilities: %s", err)
}
match, ok, err := store.VulnerabilityMatchByID(ctx, 3)
if err != nil {
t.Fatalf("unexpected error getting vulnerability match: %s", err)
}
if !ok {
t.Fatalf("expected match to exist")
}
expectedMatch := shared.VulnerabilityMatch{
ID: 3,
UploadID: 52,
VulnerabilityID: 1,
AffectedPackage: badConfig,
}
if diff := cmp.Diff(expectedMatch, match); diff != "" {
t.Errorf("unexpected vulnerability match (-want +got):\n%s", diff)
}
}
func TestGetVulnerabilityMatches(t *testing.T) {
ctx := context.Background()
logger := logtest.Scoped(t)
db := database.NewDB(logger, dbtest.NewDB(logger, t))
store := New(&observation.TestContext, db)
setupReferences(t, db)
if err := store.InsertVulnerabilities(ctx, testVulnerabilities); err != nil {
t.Fatalf("unexpected error inserting vulnerabilities: %s", err)
}
if err := store.ScanMatches(ctx); err != nil {
t.Fatalf("unexpected error inserting vulnerabilities: %s", err)
}
type testCase struct {
name string
expectedMatches []shared.VulnerabilityMatch
}
testCases := []testCase{
{
name: "all",
expectedMatches: []shared.VulnerabilityMatch{
{
ID: 1,
UploadID: 50,
VulnerabilityID: 1,
AffectedPackage: badConfig,
}, {
ID: 2,
UploadID: 51,
VulnerabilityID: 1,
AffectedPackage: badConfig,
}, {
ID: 3,
UploadID: 52,
VulnerabilityID: 1,
AffectedPackage: badConfig,
},
},
},
}
runTest := func(testCase testCase, lo, hi int) (errors int) {
t.Run(testCase.name, func(t *testing.T) {
matches, totalCount, err := store.GetVulnerabilityMatches(ctx, shared.GetVulnerabilityMatchesArgs{
Limit: 3,
Offset: lo,
})
if err != nil {
t.Fatalf("unexpected error getting vulnerability matches: %s", err)
}
if totalCount != len(testCase.expectedMatches) {
t.Errorf("unexpected total count. want=%d have=%d", len(testCase.expectedMatches), totalCount)
}
if totalCount != 0 {
if diff := cmp.Diff(testCase.expectedMatches[lo:hi], matches); diff != "" {
t.Errorf("unexpected vulnerability matches at offset %d-%d (-want +got):\n%s", lo, hi, diff)
errors++
}
}
})
return
}
for _, testCase := range testCases {
if n := len(testCase.expectedMatches); n == 0 {
runTest(testCase, 0, 0)
} else {
for lo := 0; lo < n; lo++ {
if numErrors := runTest(testCase, lo, int(math.Min(float64(lo)+3, float64(n)))); numErrors > 0 {
break
}
}
}
}
}
func setupReferences(t *testing.T, db database.DB) {
store := basestore.NewWithHandle(db.Handle())
insertUploads(t, db,
types.Upload{ID: 50},
types.Upload{ID: 51},
types.Upload{ID: 52},
types.Upload{ID: 53},
types.Upload{ID: 54},
types.Upload{ID: 55},
)
if err := store.Exec(context.Background(), sqlf.Sprintf(`
INSERT INTO lsif_references (scheme, name, version, dump_id)
VALUES
('gomod', 'github.com/go-nacelle/config', 'v1.2.3', 50),
('gomod', 'github.com/go-nacelle/config', 'v1.2.4', 51),
('gomod', 'github.com/go-nacelle/config', 'v1.2.5', 52),
('gomod', 'github.com/go-nacelle/config', 'v1.2.6', 53)
`)); err != nil {
t.Fatalf("failed to insert references: %s", err)
}
}
// insertUploads populates the lsif_uploads table with the given upload models.
func insertUploads(t testing.TB, db database.DB, uploads ...types.Upload) {
for _, upload := range uploads {
if upload.Commit == "" {
upload.Commit = makeCommit(upload.ID)
}
if upload.State == "" {
upload.State = "completed"
}
if upload.RepositoryID == 0 {
upload.RepositoryID = 50
}
if upload.Indexer == "" {
upload.Indexer = "lsif-go"
}
if upload.IndexerVersion == "" {
upload.IndexerVersion = "latest"
}
if upload.UploadedParts == nil {
upload.UploadedParts = []int{}
}
// Ensure we have a repo for the inner join in select queries
insertRepo(t, db, upload.RepositoryID, upload.RepositoryName)
query := sqlf.Sprintf(`
INSERT INTO lsif_uploads (
id,
commit,
root,
uploaded_at,
state,
failure_message,
started_at,
finished_at,
process_after,
num_resets,
num_failures,
repository_id,
indexer,
indexer_version,
num_parts,
uploaded_parts,
upload_size,
associated_index_id,
content_type,
should_reindex
) VALUES (%s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s, %s)
`,
upload.ID,
upload.Commit,
upload.Root,
upload.UploadedAt,
upload.State,
upload.FailureMessage,
upload.StartedAt,
upload.FinishedAt,
upload.ProcessAfter,
upload.NumResets,
upload.NumFailures,
upload.RepositoryID,
upload.Indexer,
upload.IndexerVersion,
upload.NumParts,
pq.Array(upload.UploadedParts),
upload.UploadSize,
upload.AssociatedIndexID,
upload.ContentType,
upload.ShouldReindex,
)
if _, err := db.ExecContext(context.Background(), query.Query(sqlf.PostgresBindVar), query.Args()...); err != nil {
t.Fatalf("unexpected error while inserting upload: %s", err)
}
}
}
// makeCommit formats an integer as a 40-character git commit hash.
func makeCommit(i int) string {
return fmt.Sprintf("%040d", i)
}
// insertRepo creates a repository record with the given id and name. If there is already a repository
// with the given identifier, nothing happens
func insertRepo(t testing.TB, db database.DB, id int, name string) {
if name == "" {
name = fmt.Sprintf("n-%d", id)
}
deletedAt := sqlf.Sprintf("NULL")
if strings.HasPrefix(name, "DELETED-") {
deletedAt = sqlf.Sprintf("%s", time.Unix(1587396557, 0).UTC())
}
insertRepoQuery := sqlf.Sprintf(
`INSERT INTO repo (id, name, deleted_at) VALUES (%s, %s, %s) ON CONFLICT (id) DO NOTHING`,
id,
name,
deletedAt,
)
if _, err := db.ExecContext(context.Background(), insertRepoQuery.Query(sqlf.PostgresBindVar), insertRepoQuery.Args()...); err != nil {
t.Fatalf("unexpected error while upserting repository: %s", err)
}
status := "cloned"
if strings.HasPrefix(name, "DELETED-") {
status = "not_cloned"
}
updateGitserverRepoQuery := sqlf.Sprintf(
`UPDATE gitserver_repos SET clone_status = %s WHERE repo_id = %s`,
status,
id,
)
if _, err := db.ExecContext(context.Background(), updateGitserverRepoQuery.Query(sqlf.PostgresBindVar), updateGitserverRepoQuery.Args()...); err != nil {
t.Fatalf("unexpected error while upserting gitserver repository: %s", err)
}
}

View File

@ -12,6 +12,9 @@ type operations struct {
getVulnerabilitiesByIDs *observation.Operation
getVulnerabilities *observation.Operation
insertVulnerabilities *observation.Operation
vulnerabilityMatchByID *observation.Operation
getVulnerabilityMatches *observation.Operation
scanMatches *observation.Operation
}
var m = new(metrics.SingletonREDMetrics)
@ -39,5 +42,8 @@ func newOperations(observationCtx *observation.Context) *operations {
getVulnerabilitiesByIDs: op("GetVulnerabilitiesByIDs"),
getVulnerabilities: op("GetVulnerabilities"),
insertVulnerabilities: op("InsertVulnerabilities"),
vulnerabilityMatchByID: op("VulnerabilityMatchByID"),
getVulnerabilityMatches: op("GetVulnerabilityMatches"),
scanMatches: op("ScanMatches"),
}
}

View File

@ -16,6 +16,10 @@ type Store interface {
GetVulnerabilitiesByIDs(ctx context.Context, ids ...int) (_ []shared.Vulnerability, err error)
GetVulnerabilities(ctx context.Context, args shared.GetVulnerabilitiesArgs) (_ []shared.Vulnerability, _ int, err error)
InsertVulnerabilities(ctx context.Context, vulnerabilities []shared.Vulnerability) (err error)
VulnerabilityMatchByID(ctx context.Context, id int) (shared.VulnerabilityMatch, bool, error)
GetVulnerabilityMatches(ctx context.Context, args shared.GetVulnerabilityMatchesArgs) ([]shared.VulnerabilityMatch, int, error)
ScanMatches(ctx context.Context) error
}
type store struct {

View File

@ -14,9 +14,15 @@ import (
"github.com/sourcegraph/sourcegraph/internal/observation"
)
var badConfig = shared.AffectedPackage{
Language: "go",
PackageName: "go-nacelle/config",
VersionConstraint: []string{"<= v1.2.5"},
}
var testVulnerabilities = []shared.Vulnerability{
// IDs assumed by insertion order
{ID: 1, SourceID: "CVE-ABC"},
{ID: 1, SourceID: "CVE-ABC", AffectedPackages: []shared.AffectedPackage{badConfig}},
{ID: 2, SourceID: "CVE-DEF"},
{ID: 3, SourceID: "CVE-GHI"},
{ID: 4, SourceID: "CVE-JKL"},

View File

@ -44,3 +44,15 @@ type AffectedSymbol struct {
Path string `json:"path"`
Symbols []string `json:"symbols"`
}
type GetVulnerabilityMatchesArgs struct {
Limit int
Offset int
}
type VulnerabilityMatch struct {
ID int
UploadID int
VulnerabilityID int
AffectedPackage AffectedPackage
}

1
go.mod
View File

@ -256,6 +256,7 @@ require (
github.com/golang-jwt/jwt v3.2.2+incompatible
github.com/google/go-github/v47 v47.1.0
github.com/grpc-ecosystem/go-grpc-middleware/providers/openmetrics/v2 v2.0.0-rc.3
github.com/hashicorp/go-version v1.6.0
github.com/hexops/autogold/v2 v2.1.0
github.com/k3a/html2text v1.1.0
github.com/opsgenie/opsgenie-go-sdk-v2 v1.2.13

2
go.sum
View File

@ -1338,6 +1338,8 @@ github.com/hashicorp/go-uuid v1.0.0/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/b
github.com/hashicorp/go-uuid v1.0.1/go.mod h1:6SBZvOh/SIDV7/2o3Jml5SYk/TvGqwFJ/bN7x4byOro=
github.com/hashicorp/go-version v1.1.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA=
github.com/hashicorp/go-version v1.2.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA=
github.com/hashicorp/go-version v1.6.0 h1:feTTfFNnjP967rlCxM/I9g701jU+RN74YKx2mOkIeek=
github.com/hashicorp/go-version v1.6.0/go.mod h1:fltr4n8CU8Ke44wwGCBoEymUuxUHl09ZGVZPK5anwXA=
github.com/hashicorp/go.net v0.0.1/go.mod h1:hjKkEWcCURg++eb33jQU7oqQcI9XDCnUzHA0oac0k90=
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
github.com/hashicorp/golang-lru v0.5.1/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=