mirror of
https://github.com/sourcegraph/sourcegraph.git
synced 2026-02-06 16:51:55 +00:00
tenant: Add test to verify that we don't regress on tables without tenant (#64368)
This test ensures that newly added tables also have the tenant_id column. We will later extend this test to also check that RLS policies exist, once we created them. Test plan: Test passes, when I modify the new migration that adds tenant_id everywhere to skip a table it fails with a nice error message.
This commit is contained in:
parent
1b1229c867
commit
0aeb6fd0a0
@ -53,6 +53,26 @@ func getSchema(name string) (*schemas.Schema, bool) {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
// 🚨 SECURITY: These tables are NOT governed by Postgres RLS protection to isolate
|
||||
// tenant data.
|
||||
// This list should only ever contain tables that are system critical, and NOT tenant-specific.
|
||||
var tablesWithoutTenant = map[string]map[string]struct{}{
|
||||
"frontend": {
|
||||
"tenants": {}, // The tenant table itself, it cannot link to itself.
|
||||
"migration_logs": {}, // Maintained by migrator and not part of Sourcegraph proper.
|
||||
"versions": {}, // Maintained by migrator and not part of Sourcegraph proper.
|
||||
"critical_and_site_config": {}, // Site config is global to the instance so it does not have a tenant.
|
||||
},
|
||||
"codeintel": {
|
||||
"tenants": {}, // The tenant table itself, it cannot link to itself.
|
||||
"migration_logs": {}, // Maintained by migrator and not part of Sourcegraph proper.
|
||||
},
|
||||
"codeinsights": {
|
||||
"tenants": {}, // The tenant table itself, it cannot link to itself.
|
||||
"migration_logs": {}, // Maintained by migrator and not part of Sourcegraph proper.
|
||||
},
|
||||
}
|
||||
|
||||
func testMigrations(t *testing.T, name string, schema *schemas.Schema) {
|
||||
t.Helper()
|
||||
|
||||
@ -75,6 +95,27 @@ func testMigrations(t *testing.T, name string, schema *schemas.Schema) {
|
||||
if err := migrationRunner.Run(ctx, options); err != nil {
|
||||
t.Fatalf("failed to perform initial upgrade: %s", err)
|
||||
}
|
||||
|
||||
t.Run("verify tenant isolation config", func(t *testing.T) {
|
||||
// Get the list of all tables
|
||||
tables, err := getAllTables(db)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to retrieve tables: %v", err)
|
||||
}
|
||||
|
||||
for _, table := range tables {
|
||||
if _, ok := tablesWithoutTenant[name][table]; ok {
|
||||
continue
|
||||
}
|
||||
hasTenantID, err := tableHasTenantIDColumn(db, table)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to check tenant_id column for table %s: %v", table, err)
|
||||
}
|
||||
if !hasTenantID {
|
||||
t.Errorf("Table %s does not have a tenant_id column. In the migration that adds it, make sure to include \n\ntenant_id integer REFERENCES tenants(id) ON UPDATE CASCADE ON DELETE CASCADE;\n\n", table)
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
t.Run("down", func(t *testing.T) {
|
||||
// Run down to the root "squashed commits" migration. For this, we need to select
|
||||
@ -301,3 +342,52 @@ func applyMigration(db *sql.DB, definition definition.Definition, up bool) (err
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getAllTables(db *sql.DB) ([]string, error) {
|
||||
query := `
|
||||
SELECT table_name
|
||||
FROM information_schema.tables
|
||||
WHERE table_schema='public'
|
||||
AND table_type='BASE TABLE'
|
||||
`
|
||||
|
||||
rows, err := db.Query(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var tables []string
|
||||
for rows.Next() {
|
||||
var table string
|
||||
if err := rows.Scan(&table); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tables = append(tables, table)
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return tables, nil
|
||||
}
|
||||
|
||||
func tableHasTenantIDColumn(db *sql.DB, tableName string) (bool, error) {
|
||||
q := sqlf.Sprintf(`
|
||||
SELECT column_name
|
||||
FROM information_schema.columns
|
||||
WHERE table_name=%s AND column_name='tenant_id'
|
||||
`, tableName)
|
||||
|
||||
var columnName string
|
||||
err := db.QueryRow(q.Query(sqlf.PostgresBindVar), q.Args()...).Scan(&columnName)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
return columnName == "tenant_id", nil
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user