drift: Polish output (#52030)

This commit is contained in:
Eric Fritz 2023-05-16 17:29:05 -05:00 committed by GitHub
parent fe8e70d94d
commit 2211519fdf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
30 changed files with 853 additions and 452 deletions

View File

@ -86,7 +86,7 @@ func Start(logger log.Logger, registerEnterpriseMigrators registerMigratorsUsing
cliutil.DownTo(appName, newRunner, outputFactory, false),
cliutil.Validate(appName, newRunner, outputFactory),
cliutil.Describe(appName, newRunner, outputFactory),
cliutil.Drift(appName, newRunner, outputFactory, DefaultSchemaFactories...),
cliutil.Drift(appName, newRunner, outputFactory, false, DefaultSchemaFactories...),
cliutil.AddLog(appName, newRunner, outputFactory),
cliutil.Upgrade(appName, newRunnerWithSchemas, outputFactory, registerMigrators, DefaultSchemaFactories...),
cliutil.Downgrade(appName, newRunnerWithSchemas, outputFactory, registerMigrators, DefaultSchemaFactories...),

View File

@ -125,7 +125,7 @@ var (
downToCommand = cliutil.DownTo("sg migration", makeRunner, outputFactory, true)
validateCommand = cliutil.Validate("sg migration", makeRunner, outputFactory)
describeCommand = cliutil.Describe("sg migration", makeRunner, outputFactory)
driftCommand = cliutil.Drift("sg migration", makeRunner, outputFactory, schemaFactories...)
driftCommand = cliutil.Drift("sg migration", makeRunner, outputFactory, true, schemaFactories...)
addLogCommand = cliutil.AddLog("sg migration", makeRunner, outputFactory)
leavesCommand = &cli.Command{

View File

@ -835,12 +835,13 @@ Available schemas:
Flags:
* `--auto-fix, --autofix`: Database goes brrrr.
* `--feedback`: provide feedback about this command by opening up a GitHub discussion
* `--file="<value>"`: The target schema description file.
* `--ignore-migrator-update`: Ignore the running migrator not being the latest version. It is recommended to use the latest migrator version.
* `--schema, --db="<value>"`: The target `schema` to compare. Possible values are 'frontend', 'codeintel' and 'codeinsights'
* `--skip-version-check`: Skip validation of the instance's current version.
* `--version="<value>"`: The target schema version. Can be a version (e.g. 5.0.2) or resolvable as a git revlike on the Sourcegraph repository (e.g. a branch, tag or commit hash).
* `--version="<value>"`: The target schema version. Can be a version (e.g. 5.0.2) or resolvable as a git revlike on the Sourcegraph repository (e.g. a branch, tag or commit hash). (default: HEAD)
### sg migration add-log

View File

@ -44,6 +44,10 @@ func (s *memoryStore) Versions(ctx context.Context) (appliedVersions, pendingVer
return s.appliedVersions, s.pendingVersions, s.failedVersions, nil
}
func (s *memoryStore) RunDDLStatements(ctx context.Context, statements []string) error {
return nil
}
func (s *memoryStore) TryLock(ctx context.Context) (bool, func(err error) error, error) {
return true, func(err error) error { return err }, nil
}

View File

@ -9,13 +9,21 @@ import (
"cuelang.org/go/pkg/strings"
"github.com/urfave/cli/v2"
"github.com/sourcegraph/sourcegraph/internal/database/migration/drift"
descriptions "github.com/sourcegraph/sourcegraph/internal/database/migration/schemas"
"github.com/sourcegraph/sourcegraph/internal/oobmigration"
"github.com/sourcegraph/sourcegraph/lib/errors"
"github.com/sourcegraph/sourcegraph/lib/output"
)
func Drift(commandName string, factory RunnerFactory, outFactory OutputFactory, expectedSchemaFactories ...ExpectedSchemaFactory) *cli.Command {
const maxAutofixAttempts = 3
func Drift(commandName string, factory RunnerFactory, outFactory OutputFactory, development bool, expectedSchemaFactories ...ExpectedSchemaFactory) *cli.Command {
defaultVersion := ""
if development {
defaultVersion = "HEAD"
}
schemaNameFlag := &cli.StringFlag{
Name: "schema",
Usage: "The target `schema` to compare. Possible values are 'frontend', 'codeintel' and 'codeinsights'",
@ -27,6 +35,7 @@ func Drift(commandName string, factory RunnerFactory, outFactory OutputFactory,
Usage: "The target schema version. Can be a version (e.g. 5.0.2) or resolvable as a git revlike on the Sourcegraph repository " +
"(e.g. a branch, tag or commit hash).",
Required: false,
Value: defaultVersion,
}
fileFlag := &cli.StringFlag{
Name: "file",
@ -37,12 +46,20 @@ func Drift(commandName string, factory RunnerFactory, outFactory OutputFactory,
Name: "skip-version-check",
Usage: "Skip validation of the instance's current version.",
Required: false,
Value: development,
}
ignoreMigratorUpdateCheckFlag := &cli.BoolFlag{
Name: "ignore-migrator-update",
Usage: "Ignore the running migrator not being the latest version. It is recommended to use the latest migrator version.",
Required: false,
}
// Only in available via `sg migration`` in development mode
autofixFlag := &cli.BoolFlag{
Name: "auto-fix",
Usage: "Database goes brrrr.",
Required: false,
Aliases: []string{"autofix"},
}
action := makeAction(outFactory, func(ctx context.Context, cmd *cli.Context, out *output.Output) error {
airgapped := isAirgapped(ctx)
@ -115,6 +132,7 @@ func Drift(commandName string, factory RunnerFactory, outFactory OutputFactory,
NewExplicitFileSchemaFactory(file),
}
}
expectedSchema, err := fetchExpectedSchema(ctx, schemaName, version, out, expectedSchemaFactories)
if err != nil {
return err
@ -125,22 +143,59 @@ func Drift(commandName string, factory RunnerFactory, outFactory OutputFactory,
return err
}
schema := schemas["public"]
summaries := drift.CompareSchemaDescriptions(schemaName, version, canonicalize(schema), canonicalize(expectedSchema))
return compareAndDisplaySchemaDescriptions(out, schemaName, version, canonicalize(schema), canonicalize(expectedSchema))
var autofixErr error
if autofixFlag.Get(cmd) {
for attempts := maxAutofixAttempts; attempts > 0 && len(summaries) > 0 && autofixErr == nil; attempts-- {
allStatements := []string{}
for _, summary := range summaries {
if statements, ok := summary.Statements(); ok {
allStatements = append(allStatements, statements...)
}
}
if len(allStatements) == 0 {
out.WriteLine(output.Linef(output.EmojiInfo, output.StyleReset, "No autofix to apply"))
break
}
autofixErr = store.RunDDLStatements(ctx, allStatements)
if autofixErr != nil {
out.WriteLine(output.Linef(output.EmojiFailure, output.StyleFailure, "Failed to apply autofix: %s", err))
} else {
out.WriteLine(output.Linef(output.EmojiSuccess, output.StyleSuccess, "Successfully applied autofix"))
out.WriteLine(output.Linef(output.EmojiInfo, output.StyleReset, "Re-checking drift"))
}
schemas, err := store.Describe(ctx)
if err != nil {
return err
}
schema := schemas["public"]
summaries = drift.CompareSchemaDescriptions(schemaName, version, canonicalize(schema), canonicalize(expectedSchema))
}
}
return displayDriftSummaries(out, summaries)
})
flags := []cli.Flag{
schemaNameFlag,
versionFlag,
fileFlag,
skipVersionCheckFlag,
ignoreMigratorUpdateCheckFlag,
}
if development {
flags = append(flags, autofixFlag)
}
return &cli.Command{
Name: "drift",
Usage: "Detect differences between the current database schema and the expected schema",
Description: ConstructLongHelp(),
Action: action,
Flags: []cli.Flag{
schemaNameFlag,
versionFlag,
fileFlag,
skipVersionCheckFlag,
ignoreMigratorUpdateCheckFlag,
},
Flags: flags,
}
}

View File

@ -7,17 +7,15 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/sourcegraph/sourcegraph/internal/database/migration/drift"
"github.com/sourcegraph/sourcegraph/internal/database/migration/schemas"
"github.com/sourcegraph/sourcegraph/lib/errors"
"github.com/sourcegraph/sourcegraph/lib/output"
)
var errOutOfSync = errors.Newf("database schema is out of sync")
func compareAndDisplaySchemaDescriptions(rawOut *output.Output, schemaName, version string, actual, expected schemas.SchemaDescription) (err error) {
func displayDriftSummaries(rawOut *output.Output, summaries []drift.Summary) (err error) {
out := &preambledOutput{rawOut, false}
for _, summary := range drift.CompareSchemaDescriptions(schemaName, version, actual, expected) {
for _, summary := range summaries {
displaySummary(out, summary)
err = errOutOfSync
}

View File

@ -19,6 +19,7 @@ type Store interface {
WithMigrationLog(ctx context.Context, definition definition.Definition, up bool, f func() error) error
Describe(ctx context.Context) (map[string]schemas.SchemaDescription, error)
Versions(ctx context.Context) (appliedVersions, pendingVersions, failedVersions []int, _ error)
RunDDLStatements(ctx context.Context, statements []string) error
}
// OutputFactory allows providing global output that might not be instantiated at compile time.

View File

@ -6,6 +6,7 @@ import (
"fmt"
"github.com/sourcegraph/sourcegraph/internal/database/migration/definition"
"github.com/sourcegraph/sourcegraph/internal/database/migration/drift"
"github.com/sourcegraph/sourcegraph/internal/database/migration/runner"
"github.com/sourcegraph/sourcegraph/internal/database/migration/schemas"
"github.com/sourcegraph/sourcegraph/internal/database/migration/shared"
@ -339,18 +340,19 @@ func CheckDrift(ctx context.Context, r Runner, version string, out *output.Outpu
}
schema := schemaDescriptions["public"]
var drift bytes.Buffer
driftOut := output.NewOutput(&drift, output.OutputOpts{})
var buf bytes.Buffer
driftOut := output.NewOutput(&buf, output.OutputOpts{})
expectedSchema, err := fetchExpectedSchema(ctx, schemaName, version, driftOut, expectedSchemaFactories)
if err != nil {
return err
}
if err := compareAndDisplaySchemaDescriptions(driftOut, schemaName, version, canonicalize(schema), canonicalize(expectedSchema)); err != nil {
if err := displayDriftSummaries(driftOut, drift.CompareSchemaDescriptions(schemaName, version, canonicalize(schema), canonicalize(expectedSchema))); err != nil {
schemasWithDrift = append(schemasWithDrift,
&schemaWithDrift{
name: schemaName,
drift: &drift,
drift: &buf,
},
)
}

View File

@ -14,17 +14,13 @@ go_library(
"compare_tables.go",
"compare_triggers.go",
"compare_views.go",
"formatter_console.go",
"summary.go",
"util_collections.go",
"util_search.go",
"util_strings.go",
],
importpath = "github.com/sourcegraph/sourcegraph/internal/database/migration/drift",
visibility = ["//:__subpackages__"],
deps = [
"//internal/database/migration/schemas",
"//lib/output",
"@com_github_google_go_cmp//cmp",
],
)

View File

@ -1,6 +1,8 @@
package drift
import (
"sort"
"github.com/google/go-cmp/cmp"
"github.com/sourcegraph/sourcegraph/internal/database/migration/schemas"
@ -24,10 +26,22 @@ func CompareSchemaDescriptions(schemaName, version string, actual, expected sche
// compareNamedLists invokes the given primary callback with a pair of differing elements from slices
// `as` and `bs`, respectively, with the same name. If there is a missing element from `as`, there will
// be an invocation of this callback with a nil value for its first parameter. Elements for which there
// be an invocation of this callback with a nil value for its first parameter. If any invocation of the
// function returns true, the output of this function will be true.
func compareNamedLists[T schemas.Namer](
as []T,
bs []T,
primaryCallback func(a *T, b T) Summary,
) []Summary {
return compareNamedListsStrict(as, bs, primaryCallback, noopAdditionalCallback[T])
}
// compareNamedListsStrict invokes the given primary callback with a pair of differing elements from
// slices `as` and `bs`, respectively, with the same name. If there is a missing element from `as`, there
// will be an invocation of this callback with a nil value for its first parameter. Elements for which there
// is no analog in `bs` will be collected and sent to an invocation of the additions callback. If any
// invocation of either function returns true, the output of this function will be true.
func compareNamedLists[T schemas.Namer](
func compareNamedListsStrict[T schemas.Namer](
as []T,
bs []T,
primaryCallback func(a *T, b T) Summary,
@ -41,17 +55,31 @@ func compareNamedLists[T schemas.Namer](
return nil
}
return compareNamedListsMulti(as, bs, wrappedPrimaryCallback, additionsCallback)
return compareNamedListsMultiStrict(as, bs, wrappedPrimaryCallback, additionsCallback)
}
// compareNamedListsMulti invokes the given primary callback with a pair of differing elements from slices
// `as` and `bs`, respectively, with the same name. Similar `compareNamedLists`, but this version expects
// multiple `Summary` values from the callback.
func compareNamedListsMulti[T schemas.Namer](
as []T,
bs []T,
primaryCallback func(a *T, b T) []Summary,
) []Summary {
return compareNamedListsMultiStrict(as, bs, primaryCallback, noopAdditionalCallback[T])
}
// compareNamedListsMultiStrict invokes the given primary callback with a pair of differing elements from
// slices `as` and `bs`, respectively, with the same name. Similar `compareNamedListsStrict`, but
// this version expects multiple `Summary` values from the callback.
func compareNamedListsMultiStrict[T schemas.Namer](
as []T,
bs []T,
primaryCallback func(a *T, b T) []Summary,
additionsCallback func(additional []T) []Summary,
) []Summary {
am := groupByName(as)
bm := groupByName(bs)
am := schemas.GroupByName(as)
bm := schemas.GroupByName(bs)
additional := make([]T, 0, len(am))
summaries := []Summary(nil)
@ -86,3 +114,14 @@ func compareNamedListsMulti[T schemas.Namer](
func noopAdditionalCallback[T schemas.Namer](_ []T) []Summary {
return nil
}
// keys returns the ordered keys of the given map.
func keys[T any](m map[string]T) []string {
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
sort.Strings(keys)
return keys
}

View File

@ -3,83 +3,74 @@ package drift
import (
"fmt"
"github.com/google/go-cmp/cmp"
"github.com/sourcegraph/sourcegraph/internal/database/migration/schemas"
)
func compareColumns(schemaName, version string, actualTable, expectedTable schemas.TableDescription) []Summary {
return compareNamedLists(actualTable.Columns, expectedTable.Columns, func(column *schemas.ColumnDescription, expectedColumn schemas.ColumnDescription) Summary {
return compareNamedListsStrict(
actualTable.Columns,
expectedTable.Columns,
compareColumnsCallbackFor(schemaName, version, expectedTable),
compareColumnsAdditionalCallbackFor(expectedTable),
)
}
func compareColumnsCallbackFor(schemaName, version string, table schemas.TableDescription) func(_ *schemas.ColumnDescription, _ schemas.ColumnDescription) Summary {
return func(column *schemas.ColumnDescription, expectedColumn schemas.ColumnDescription) Summary {
if column == nil {
url := makeSearchURL(schemaName, version,
fmt.Sprintf("CREATE TABLE %s", expectedTable.Name),
fmt.Sprintf("ALTER TABLE ONLY %s", expectedTable.Name),
)
return newDriftSummary(
fmt.Sprintf("%q.%q", expectedTable.Name, expectedColumn.Name),
fmt.Sprintf("Missing column %q.%q", expectedTable.Name, expectedColumn.Name),
fmt.Sprintf("%q.%q", table.GetName(), expectedColumn.GetName()),
fmt.Sprintf("Missing column %q.%q", table.GetName(), expectedColumn.GetName()),
"define the column",
).withURLHint(url)
).withStatements(
expectedColumn.CreateStatement(table),
).withURLHint(
makeSearchURL(schemaName, version,
fmt.Sprintf("CREATE TABLE %s", table.GetName()),
fmt.Sprintf("ALTER TABLE ONLY %s", table.GetName()),
),
)
}
equivIf := func(f func(*schemas.ColumnDescription)) bool {
c := *column
f(&c)
return cmp.Diff(c, expectedColumn) == ""
}
// TODO
// if equivIf(func(s *schemas.ColumnDescription) { s.TypeName = expectedColumn.TypeName }) {}
if equivIf(func(s *schemas.ColumnDescription) { s.IsNullable = expectedColumn.IsNullable }) {
var verb string
if expectedColumn.IsNullable {
verb = "DROP"
} else {
verb = "SET"
}
alterColumnStmt := fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s %s NOT NULL;", expectedTable.Name, expectedColumn.Name, verb)
if alterStatements, ok := (*column).AlterToTarget(table, expectedColumn); ok {
return newDriftSummary(
fmt.Sprintf("%q.%q", expectedTable.Name, expectedColumn.Name),
fmt.Sprintf("Unexpected properties of column %q.%q", expectedTable.Name, expectedColumn.Name),
"change the column nullability constraint",
).withDiff(expectedColumn, *column).withStatements(alterColumnStmt)
expectedColumn.GetName(),
fmt.Sprintf("Unexpected properties of column %s.%q", table.GetName(), expectedColumn.GetName()),
"alter the column",
).withStatements(
alterStatements...,
)
}
if equivIf(func(s *schemas.ColumnDescription) { s.Default = expectedColumn.Default }) {
alterColumnStmt := fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s SET DEFAULT %s;", expectedTable.Name, expectedColumn.Name, expectedColumn.Default)
return newDriftSummary(
fmt.Sprintf("%q.%q", expectedTable.Name, expectedColumn.Name),
fmt.Sprintf("Unexpected properties of column %q.%q", expectedTable.Name, expectedColumn.Name),
"change the column default",
).withDiff(expectedColumn, *column).withStatements(alterColumnStmt)
}
url := makeSearchURL(schemaName, version,
fmt.Sprintf("CREATE TABLE %s", expectedTable.Name),
fmt.Sprintf("ALTER TABLE ONLY %s", expectedTable.Name),
)
return newDriftSummary(
fmt.Sprintf("%q.%q", expectedTable.Name, expectedColumn.Name),
fmt.Sprintf("Unexpected properties of column %q.%q", expectedTable.Name, expectedColumn.Name),
fmt.Sprintf("%q.%q", table.GetName(), expectedColumn.GetName()),
fmt.Sprintf("Unexpected properties of column %q.%q", table.GetName(), expectedColumn.GetName()),
"redefine the column",
).withDiff(expectedColumn, *column).withURLHint(url)
}, func(additional []schemas.ColumnDescription) []Summary {
).withDiff(
expectedColumn,
*column,
).withURLHint(
makeSearchURL(schemaName, version,
fmt.Sprintf("CREATE TABLE %s", table.GetName()),
fmt.Sprintf("ALTER TABLE ONLY %s", table.GetName()),
),
)
}
}
func compareColumnsAdditionalCallbackFor(table schemas.TableDescription) func(_ []schemas.ColumnDescription) []Summary {
return func(additional []schemas.ColumnDescription) []Summary {
summaries := []Summary{}
for _, column := range additional {
alterColumnStmt := fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s;", expectedTable.Name, column.Name)
summary := newDriftSummary(
fmt.Sprintf("%q.%q", expectedTable.Name, column.Name),
fmt.Sprintf("Unexpected column %q.%q", expectedTable.Name, column.Name),
summaries = append(summaries, newDriftSummary(
fmt.Sprintf("%q.%q", table.GetName(), column.GetName()),
fmt.Sprintf("Unexpected column %q.%q", table.GetName(), column.GetName()),
"drop the column",
).withStatements(alterColumnStmt)
summaries = append(summaries, summary)
).withStatements(
column.DropStatement(table),
))
}
return summaries
})
}
}

View File

@ -7,36 +7,53 @@ import (
)
func compareConstraints(actualTable, expectedTable schemas.TableDescription) []Summary {
return compareNamedLists(actualTable.Constraints, expectedTable.Constraints, func(constraint *schemas.ConstraintDescription, expectedConstraint schemas.ConstraintDescription) Summary {
createConstraintStmt := fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s %s;", expectedTable.Name, expectedConstraint.Name, expectedConstraint.ConstraintDefinition)
dropConstraintStmt := fmt.Sprintf("ALTER TABLE %s DROP CONSTRAINT %s;", expectedTable.Name, expectedConstraint.Name)
return compareNamedListsStrict(
actualTable.Constraints,
expectedTable.Constraints,
compareConstraintsCallbackFor(expectedTable),
compareConstraintsAdditionalCallbackFor(expectedTable),
)
}
func compareConstraintsCallbackFor(table schemas.TableDescription) func(_ *schemas.ConstraintDescription, _ schemas.ConstraintDescription) Summary {
return func(constraint *schemas.ConstraintDescription, expectedConstraint schemas.ConstraintDescription) Summary {
if constraint == nil {
return newDriftSummary(
fmt.Sprintf("%q.%q", expectedTable.Name, expectedConstraint.Name),
fmt.Sprintf("Missing constraint %q.%q", expectedTable.Name, expectedConstraint.Name),
fmt.Sprintf("%q.%q", table.GetName(), expectedConstraint.GetName()),
fmt.Sprintf("Missing constraint %q.%q", table.GetName(), expectedConstraint.GetName()),
"define the constraint",
).withStatements(createConstraintStmt)
).withStatements(
expectedConstraint.CreateStatement(table),
)
}
return newDriftSummary(
fmt.Sprintf("%q.%q", expectedTable.Name, expectedConstraint.Name),
fmt.Sprintf("Unexpected properties of constraint %q.%q", expectedTable.Name, expectedConstraint.Name),
fmt.Sprintf("%q.%q", table.GetName(), expectedConstraint.GetName()),
fmt.Sprintf("Unexpected properties of constraint %q.%q", table.GetName(), expectedConstraint.GetName()),
"redefine the constraint",
).withDiff(expectedConstraint, *constraint).withStatements(dropConstraintStmt, createConstraintStmt)
}, func(additional []schemas.ConstraintDescription) []Summary {
).withDiff(
expectedConstraint,
*constraint,
).withStatements(
expectedConstraint.DropStatement(table),
expectedConstraint.CreateStatement(table),
)
}
}
func compareConstraintsAdditionalCallbackFor(table schemas.TableDescription) func(_ []schemas.ConstraintDescription) []Summary {
return func(additional []schemas.ConstraintDescription) []Summary {
summaries := []Summary{}
for _, constraint := range additional {
alterTableStmt := fmt.Sprintf("ALTER TABLE %s DROP CONSTRAINT %s;", expectedTable.Name, constraint.Name)
summary := newDriftSummary(
fmt.Sprintf("%q.%q", expectedTable.Name, constraint.Name),
fmt.Sprintf("Unexpected constraint %q.%q", expectedTable.Name, constraint.Name),
summaries = append(summaries, newDriftSummary(
fmt.Sprintf("%q.%q", table.GetName(), constraint.GetName()),
fmt.Sprintf("Unexpected constraint %q.%q", table.GetName(), constraint.GetName()),
"drop the constraint",
).withStatements(alterTableStmt)
summaries = append(summaries, summary)
).withStatements(
constraint.DropStatement(table),
))
}
return summaries
})
}
}

View File

@ -2,120 +2,44 @@ package drift
import (
"fmt"
"strings"
"github.com/sourcegraph/sourcegraph/internal/database/migration/schemas"
)
func compareEnums(schemaName, version string, actual, expected schemas.SchemaDescription) []Summary {
return compareNamedLists(actual.Enums, expected.Enums, func(enum *schemas.EnumDescription, expectedEnum schemas.EnumDescription) Summary {
quotedLabels := make([]string, 0, len(expectedEnum.Labels))
for _, label := range expectedEnum.Labels {
quotedLabels = append(quotedLabels, fmt.Sprintf("'%s'", label))
}
createEnumStmt := fmt.Sprintf("CREATE TYPE %s AS ENUM (%s);", expectedEnum.Name, strings.Join(quotedLabels, ", "))
dropEnumStmt := fmt.Sprintf("DROP TYPE %s;", expectedEnum.Name)
if enum == nil {
return newDriftSummary(
expectedEnum.Name,
fmt.Sprintf("Missing enum %q", expectedEnum.Name),
"create the type",
).withStatements(createEnumStmt)
}
if ordered, ok := constructEnumRepairStatements(*enum, expectedEnum); ok {
return newDriftSummary(
expectedEnum.Name,
fmt.Sprintf("Missing %d labels for enum %q", len(ordered), expectedEnum.Name),
"add the missing enum labels",
).withStatements(ordered...)
}
return compareNamedLists(actual.Enums, expected.Enums, compareEnumsCallback)
}
func compareEnumsCallback(enum *schemas.EnumDescription, expectedEnum schemas.EnumDescription) Summary {
if enum == nil {
return newDriftSummary(
expectedEnum.Name,
fmt.Sprintf("Unexpected labels for enum %q", expectedEnum.Name),
"drop and re-create the type",
).withDiff(enum.Labels, expectedEnum.Labels).withStatements(dropEnumStmt, createEnumStmt)
}, noopAdditionalCallback[schemas.EnumDescription])
}
// constructEnumRepairStatements returns a set of `ALTER ENUM ADD VALUE` statements to make
// the given enum equivalent to the given expected enum. If the given enum is not a subset of
// the expected enum, then additive statements cannot bring the enum to the expected state and
// we return a false-valued flag. In this case the existing type must be dropped and re-created
// as there's currently no way to *remove* values from an enum type.
func constructEnumRepairStatements(enum, expectedEnum schemas.EnumDescription) ([]string, bool) {
labels := groupByName(wrapStrings(enum.Labels))
expectedLabels := groupByName(wrapStrings(expectedEnum.Labels))
for label := range labels {
if _, ok := expectedLabels[label]; !ok {
return nil, false
}
expectedEnum.GetName(),
fmt.Sprintf("Missing enum %q", expectedEnum.GetName()),
"define the type",
).withStatements(
expectedEnum.CreateStatement(),
)
}
// If we're here then we're strictly missing labels and can add them in-place.
// Try to reconstruct the data we need to make the proper create type statement.
type missingLabel struct {
label string
neighbor string
before bool
}
missingLabels := make([]missingLabel, 0, len(expectedEnum.Labels))
after := ""
for _, label := range expectedEnum.Labels {
if _, ok := labels[label]; !ok && after != "" {
missingLabels = append(missingLabels, missingLabel{label: label, neighbor: after, before: false})
}
after = label
if alterStatements, ok := (*enum).AlterToTarget(expectedEnum); ok {
return newDriftSummary(
expectedEnum.GetName(),
fmt.Sprintf("Unexpected properties of enum %q", expectedEnum.GetName()),
"alter the type",
).withStatements(
alterStatements...,
)
}
before := ""
for i := len(expectedEnum.Labels) - 1; i >= 0; i-- {
label := expectedEnum.Labels[i]
if _, ok := labels[label]; !ok && before != "" {
missingLabels = append(missingLabels, missingLabel{label: label, neighbor: before, before: true})
}
before = label
}
var (
ordered []string
reachable = groupByName(wrapStrings(enum.Labels))
return newDriftSummary(
expectedEnum.GetName(),
fmt.Sprintf("Unexpected properties of enum %q", expectedEnum.GetName()),
"redefine the type",
).withDiff(
expectedEnum.Labels,
enum.Labels,
).withStatements(
expectedEnum.DropStatement(),
expectedEnum.CreateStatement(),
)
outer:
for len(missingLabels) > 0 {
for _, s := range missingLabels {
// Neighbor doesn't exist yet, blocked from creating
if _, ok := reachable[s.neighbor]; !ok {
continue
}
rel := "AFTER"
if s.before {
rel = "BEFORE"
}
filtered := missingLabels[:0]
for _, l := range missingLabels {
if l.label != s.label {
filtered = append(filtered, l)
}
}
missingLabels = filtered
reachable[s.label] = stringNamer(s.label)
ordered = append(ordered, fmt.Sprintf("ALTER TYPE %s ADD VALUE '%s' %s '%s';", expectedEnum.Name, s.label, rel, s.neighbor))
continue outer
}
panic("Infinite loop")
}
return ordered, true
}

View File

@ -7,17 +7,19 @@ import (
)
func compareExtensions(schemaName, version string, actual, expected schemas.SchemaDescription) []Summary {
return compareNamedLists(wrapStrings(actual.Extensions), wrapStrings(expected.Extensions), func(extension *stringNamer, expectedExtension stringNamer) Summary {
if extension == nil {
createExtensionStmt := fmt.Sprintf("CREATE EXTENSION %s;", expectedExtension)
return newDriftSummary(
expectedExtension.GetName(),
fmt.Sprintf("Missing extension %q", expectedExtension),
"install the extension",
).withStatements(createExtensionStmt)
}
return nil
}, noopAdditionalCallback[stringNamer])
return compareNamedLists(actual.WrappedExtensions(), expected.WrappedExtensions(), compareExtensionsCallback)
}
func compareExtensionsCallback(extension *schemas.ExtensionDescription, expectedExtension schemas.ExtensionDescription) Summary {
if extension == nil {
return newDriftSummary(
expectedExtension.GetName(),
fmt.Sprintf("Missing extension %q", expectedExtension.GetName()),
"define the extension",
).withStatements(
expectedExtension.CreateStatement(),
)
}
return nil
}

View File

@ -2,27 +2,33 @@ package drift
import (
"fmt"
"strings"
"github.com/sourcegraph/sourcegraph/internal/database/migration/schemas"
)
func compareFunctions(schemaName, version string, actual, expected schemas.SchemaDescription) []Summary {
return compareNamedLists(actual.Functions, expected.Functions, func(function *schemas.FunctionDescription, expectedFunction schemas.FunctionDescription) Summary {
definitionStmt := fmt.Sprintf("%s;", strings.TrimSpace(expectedFunction.Definition))
if function == nil {
return newDriftSummary(
expectedFunction.Name,
fmt.Sprintf("Missing function %q", expectedFunction.Name),
"define the function",
).withStatements(definitionStmt)
}
return newDriftSummary(
expectedFunction.Name,
fmt.Sprintf("Unexpected definition of function %q", expectedFunction.Name),
"replace the function definition",
).withDiff(expectedFunction.Definition, function.Definition).withStatements(definitionStmt)
}, noopAdditionalCallback[schemas.FunctionDescription])
return compareNamedLists(actual.Functions, expected.Functions, compareFunctionsCallback)
}
func compareFunctionsCallback(function *schemas.FunctionDescription, expectedFunction schemas.FunctionDescription) Summary {
if function == nil {
return newDriftSummary(
expectedFunction.GetName(),
fmt.Sprintf("Missing function %q", expectedFunction.GetName()),
"define the function",
).withStatements(
expectedFunction.CreateOrReplaceStatement(),
)
}
return newDriftSummary(
expectedFunction.GetName(),
fmt.Sprintf("Unexpected definition of function %q", expectedFunction.GetName()),
"redefine the function",
).withDiff(
expectedFunction.Definition,
function.Definition,
).withStatements(
expectedFunction.CreateOrReplaceStatement(),
)
}

View File

@ -7,45 +7,53 @@ import (
)
func compareIndexes(actualTable, expectedTable schemas.TableDescription) []Summary {
return compareNamedLists(actualTable.Indexes, expectedTable.Indexes, func(index *schemas.IndexDescription, expectedIndex schemas.IndexDescription) Summary {
var createIndexStmt string
switch expectedIndex.ConstraintType {
case "u":
fallthrough
case "p":
createIndexStmt = fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s %s;", actualTable.Name, expectedIndex.Name, expectedIndex.ConstraintDefinition)
default:
createIndexStmt = fmt.Sprintf("%s;", expectedIndex.IndexDefinition)
}
return compareNamedListsStrict(
actualTable.Indexes,
expectedTable.Indexes,
compareIndexesCallbackFor(expectedTable),
compareIndexesCallbackAdditionalFor(expectedTable),
)
}
func compareIndexesCallbackFor(table schemas.TableDescription) func(_ *schemas.IndexDescription, _ schemas.IndexDescription) Summary {
return func(index *schemas.IndexDescription, expectedIndex schemas.IndexDescription) Summary {
if index == nil {
return newDriftSummary(
fmt.Sprintf("%q.%q", expectedTable.Name, expectedIndex.Name),
fmt.Sprintf("Missing index %q.%q", expectedTable.Name, expectedIndex.Name),
fmt.Sprintf("%q.%q", table.GetName(), expectedIndex.GetName()),
fmt.Sprintf("Missing index %q.%q", table.GetName(), expectedIndex.GetName()),
"define the index",
).withStatements(createIndexStmt)
).withStatements(
expectedIndex.CreateStatement(table),
)
}
dropIndexStmt := fmt.Sprintf("DROP INDEX %s;", expectedIndex.Name)
return newDriftSummary(
fmt.Sprintf("%q.%q", expectedTable.Name, expectedIndex.Name),
fmt.Sprintf("Unexpected properties of index %q.%q", expectedTable.Name, expectedIndex.Name),
fmt.Sprintf("%q.%q", table.GetName(), expectedIndex.GetName()),
fmt.Sprintf("Unexpected properties of index %q.%q", table.GetName(), expectedIndex.GetName()),
"redefine the index",
).withDiff(expectedIndex, *index).withStatements(dropIndexStmt, createIndexStmt)
}, func(additional []schemas.IndexDescription) []Summary {
).withDiff(
expectedIndex,
*index,
).withStatements(
expectedIndex.DropStatement(),
expectedIndex.CreateStatement(table),
)
}
}
func compareIndexesCallbackAdditionalFor(table schemas.TableDescription) func(_ []schemas.IndexDescription) []Summary {
return func(additional []schemas.IndexDescription) []Summary {
summaries := []Summary{}
for _, index := range additional {
dropIndexStmt := fmt.Sprintf("DROP INDEX %s;", index.Name)
summary := newDriftSummary(
fmt.Sprintf("%q.%q", expectedTable.Name, index.Name),
fmt.Sprintf("Unexpected index %q.%q", expectedTable.Name, index.Name),
summaries = append(summaries, newDriftSummary(
fmt.Sprintf("%q.%q", table.GetName(), index.GetName()),
fmt.Sprintf("Unexpected index %q.%q", table.GetName(), index.GetName()),
"drop the index",
).withStatements(dropIndexStmt)
summaries = append(summaries, summary)
).withStatements(
index.DropStatement(),
))
}
return summaries
})
}
}

View File

@ -7,24 +7,43 @@ import (
)
func compareSequences(schemaName, version string, actual, expected schemas.SchemaDescription) []Summary {
return compareNamedLists(actual.Sequences, expected.Sequences, func(sequence *schemas.SequenceDescription, expectedSequence schemas.SequenceDescription) Summary {
definitionStmt := makeSearchURL(schemaName, version,
fmt.Sprintf("CREATE SEQUENCE %s", expectedSequence.Name),
fmt.Sprintf("nextval('%s'::regclass);", expectedSequence.Name),
)
return compareNamedLists(actual.Sequences, expected.Sequences, compareSequencesCallbackFor(schemaName, version))
}
func compareSequencesCallbackFor(schemaName, version string) func(_ *schemas.SequenceDescription, _ schemas.SequenceDescription) Summary {
return func(sequence *schemas.SequenceDescription, expectedSequence schemas.SequenceDescription) Summary {
if sequence == nil {
return newDriftSummary(
expectedSequence.Name,
fmt.Sprintf("Missing sequence %q", expectedSequence.Name),
expectedSequence.GetName(),
fmt.Sprintf("Missing sequence %q", expectedSequence.GetName()),
"define the sequence",
).withStatements(definitionStmt)
).withStatements(
expectedSequence.CreateStatement(),
)
}
if alterStatements, ok := (*sequence).AlterToTarget(expectedSequence); ok {
return newDriftSummary(
expectedSequence.GetName(),
fmt.Sprintf("Unexpected properties of sequence %q", expectedSequence.GetName()),
"alter the sequence",
).withStatements(
alterStatements...,
)
}
return newDriftSummary(
expectedSequence.Name,
fmt.Sprintf("Unexpected properties of sequence %q", expectedSequence.Name),
expectedSequence.GetName(),
fmt.Sprintf("Unexpected properties of sequence %q", expectedSequence.GetName()),
"redefine the sequence",
).withDiff(expectedSequence, *sequence).withStatements(definitionStmt)
}, noopAdditionalCallback[schemas.SequenceDescription])
).withDiff(
expectedSequence,
*sequence,
).withURLHint(
makeSearchURL(schemaName, version,
fmt.Sprintf("CREATE SEQUENCE %s", expectedSequence.GetName()),
fmt.Sprintf("nextval('%s'::regclass);", expectedSequence.GetName()),
),
)
}
}

View File

@ -7,19 +7,23 @@ import (
)
func compareTables(schemaName, version string, actual, expected schemas.SchemaDescription) []Summary {
return compareNamedListsMulti(actual.Tables, expected.Tables, func(table *schemas.TableDescription, expectedTable schemas.TableDescription) []Summary {
if table == nil {
url := makeSearchURL(schemaName, version,
fmt.Sprintf("CREATE TABLE %s", expectedTable.Name),
fmt.Sprintf("ALTER TABLE ONLY %s", expectedTable.Name),
fmt.Sprintf("CREATE .*(INDEX|TRIGGER).* ON %s", expectedTable.Name),
)
return compareNamedListsMulti(actual.Tables, expected.Tables, compareTablesCallbackFor(schemaName, version))
}
func compareTablesCallbackFor(schemaName, version string) func(_ *schemas.TableDescription, _ schemas.TableDescription) []Summary {
return func(table *schemas.TableDescription, expectedTable schemas.TableDescription) []Summary {
if table == nil {
return singleton(newDriftSummary(
expectedTable.Name,
fmt.Sprintf("Missing table %q", expectedTable.Name),
expectedTable.GetName(),
fmt.Sprintf("Missing table %q", expectedTable.GetName()),
"define the table",
).withURLHint(url))
).withURLHint(
makeSearchURL(schemaName, version,
fmt.Sprintf("CREATE TABLE %s", expectedTable.GetName()),
fmt.Sprintf("ALTER TABLE ONLY %s", expectedTable.GetName()),
fmt.Sprintf("CREATE .*(INDEX|TRIGGER).* ON %s", expectedTable.GetName()),
),
))
}
summaries := []Summary(nil)
@ -28,5 +32,5 @@ func compareTables(schemaName, version string, actual, expected schemas.SchemaDe
summaries = append(summaries, compareIndexes(*table, expectedTable)...)
summaries = append(summaries, compareTriggers(*table, expectedTable)...)
return summaries
}, noopAdditionalCallback[schemas.TableDescription])
}
}

View File

@ -7,36 +7,53 @@ import (
)
func compareTriggers(actualTable, expectedTable schemas.TableDescription) []Summary {
return compareNamedLists(actualTable.Triggers, expectedTable.Triggers, func(trigger *schemas.TriggerDescription, expectedTrigger schemas.TriggerDescription) Summary {
createTriggerStmt := fmt.Sprintf("%s;", expectedTrigger.Definition)
dropTriggerStmt := fmt.Sprintf("DROP TRIGGER %s ON %s;", expectedTrigger.Name, expectedTable.Name)
return compareNamedListsStrict(
actualTable.Triggers,
expectedTable.Triggers,
compareNamedListsCallbackFor(expectedTable),
compareNamedListsAdditionalCallbackFor(expectedTable),
)
}
func compareNamedListsCallbackFor(table schemas.TableDescription) func(_ *schemas.TriggerDescription, _ schemas.TriggerDescription) Summary {
return func(trigger *schemas.TriggerDescription, expectedTrigger schemas.TriggerDescription) Summary {
if trigger == nil {
return newDriftSummary(
fmt.Sprintf("%q.%q", expectedTable.Name, expectedTrigger.Name),
fmt.Sprintf("Missing trigger %q.%q", expectedTable.Name, expectedTrigger.Name),
fmt.Sprintf("%q.%q", table.GetName(), expectedTrigger.GetName()),
fmt.Sprintf("Missing trigger %q.%q", table.GetName(), expectedTrigger.GetName()),
"define the trigger",
).withStatements(createTriggerStmt)
).withStatements(
expectedTrigger.CreateStatement(),
)
}
return newDriftSummary(
fmt.Sprintf("%q.%q", expectedTable.Name, expectedTrigger.Name),
fmt.Sprintf("Unexpected properties of trigger %q.%q", expectedTable.Name, expectedTrigger.Name),
fmt.Sprintf("%q.%q", table.GetName(), expectedTrigger.GetName()),
fmt.Sprintf("Unexpected properties of trigger %q.%q", table.GetName(), expectedTrigger.GetName()),
"redefine the trigger",
).withDiff(expectedTrigger, *trigger).withStatements(dropTriggerStmt, createTriggerStmt)
}, func(additional []schemas.TriggerDescription) []Summary {
).withDiff(
expectedTrigger,
*trigger,
).withStatements(
expectedTrigger.DropStatement(table),
expectedTrigger.CreateStatement(),
)
}
}
func compareNamedListsAdditionalCallbackFor(table schemas.TableDescription) func(_ []schemas.TriggerDescription) []Summary {
return func(additional []schemas.TriggerDescription) []Summary {
summaries := []Summary{}
for _, trigger := range additional {
dropTriggerStmt := fmt.Sprintf("DROP TRIGGER %s ON %s;", trigger.Name, expectedTable.Name)
summary := newDriftSummary(
fmt.Sprintf("%q.%q", expectedTable.Name, trigger.Name),
fmt.Sprintf("Unexpected trigger %q.%q", expectedTable.Name, trigger.Name),
summaries = append(summaries, newDriftSummary(
fmt.Sprintf("%q.%q", table.GetName(), trigger.GetName()),
fmt.Sprintf("Unexpected trigger %q.%q", table.GetName(), trigger.GetName()),
"drop the trigger",
).withStatements(dropTriggerStmt)
summaries = append(summaries, summary)
).withStatements(
trigger.DropStatement(table),
))
}
return summaries
})
}
}

View File

@ -2,30 +2,34 @@ package drift
import (
"fmt"
"strings"
"github.com/sourcegraph/sourcegraph/internal/database/migration/schemas"
)
func compareViews(schemaName, version string, actual, expected schemas.SchemaDescription) []Summary {
return compareNamedLists(actual.Views, expected.Views, func(view *schemas.ViewDescription, expectedView schemas.ViewDescription) Summary {
// pgsql has weird indents here
viewDefinition := strings.TrimSpace(stripIndent(" " + expectedView.Definition))
createViewStmt := fmt.Sprintf("CREATE VIEW %s AS %s", expectedView.Name, viewDefinition)
dropViewStmt := fmt.Sprintf("DROP VIEW %s;", expectedView.Name)
if view == nil {
return newDriftSummary(
expectedView.Name,
fmt.Sprintf("Missing view %q", expectedView.Name),
"define the view",
).withStatements(createViewStmt)
}
return newDriftSummary(
expectedView.Name,
fmt.Sprintf("Unexpected definition of view %q", expectedView.Name),
"redefine the view",
).withDiff(expectedView.Definition, view.Definition).withStatements(dropViewStmt, createViewStmt)
}, noopAdditionalCallback[schemas.ViewDescription])
return compareNamedLists(actual.Views, expected.Views, compareViewsCallback)
}
func compareViewsCallback(view *schemas.ViewDescription, expectedView schemas.ViewDescription) Summary {
if view == nil {
return newDriftSummary(
expectedView.GetName(),
fmt.Sprintf("Missing view %q", expectedView.GetName()),
"define the view",
).withStatements(
expectedView.CreateStatement(),
)
}
return newDriftSummary(
expectedView.GetName(),
fmt.Sprintf("Unexpected definition of view %q", expectedView.GetName()),
"redefine the view",
).withDiff(
expectedView.Definition,
view.Definition,
).withStatements(
expectedView.DropStatement(),
expectedView.CreateStatement(),
)
}

View File

@ -1,46 +0,0 @@
package drift
import (
"fmt"
"strings"
"github.com/google/go-cmp/cmp"
"github.com/sourcegraph/sourcegraph/lib/output"
)
type ConsoleFormatter struct {
out OutputWriter
}
type OutputWriter interface {
Write(s string)
Writef(format string, args ...any)
WriteLine(line output.FancyLine)
WriteCode(languageName, str string) error
}
func NewConsoleFormatter(out OutputWriter) *ConsoleFormatter {
return &ConsoleFormatter{out: out}
}
func (f *ConsoleFormatter) Display(s Summary) {
f.out.WriteLine(output.Line(output.EmojiFailure, output.StyleBold, s.Problem()))
if a, b, ok := s.Diff(); ok {
_ = f.out.WriteCode("diff", strings.TrimSpace(cmp.Diff(a, b)))
}
f.out.WriteLine(output.Line(output.EmojiLightbulb, output.StyleItalic, fmt.Sprintf("Suggested action: %s.", s.Solution())))
if statements, ok := s.Statements(); ok {
_ = f.out.WriteCode("sql", strings.Join(statements, "\n"))
}
if urlHint, ok := s.URLHint(); ok {
f.out.WriteLine(output.Line(output.EmojiLightbulb, output.StyleItalic, fmt.Sprintf("Hint: Reproduce %s as defined at the following URL:", s.Name())))
f.out.Write("")
f.out.WriteLine(output.Line(output.EmojiFingerPointRight, output.StyleUnderline, urlHint))
f.out.Write("")
}
}

View File

@ -1,29 +0,0 @@
package drift
import (
"sort"
"github.com/sourcegraph/sourcegraph/internal/database/migration/schemas"
)
// keys returns the ordered keys of the given map.
func keys[T any](m map[string]T) []string {
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
sort.Strings(keys)
return keys
}
// groupByName converts the given element slice into a map indexed by
// each element's name.
func groupByName[T schemas.Namer](ts []T) map[string]T {
m := make(map[string]T, len(ts))
for _, t := range ts {
m[t.GetName()] = t
}
return m
}

View File

@ -7,16 +7,6 @@ import (
"strings"
)
// quoteTerm converts the given literal search term into a regular expression.
func quoteTerm(searchTerm string) string {
terms := strings.Split(searchTerm, " ")
for i, term := range terms {
terms[i] = regexp.QuoteMeta(term)
}
return "(^|\\b)" + strings.Join(terms, "\\s") + "($|\\b)"
}
// makeSearchURL returns a URL to a sourcegraph.com search query within the squashed
// definition of the given schema.
func makeSearchURL(schemaName, version string, searchTerms ...string) string {
@ -39,3 +29,13 @@ func makeSearchURL(schemaName, version string, searchTerms ...string) string {
searchUrl.RawQuery = qs.Encode()
return searchUrl.String()
}
// quoteTerm converts the given literal search term into a regular expression.
func quoteTerm(searchTerm string) string {
terms := strings.Split(searchTerm, " ")
for i, term := range terms {
terms[i] = regexp.QuoteMeta(term)
}
return "(^|\\b)" + strings.Join(terms, "\\s") + "($|\\b)"
}

View File

@ -1,37 +0,0 @@
package drift
import (
"strings"
)
type stringNamer string
func (s stringNamer) GetName() string { return string(s) }
// wrapStrings converts a string slice into a string slice with GetName
// on each element.
func wrapStrings(ss []string) []stringNamer {
sn := make([]stringNamer, 0, len(ss))
for _, s := range ss {
sn = append(sn, stringNamer(s))
}
return sn
}
// stripIndent removes the largest common indent from the given text.
func stripIndent(s string) string {
lines := strings.Split(strings.TrimRight(s, "\n"), "\n")
min := len(lines[0])
for _, line := range lines {
if indent := len(line) - len(strings.TrimLeft(line, " ")); indent < min {
min = indent
}
}
for i, line := range lines {
lines[i] = line[min:]
}
return strings.Join(lines, "\n")
}

View File

@ -13,6 +13,7 @@ type Store interface {
Done(err error) error
Versions(ctx context.Context) (appliedVersions, pendingVersions, failedVersions []int, _ error)
RunDDLStatements(ctx context.Context, statements []string) error
TryLock(ctx context.Context) (bool, func(err error) error, error)
Up(ctx context.Context, migration definition.Definition) error
Down(ctx context.Context, migration definition.Definition) error

View File

@ -32,6 +32,9 @@ type MockStore struct {
// IndexStatusFunc is an instance of a mock function object controlling
// the behavior of the method IndexStatus.
IndexStatusFunc *StoreIndexStatusFunc
// RunDDLStatementsFunc is an instance of a mock function object
// controlling the behavior of the method RunDDLStatements.
RunDDLStatementsFunc *StoreRunDDLStatementsFunc
// TransactFunc is an instance of a mock function object controlling the
// behavior of the method Transact.
TransactFunc *StoreTransactFunc
@ -73,6 +76,11 @@ func NewMockStore() *MockStore {
return
},
},
RunDDLStatementsFunc: &StoreRunDDLStatementsFunc{
defaultHook: func(context.Context, []string) (r0 error) {
return
},
},
TransactFunc: &StoreTransactFunc{
defaultHook: func(context.Context) (r0 Store, r1 error) {
return
@ -125,6 +133,11 @@ func NewStrictMockStore() *MockStore {
panic("unexpected invocation of MockStore.IndexStatus")
},
},
RunDDLStatementsFunc: &StoreRunDDLStatementsFunc{
defaultHook: func(context.Context, []string) error {
panic("unexpected invocation of MockStore.RunDDLStatements")
},
},
TransactFunc: &StoreTransactFunc{
defaultHook: func(context.Context) (Store, error) {
panic("unexpected invocation of MockStore.Transact")
@ -169,6 +182,9 @@ func NewMockStoreFrom(i Store) *MockStore {
IndexStatusFunc: &StoreIndexStatusFunc{
defaultHook: i.IndexStatus,
},
RunDDLStatementsFunc: &StoreRunDDLStatementsFunc{
defaultHook: i.RunDDLStatements,
},
TransactFunc: &StoreTransactFunc{
defaultHook: i.Transact,
},
@ -609,6 +625,111 @@ func (c StoreIndexStatusFuncCall) Results() []interface{} {
return []interface{}{c.Result0, c.Result1, c.Result2}
}
// StoreRunDDLStatementsFunc describes the behavior when the
// RunDDLStatements method of the parent MockStore instance is invoked.
type StoreRunDDLStatementsFunc struct {
defaultHook func(context.Context, []string) error
hooks []func(context.Context, []string) error
history []StoreRunDDLStatementsFuncCall
mutex sync.Mutex
}
// RunDDLStatements delegates to the next hook function in the queue and
// stores the parameter and result values of this invocation.
func (m *MockStore) RunDDLStatements(v0 context.Context, v1 []string) error {
r0 := m.RunDDLStatementsFunc.nextHook()(v0, v1)
m.RunDDLStatementsFunc.appendCall(StoreRunDDLStatementsFuncCall{v0, v1, r0})
return r0
}
// SetDefaultHook sets function that is called when the RunDDLStatements
// method of the parent MockStore instance is invoked and the hook queue is
// empty.
func (f *StoreRunDDLStatementsFunc) SetDefaultHook(hook func(context.Context, []string) error) {
f.defaultHook = hook
}
// PushHook adds a function to the end of hook queue. Each invocation of the
// RunDDLStatements method of the parent MockStore instance invokes the hook
// at the front of the queue and discards it. After the queue is empty, the
// default hook function is invoked for any future action.
func (f *StoreRunDDLStatementsFunc) PushHook(hook func(context.Context, []string) error) {
f.mutex.Lock()
f.hooks = append(f.hooks, hook)
f.mutex.Unlock()
}
// SetDefaultReturn calls SetDefaultHook with a function that returns the
// given values.
func (f *StoreRunDDLStatementsFunc) SetDefaultReturn(r0 error) {
f.SetDefaultHook(func(context.Context, []string) error {
return r0
})
}
// PushReturn calls PushHook with a function that returns the given values.
func (f *StoreRunDDLStatementsFunc) PushReturn(r0 error) {
f.PushHook(func(context.Context, []string) error {
return r0
})
}
func (f *StoreRunDDLStatementsFunc) nextHook() func(context.Context, []string) error {
f.mutex.Lock()
defer f.mutex.Unlock()
if len(f.hooks) == 0 {
return f.defaultHook
}
hook := f.hooks[0]
f.hooks = f.hooks[1:]
return hook
}
func (f *StoreRunDDLStatementsFunc) appendCall(r0 StoreRunDDLStatementsFuncCall) {
f.mutex.Lock()
f.history = append(f.history, r0)
f.mutex.Unlock()
}
// History returns a sequence of StoreRunDDLStatementsFuncCall objects
// describing the invocations of this function.
func (f *StoreRunDDLStatementsFunc) History() []StoreRunDDLStatementsFuncCall {
f.mutex.Lock()
history := make([]StoreRunDDLStatementsFuncCall, len(f.history))
copy(history, f.history)
f.mutex.Unlock()
return history
}
// StoreRunDDLStatementsFuncCall is an object that describes an invocation
// of method RunDDLStatements on an instance of MockStore.
type StoreRunDDLStatementsFuncCall struct {
// Arg0 is the value of the 1st argument passed to this method
// invocation.
Arg0 context.Context
// Arg1 is the value of the 2nd argument passed to this method
// invocation.
Arg1 []string
// Result0 is the value of the 1st result returned from this method
// invocation.
Result0 error
}
// Args returns an interface slice containing the arguments of this
// invocation.
func (c StoreRunDDLStatementsFuncCall) Args() []interface{} {
return []interface{}{c.Arg0, c.Arg1}
}
// Results returns an interface slice containing the results of this
// invocation.
func (c StoreRunDDLStatementsFuncCall) Results() []interface{} {
return []interface{}{c.Result0}
}
// StoreTransactFunc describes the behavior when the Transact method of the
// parent MockStore instance is invoked.
type StoreTransactFunc struct {

View File

@ -19,6 +19,7 @@ go_library(
"//internal/lazyregexp",
"//lib/errors",
"//migrations",
"@com_github_google_go_cmp//cmp",
],
)

View File

@ -1,8 +1,11 @@
package schemas
import (
"fmt"
"strings"
"github.com/google/go-cmp/cmp"
"github.com/sourcegraph/sourcegraph/internal/lazyregexp"
)
@ -15,16 +18,129 @@ type SchemaDescription struct {
Views []ViewDescription
}
func (d SchemaDescription) WrappedExtensions() []ExtensionDescription {
extensions := make([]ExtensionDescription, 0, len(d.Extensions))
for _, name := range d.Extensions {
extensions = append(extensions, ExtensionDescription{Name: name})
}
return extensions
}
type ExtensionDescription struct {
Name string
}
func (d ExtensionDescription) CreateStatement() string {
return fmt.Sprintf("CREATE EXTENSION %s;", d.Name)
}
type EnumDescription struct {
Name string
Labels []string
}
func (d EnumDescription) CreateStatement() string {
quotedLabels := make([]string, 0, len(d.Labels))
for _, label := range d.Labels {
quotedLabels = append(quotedLabels, fmt.Sprintf("'%s'", label))
}
return fmt.Sprintf("CREATE TYPE %s AS ENUM (%s);", d.Name, strings.Join(quotedLabels, ", "))
}
func (d EnumDescription) DropStatement() string {
return fmt.Sprintf("DROP TYPE IF EXISTS %s;", d.Name)
}
// AlterToTarget returns a set of `ALTER ENUM ADD VALUE` statements to make the given enum equivalent to
// the expected enum, then additive statements cannot bring the enum to the expected state and we return
// a false-valued flag. In this case the existing type must be dropped and re-created as there's currently
// no way to *remove* values from an enum type.
func (d EnumDescription) AlterToTarget(target EnumDescription) ([]string, bool) {
labels := GroupByName(wrapStrings(d.Labels))
expectedLabels := GroupByName(wrapStrings(target.Labels))
for label := range labels {
if _, ok := expectedLabels[label]; !ok {
return nil, false
}
}
// If we're here then we're strictly missing labels and can add them in-place.
// Try to reconstruct the data we need to make the proper create type statement.
type missingLabel struct {
label string
neighbor string
before bool
}
missingLabels := make([]missingLabel, 0, len(target.Labels))
after := ""
for _, label := range target.Labels {
if _, ok := labels[label]; !ok && after != "" {
missingLabels = append(missingLabels, missingLabel{label: label, neighbor: after, before: false})
}
after = label
}
before := ""
for i := len(target.Labels) - 1; i >= 0; i-- {
label := target.Labels[i]
if _, ok := labels[label]; !ok && before != "" {
missingLabels = append(missingLabels, missingLabel{label: label, neighbor: before, before: true})
}
before = label
}
var (
ordered []string
reachable = GroupByName(wrapStrings(d.Labels))
)
outer:
for len(missingLabels) > 0 {
for _, s := range missingLabels {
// Neighbor doesn't exist yet, blocked from creating
if _, ok := reachable[s.neighbor]; !ok {
continue
}
rel := "AFTER"
if s.before {
rel = "BEFORE"
}
filtered := missingLabels[:0]
for _, l := range missingLabels {
if l.label != s.label {
filtered = append(filtered, l)
}
}
missingLabels = filtered
reachable[s.label] = stringNamer(s.label)
ordered = append(ordered, fmt.Sprintf("ALTER TYPE %s ADD VALUE '%s' %s '%s';", target.GetName(), s.label, rel, s.neighbor))
continue outer
}
panic("Infinite loop")
}
return ordered, true
}
type FunctionDescription struct {
Name string
Definition string
}
func (d FunctionDescription) CreateOrReplaceStatement() string {
return fmt.Sprintf("%s;", d.Definition)
}
type SequenceDescription struct {
Name string
TypeName string
@ -35,6 +151,44 @@ type SequenceDescription struct {
CycleOption string
}
func (d SequenceDescription) CreateStatement() string {
minValue := "NO MINVALUE"
if d.MinimumValue != 0 {
minValue = fmt.Sprintf("MINVALUE %d", d.MinimumValue)
}
maxValue := "NO MAXVALUE"
if d.MaximumValue != 0 {
maxValue = fmt.Sprintf("MAXVALUE %d", d.MaximumValue)
}
return fmt.Sprintf(
"CREATE SEQUENCE %s AS %s INCREMENT BY %d %s %s START WITH %d %s CYCLE;",
d.Name,
d.TypeName,
d.Increment,
minValue,
maxValue,
d.StartValue,
d.CycleOption,
)
}
func (d SequenceDescription) AlterToTarget(target SequenceDescription) ([]string, bool) {
statements := []string{}
if d.TypeName != target.TypeName {
statements = append(statements, fmt.Sprintf("ALTER SEQUENCE %s AS %s MAXVALUE %d;", d.Name, target.TypeName, target.MaximumValue))
// Remove from diff below
d.TypeName = target.TypeName
d.MaximumValue = target.MaximumValue
}
// Abort if there are other fields we haven't addressed
hasAdditionalDiff := cmp.Diff(d, target) != ""
return statements, !hasAdditionalDiff
}
type TableDescription struct {
Name string
Comment string
@ -58,6 +212,57 @@ type ColumnDescription struct {
Comment string
}
func (d ColumnDescription) CreateStatement(table TableDescription) string {
nullableExpr := ""
if !d.IsNullable {
nullableExpr = " NOT NULL"
}
defaultExpr := ""
if d.Default != "" {
defaultExpr = fmt.Sprintf(" DEFAULT %s", d.Default)
}
return fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s%s%s;", table.Name, d.Name, d.TypeName, nullableExpr, defaultExpr)
}
func (d ColumnDescription) DropStatement(table TableDescription) string {
return fmt.Sprintf("ALTER TABLE %s DROP COLUMN IF EXISTS %s;", table.Name, d.Name)
}
func (d ColumnDescription) AlterToTarget(table TableDescription, target ColumnDescription) ([]string, bool) {
statements := []string{}
if d.TypeName != target.TypeName {
statements = append(statements, fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s SET DATA TYPE %s;", table.Name, target.Name, target.TypeName))
// Remove from diff below
d.TypeName = target.TypeName
}
if d.IsNullable != target.IsNullable {
var verb string
if target.IsNullable {
verb = "DROP"
} else {
verb = "SET"
}
statements = append(statements, fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s %s NOT NULL;", table.Name, target.Name, verb))
// Remove from diff below
d.IsNullable = target.IsNullable
}
if d.Default != target.Default {
statements = append(statements, fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s SET DEFAULT %s;", table.Name, target.Name, target.Default))
// Remove from diff below
d.Default = target.Default
}
// Abort if there are other fields we haven't addressed
hasAdditionalDiff := cmp.Diff(d, target) != ""
return statements, !hasAdditionalDiff
}
type IndexDescription struct {
Name string
IsPrimaryKey bool
@ -69,6 +274,18 @@ type IndexDescription struct {
ConstraintDefinition string
}
func (d IndexDescription) CreateStatement(table TableDescription) string {
if d.ConstraintType == "u" || d.ConstraintType == "p" {
return fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s %s;", table.Name, d.Name, d.ConstraintDefinition)
}
return fmt.Sprintf("%s;", d.IndexDefinition)
}
func (d IndexDescription) DropStatement() string {
return fmt.Sprintf("DROP INDEX IF EXISTS %s;", d.GetName())
}
type ConstraintDescription struct {
Name string
ConstraintType string
@ -77,16 +294,58 @@ type ConstraintDescription struct {
ConstraintDefinition string
}
func (d ConstraintDescription) CreateStatement(table TableDescription) string {
return fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s %s;", table.Name, d.Name, d.ConstraintDefinition)
}
func (d ConstraintDescription) DropStatement(table TableDescription) string {
return fmt.Sprintf("ALTER TABLE %s DROP CONSTRAINT %s;", table.Name, d.Name)
}
type TriggerDescription struct {
Name string
Definition string
}
func (d TriggerDescription) CreateStatement() string {
return fmt.Sprintf("%s;", d.Definition)
}
func (d TriggerDescription) DropStatement(table TableDescription) string {
return fmt.Sprintf("DROP TRIGGER IF EXISTS %s ON %s;", d.Name, table.Name)
}
type ViewDescription struct {
Name string
Definition string
}
func (d ViewDescription) CreateStatement() string {
// pgsql indents definitions strangely; we copy that
return fmt.Sprintf("CREATE VIEW %s AS %s", d.Name, strings.TrimSpace(stripIndent(" "+d.Definition)))
}
func (d ViewDescription) DropStatement() string {
return fmt.Sprintf("DROP VIEW IF EXISTS %s;", d.Name)
}
// stripIndent removes the largest common indent from the given text.
func stripIndent(s string) string {
lines := strings.Split(strings.TrimRight(s, "\n"), "\n")
min := len(lines[0])
for _, line := range lines {
if indent := len(line) - len(strings.TrimLeft(line, " ")); indent < min {
min = indent
}
}
for i, line := range lines {
lines[i] = line[min:]
}
return strings.Join(lines, "\n")
}
func Canonicalize(schemaDescription SchemaDescription) {
for i := range schemaDescription.Tables {
sortColumnsByName(schemaDescription.Tables[i].Columns)
@ -104,6 +363,28 @@ func Canonicalize(schemaDescription SchemaDescription) {
type Namer interface{ GetName() string }
func GroupByName[T Namer](ts []T) map[string]T {
m := make(map[string]T, len(ts))
for _, t := range ts {
m[t.GetName()] = t
}
return m
}
type stringNamer string
func wrapStrings(ss []string) []Namer {
sn := make([]Namer, 0, len(ss))
for _, s := range ss {
sn = append(sn, stringNamer(s))
}
return sn
}
func (n stringNamer) GetName() string { return string(n) }
func (d ExtensionDescription) GetName() string { return d.Name }
func (d EnumDescription) GetName() string { return d.Name }
func (d FunctionDescription) GetName() string { return d.Name }
func (d SequenceDescription) GetName() string { return d.Name }

View File

@ -16,6 +16,7 @@ type Operations struct {
tryLock *observation.Operation
up *observation.Operation
versions *observation.Operation
runDDLStatements *observation.Operation
withMigrationLog *observation.Operation
}
@ -49,6 +50,7 @@ func NewOperations(observationCtx *observation.Context) *Operations {
tryLock: op("TryLock"),
up: op("Up"),
versions: op("Versions"),
runDDLStatements: op("RunDDLStatements"),
withMigrationLog: op("WithMigrationLog"),
}
})

View File

@ -282,6 +282,25 @@ WHERE row_number = 1
ORDER BY version
`
func (s *Store) RunDDLStatements(ctx context.Context, statements []string) (err error) {
ctx, _, endObservation := s.operations.runDDLStatements.With(ctx, &err, observation.Args{})
defer endObservation(1, observation.Args{})
tx, err := s.Transact(ctx)
if err != nil {
return err
}
defer func() { err = tx.Done(err) }()
for _, statement := range statements {
if err := tx.Exec(ctx, sqlf.Sprintf(statement)); err != nil {
return err
}
}
return nil
}
// TryLock attempts to create hold an advisory lock. This method returns a function that should be
// called once the lock should be released. This method accepts the current function's error output
// and wraps any additional errors that occur on close. Calling this method when the lock was not