diff --git a/internal/database/connections/live/migration_test.go b/internal/database/connections/live/migration_test.go index e162352cfba..e2c7094b048 100644 --- a/internal/database/connections/live/migration_test.go +++ b/internal/database/connections/live/migration_test.go @@ -53,6 +53,26 @@ func getSchema(name string) (*schemas.Schema, bool) { return nil, false } +// 🚨 SECURITY: These tables are NOT governed by Postgres RLS protection to isolate +// tenant data. +// This list should only ever contain tables that are system critical, and NOT tenant-specific. +var tablesWithoutTenant = map[string]map[string]struct{}{ + "frontend": { + "tenants": {}, // The tenant table itself, it cannot link to itself. + "migration_logs": {}, // Maintained by migrator and not part of Sourcegraph proper. + "versions": {}, // Maintained by migrator and not part of Sourcegraph proper. + "critical_and_site_config": {}, // Site config is global to the instance so it does not have a tenant. + }, + "codeintel": { + "tenants": {}, // The tenant table itself, it cannot link to itself. + "migration_logs": {}, // Maintained by migrator and not part of Sourcegraph proper. + }, + "codeinsights": { + "tenants": {}, // The tenant table itself, it cannot link to itself. + "migration_logs": {}, // Maintained by migrator and not part of Sourcegraph proper. + }, +} + func testMigrations(t *testing.T, name string, schema *schemas.Schema) { t.Helper() @@ -75,6 +95,27 @@ func testMigrations(t *testing.T, name string, schema *schemas.Schema) { if err := migrationRunner.Run(ctx, options); err != nil { t.Fatalf("failed to perform initial upgrade: %s", err) } + + t.Run("verify tenant isolation config", func(t *testing.T) { + // Get the list of all tables + tables, err := getAllTables(db) + if err != nil { + t.Fatalf("Failed to retrieve tables: %v", err) + } + + for _, table := range tables { + if _, ok := tablesWithoutTenant[name][table]; ok { + continue + } + hasTenantID, err := tableHasTenantIDColumn(db, table) + if err != nil { + t.Errorf("Failed to check tenant_id column for table %s: %v", table, err) + } + if !hasTenantID { + t.Errorf("Table %s does not have a tenant_id column. In the migration that adds it, make sure to include \n\ntenant_id integer REFERENCES tenants(id) ON UPDATE CASCADE ON DELETE CASCADE;\n\n", table) + } + } + }) }) t.Run("down", func(t *testing.T) { // Run down to the root "squashed commits" migration. For this, we need to select @@ -301,3 +342,52 @@ func applyMigration(db *sql.DB, definition definition.Definition, up bool) (err return nil } + +func getAllTables(db *sql.DB) ([]string, error) { + query := ` + SELECT table_name + FROM information_schema.tables + WHERE table_schema='public' + AND table_type='BASE TABLE' + ` + + rows, err := db.Query(query) + if err != nil { + return nil, err + } + defer rows.Close() + + var tables []string + for rows.Next() { + var table string + if err := rows.Scan(&table); err != nil { + return nil, err + } + tables = append(tables, table) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return tables, nil +} + +func tableHasTenantIDColumn(db *sql.DB, tableName string) (bool, error) { + q := sqlf.Sprintf(` + SELECT column_name + FROM information_schema.columns + WHERE table_name=%s AND column_name='tenant_id' + `, tableName) + + var columnName string + err := db.QueryRow(q.Query(sqlf.PostgresBindVar), q.Args()...).Scan(&columnName) + if err != nil { + if err == sql.ErrNoRows { + return false, nil + } + return false, err + } + + return columnName == "tenant_id", nil +}