mirror of
https://github.com/sourcegraph/sourcegraph.git
synced 2026-02-06 16:51:55 +00:00
drift: Polish output (#52030)
This commit is contained in:
parent
fe8e70d94d
commit
2211519fdf
@ -86,7 +86,7 @@ func Start(logger log.Logger, registerEnterpriseMigrators registerMigratorsUsing
|
||||
cliutil.DownTo(appName, newRunner, outputFactory, false),
|
||||
cliutil.Validate(appName, newRunner, outputFactory),
|
||||
cliutil.Describe(appName, newRunner, outputFactory),
|
||||
cliutil.Drift(appName, newRunner, outputFactory, DefaultSchemaFactories...),
|
||||
cliutil.Drift(appName, newRunner, outputFactory, false, DefaultSchemaFactories...),
|
||||
cliutil.AddLog(appName, newRunner, outputFactory),
|
||||
cliutil.Upgrade(appName, newRunnerWithSchemas, outputFactory, registerMigrators, DefaultSchemaFactories...),
|
||||
cliutil.Downgrade(appName, newRunnerWithSchemas, outputFactory, registerMigrators, DefaultSchemaFactories...),
|
||||
|
||||
@ -125,7 +125,7 @@ var (
|
||||
downToCommand = cliutil.DownTo("sg migration", makeRunner, outputFactory, true)
|
||||
validateCommand = cliutil.Validate("sg migration", makeRunner, outputFactory)
|
||||
describeCommand = cliutil.Describe("sg migration", makeRunner, outputFactory)
|
||||
driftCommand = cliutil.Drift("sg migration", makeRunner, outputFactory, schemaFactories...)
|
||||
driftCommand = cliutil.Drift("sg migration", makeRunner, outputFactory, true, schemaFactories...)
|
||||
addLogCommand = cliutil.AddLog("sg migration", makeRunner, outputFactory)
|
||||
|
||||
leavesCommand = &cli.Command{
|
||||
|
||||
@ -835,12 +835,13 @@ Available schemas:
|
||||
|
||||
Flags:
|
||||
|
||||
* `--auto-fix, --autofix`: Database goes brrrr.
|
||||
* `--feedback`: provide feedback about this command by opening up a GitHub discussion
|
||||
* `--file="<value>"`: The target schema description file.
|
||||
* `--ignore-migrator-update`: Ignore the running migrator not being the latest version. It is recommended to use the latest migrator version.
|
||||
* `--schema, --db="<value>"`: The target `schema` to compare. Possible values are 'frontend', 'codeintel' and 'codeinsights'
|
||||
* `--skip-version-check`: Skip validation of the instance's current version.
|
||||
* `--version="<value>"`: The target schema version. Can be a version (e.g. 5.0.2) or resolvable as a git revlike on the Sourcegraph repository (e.g. a branch, tag or commit hash).
|
||||
* `--version="<value>"`: The target schema version. Can be a version (e.g. 5.0.2) or resolvable as a git revlike on the Sourcegraph repository (e.g. a branch, tag or commit hash). (default: HEAD)
|
||||
|
||||
### sg migration add-log
|
||||
|
||||
|
||||
@ -44,6 +44,10 @@ func (s *memoryStore) Versions(ctx context.Context) (appliedVersions, pendingVer
|
||||
return s.appliedVersions, s.pendingVersions, s.failedVersions, nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) RunDDLStatements(ctx context.Context, statements []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *memoryStore) TryLock(ctx context.Context) (bool, func(err error) error, error) {
|
||||
return true, func(err error) error { return err }, nil
|
||||
}
|
||||
|
||||
@ -9,13 +9,21 @@ import (
|
||||
"cuelang.org/go/pkg/strings"
|
||||
"github.com/urfave/cli/v2"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/internal/database/migration/drift"
|
||||
descriptions "github.com/sourcegraph/sourcegraph/internal/database/migration/schemas"
|
||||
"github.com/sourcegraph/sourcegraph/internal/oobmigration"
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
"github.com/sourcegraph/sourcegraph/lib/output"
|
||||
)
|
||||
|
||||
func Drift(commandName string, factory RunnerFactory, outFactory OutputFactory, expectedSchemaFactories ...ExpectedSchemaFactory) *cli.Command {
|
||||
const maxAutofixAttempts = 3
|
||||
|
||||
func Drift(commandName string, factory RunnerFactory, outFactory OutputFactory, development bool, expectedSchemaFactories ...ExpectedSchemaFactory) *cli.Command {
|
||||
defaultVersion := ""
|
||||
if development {
|
||||
defaultVersion = "HEAD"
|
||||
}
|
||||
|
||||
schemaNameFlag := &cli.StringFlag{
|
||||
Name: "schema",
|
||||
Usage: "The target `schema` to compare. Possible values are 'frontend', 'codeintel' and 'codeinsights'",
|
||||
@ -27,6 +35,7 @@ func Drift(commandName string, factory RunnerFactory, outFactory OutputFactory,
|
||||
Usage: "The target schema version. Can be a version (e.g. 5.0.2) or resolvable as a git revlike on the Sourcegraph repository " +
|
||||
"(e.g. a branch, tag or commit hash).",
|
||||
Required: false,
|
||||
Value: defaultVersion,
|
||||
}
|
||||
fileFlag := &cli.StringFlag{
|
||||
Name: "file",
|
||||
@ -37,12 +46,20 @@ func Drift(commandName string, factory RunnerFactory, outFactory OutputFactory,
|
||||
Name: "skip-version-check",
|
||||
Usage: "Skip validation of the instance's current version.",
|
||||
Required: false,
|
||||
Value: development,
|
||||
}
|
||||
ignoreMigratorUpdateCheckFlag := &cli.BoolFlag{
|
||||
Name: "ignore-migrator-update",
|
||||
Usage: "Ignore the running migrator not being the latest version. It is recommended to use the latest migrator version.",
|
||||
Required: false,
|
||||
}
|
||||
// Only in available via `sg migration`` in development mode
|
||||
autofixFlag := &cli.BoolFlag{
|
||||
Name: "auto-fix",
|
||||
Usage: "Database goes brrrr.",
|
||||
Required: false,
|
||||
Aliases: []string{"autofix"},
|
||||
}
|
||||
|
||||
action := makeAction(outFactory, func(ctx context.Context, cmd *cli.Context, out *output.Output) error {
|
||||
airgapped := isAirgapped(ctx)
|
||||
@ -115,6 +132,7 @@ func Drift(commandName string, factory RunnerFactory, outFactory OutputFactory,
|
||||
NewExplicitFileSchemaFactory(file),
|
||||
}
|
||||
}
|
||||
|
||||
expectedSchema, err := fetchExpectedSchema(ctx, schemaName, version, out, expectedSchemaFactories)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -125,22 +143,59 @@ func Drift(commandName string, factory RunnerFactory, outFactory OutputFactory,
|
||||
return err
|
||||
}
|
||||
schema := schemas["public"]
|
||||
summaries := drift.CompareSchemaDescriptions(schemaName, version, canonicalize(schema), canonicalize(expectedSchema))
|
||||
|
||||
return compareAndDisplaySchemaDescriptions(out, schemaName, version, canonicalize(schema), canonicalize(expectedSchema))
|
||||
var autofixErr error
|
||||
if autofixFlag.Get(cmd) {
|
||||
for attempts := maxAutofixAttempts; attempts > 0 && len(summaries) > 0 && autofixErr == nil; attempts-- {
|
||||
allStatements := []string{}
|
||||
for _, summary := range summaries {
|
||||
if statements, ok := summary.Statements(); ok {
|
||||
allStatements = append(allStatements, statements...)
|
||||
}
|
||||
}
|
||||
if len(allStatements) == 0 {
|
||||
out.WriteLine(output.Linef(output.EmojiInfo, output.StyleReset, "No autofix to apply"))
|
||||
break
|
||||
}
|
||||
|
||||
autofixErr = store.RunDDLStatements(ctx, allStatements)
|
||||
if autofixErr != nil {
|
||||
out.WriteLine(output.Linef(output.EmojiFailure, output.StyleFailure, "Failed to apply autofix: %s", err))
|
||||
} else {
|
||||
out.WriteLine(output.Linef(output.EmojiSuccess, output.StyleSuccess, "Successfully applied autofix"))
|
||||
out.WriteLine(output.Linef(output.EmojiInfo, output.StyleReset, "Re-checking drift"))
|
||||
}
|
||||
|
||||
schemas, err := store.Describe(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
schema := schemas["public"]
|
||||
summaries = drift.CompareSchemaDescriptions(schemaName, version, canonicalize(schema), canonicalize(expectedSchema))
|
||||
}
|
||||
}
|
||||
|
||||
return displayDriftSummaries(out, summaries)
|
||||
})
|
||||
|
||||
flags := []cli.Flag{
|
||||
schemaNameFlag,
|
||||
versionFlag,
|
||||
fileFlag,
|
||||
skipVersionCheckFlag,
|
||||
ignoreMigratorUpdateCheckFlag,
|
||||
}
|
||||
if development {
|
||||
flags = append(flags, autofixFlag)
|
||||
}
|
||||
|
||||
return &cli.Command{
|
||||
Name: "drift",
|
||||
Usage: "Detect differences between the current database schema and the expected schema",
|
||||
Description: ConstructLongHelp(),
|
||||
Action: action,
|
||||
Flags: []cli.Flag{
|
||||
schemaNameFlag,
|
||||
versionFlag,
|
||||
fileFlag,
|
||||
skipVersionCheckFlag,
|
||||
ignoreMigratorUpdateCheckFlag,
|
||||
},
|
||||
Flags: flags,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -7,17 +7,15 @@ import (
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/internal/database/migration/drift"
|
||||
"github.com/sourcegraph/sourcegraph/internal/database/migration/schemas"
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
"github.com/sourcegraph/sourcegraph/lib/output"
|
||||
)
|
||||
|
||||
var errOutOfSync = errors.Newf("database schema is out of sync")
|
||||
|
||||
func compareAndDisplaySchemaDescriptions(rawOut *output.Output, schemaName, version string, actual, expected schemas.SchemaDescription) (err error) {
|
||||
func displayDriftSummaries(rawOut *output.Output, summaries []drift.Summary) (err error) {
|
||||
out := &preambledOutput{rawOut, false}
|
||||
|
||||
for _, summary := range drift.CompareSchemaDescriptions(schemaName, version, actual, expected) {
|
||||
for _, summary := range summaries {
|
||||
displaySummary(out, summary)
|
||||
err = errOutOfSync
|
||||
}
|
||||
|
||||
@ -19,6 +19,7 @@ type Store interface {
|
||||
WithMigrationLog(ctx context.Context, definition definition.Definition, up bool, f func() error) error
|
||||
Describe(ctx context.Context) (map[string]schemas.SchemaDescription, error)
|
||||
Versions(ctx context.Context) (appliedVersions, pendingVersions, failedVersions []int, _ error)
|
||||
RunDDLStatements(ctx context.Context, statements []string) error
|
||||
}
|
||||
|
||||
// OutputFactory allows providing global output that might not be instantiated at compile time.
|
||||
|
||||
@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/internal/database/migration/definition"
|
||||
"github.com/sourcegraph/sourcegraph/internal/database/migration/drift"
|
||||
"github.com/sourcegraph/sourcegraph/internal/database/migration/runner"
|
||||
"github.com/sourcegraph/sourcegraph/internal/database/migration/schemas"
|
||||
"github.com/sourcegraph/sourcegraph/internal/database/migration/shared"
|
||||
@ -339,18 +340,19 @@ func CheckDrift(ctx context.Context, r Runner, version string, out *output.Outpu
|
||||
}
|
||||
schema := schemaDescriptions["public"]
|
||||
|
||||
var drift bytes.Buffer
|
||||
driftOut := output.NewOutput(&drift, output.OutputOpts{})
|
||||
var buf bytes.Buffer
|
||||
driftOut := output.NewOutput(&buf, output.OutputOpts{})
|
||||
|
||||
expectedSchema, err := fetchExpectedSchema(ctx, schemaName, version, driftOut, expectedSchemaFactories)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if err := compareAndDisplaySchemaDescriptions(driftOut, schemaName, version, canonicalize(schema), canonicalize(expectedSchema)); err != nil {
|
||||
|
||||
if err := displayDriftSummaries(driftOut, drift.CompareSchemaDescriptions(schemaName, version, canonicalize(schema), canonicalize(expectedSchema))); err != nil {
|
||||
schemasWithDrift = append(schemasWithDrift,
|
||||
&schemaWithDrift{
|
||||
name: schemaName,
|
||||
drift: &drift,
|
||||
drift: &buf,
|
||||
},
|
||||
)
|
||||
}
|
||||
|
||||
4
internal/database/migration/drift/BUILD.bazel
generated
4
internal/database/migration/drift/BUILD.bazel
generated
@ -14,17 +14,13 @@ go_library(
|
||||
"compare_tables.go",
|
||||
"compare_triggers.go",
|
||||
"compare_views.go",
|
||||
"formatter_console.go",
|
||||
"summary.go",
|
||||
"util_collections.go",
|
||||
"util_search.go",
|
||||
"util_strings.go",
|
||||
],
|
||||
importpath = "github.com/sourcegraph/sourcegraph/internal/database/migration/drift",
|
||||
visibility = ["//:__subpackages__"],
|
||||
deps = [
|
||||
"//internal/database/migration/schemas",
|
||||
"//lib/output",
|
||||
"@com_github_google_go_cmp//cmp",
|
||||
],
|
||||
)
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
package drift
|
||||
|
||||
import (
|
||||
"sort"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/internal/database/migration/schemas"
|
||||
@ -24,10 +26,22 @@ func CompareSchemaDescriptions(schemaName, version string, actual, expected sche
|
||||
|
||||
// compareNamedLists invokes the given primary callback with a pair of differing elements from slices
|
||||
// `as` and `bs`, respectively, with the same name. If there is a missing element from `as`, there will
|
||||
// be an invocation of this callback with a nil value for its first parameter. Elements for which there
|
||||
// be an invocation of this callback with a nil value for its first parameter. If any invocation of the
|
||||
// function returns true, the output of this function will be true.
|
||||
func compareNamedLists[T schemas.Namer](
|
||||
as []T,
|
||||
bs []T,
|
||||
primaryCallback func(a *T, b T) Summary,
|
||||
) []Summary {
|
||||
return compareNamedListsStrict(as, bs, primaryCallback, noopAdditionalCallback[T])
|
||||
}
|
||||
|
||||
// compareNamedListsStrict invokes the given primary callback with a pair of differing elements from
|
||||
// slices `as` and `bs`, respectively, with the same name. If there is a missing element from `as`, there
|
||||
// will be an invocation of this callback with a nil value for its first parameter. Elements for which there
|
||||
// is no analog in `bs` will be collected and sent to an invocation of the additions callback. If any
|
||||
// invocation of either function returns true, the output of this function will be true.
|
||||
func compareNamedLists[T schemas.Namer](
|
||||
func compareNamedListsStrict[T schemas.Namer](
|
||||
as []T,
|
||||
bs []T,
|
||||
primaryCallback func(a *T, b T) Summary,
|
||||
@ -41,17 +55,31 @@ func compareNamedLists[T schemas.Namer](
|
||||
return nil
|
||||
}
|
||||
|
||||
return compareNamedListsMulti(as, bs, wrappedPrimaryCallback, additionsCallback)
|
||||
return compareNamedListsMultiStrict(as, bs, wrappedPrimaryCallback, additionsCallback)
|
||||
}
|
||||
|
||||
// compareNamedListsMulti invokes the given primary callback with a pair of differing elements from slices
|
||||
// `as` and `bs`, respectively, with the same name. Similar `compareNamedLists`, but this version expects
|
||||
// multiple `Summary` values from the callback.
|
||||
func compareNamedListsMulti[T schemas.Namer](
|
||||
as []T,
|
||||
bs []T,
|
||||
primaryCallback func(a *T, b T) []Summary,
|
||||
) []Summary {
|
||||
return compareNamedListsMultiStrict(as, bs, primaryCallback, noopAdditionalCallback[T])
|
||||
}
|
||||
|
||||
// compareNamedListsMultiStrict invokes the given primary callback with a pair of differing elements from
|
||||
// slices `as` and `bs`, respectively, with the same name. Similar `compareNamedListsStrict`, but
|
||||
// this version expects multiple `Summary` values from the callback.
|
||||
func compareNamedListsMultiStrict[T schemas.Namer](
|
||||
as []T,
|
||||
bs []T,
|
||||
primaryCallback func(a *T, b T) []Summary,
|
||||
additionsCallback func(additional []T) []Summary,
|
||||
) []Summary {
|
||||
am := groupByName(as)
|
||||
bm := groupByName(bs)
|
||||
am := schemas.GroupByName(as)
|
||||
bm := schemas.GroupByName(bs)
|
||||
additional := make([]T, 0, len(am))
|
||||
summaries := []Summary(nil)
|
||||
|
||||
@ -86,3 +114,14 @@ func compareNamedListsMulti[T schemas.Namer](
|
||||
func noopAdditionalCallback[T schemas.Namer](_ []T) []Summary {
|
||||
return nil
|
||||
}
|
||||
|
||||
// keys returns the ordered keys of the given map.
|
||||
func keys[T any](m map[string]T) []string {
|
||||
keys := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
return keys
|
||||
}
|
||||
|
||||
@ -3,83 +3,74 @@ package drift
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/internal/database/migration/schemas"
|
||||
)
|
||||
|
||||
func compareColumns(schemaName, version string, actualTable, expectedTable schemas.TableDescription) []Summary {
|
||||
return compareNamedLists(actualTable.Columns, expectedTable.Columns, func(column *schemas.ColumnDescription, expectedColumn schemas.ColumnDescription) Summary {
|
||||
return compareNamedListsStrict(
|
||||
actualTable.Columns,
|
||||
expectedTable.Columns,
|
||||
compareColumnsCallbackFor(schemaName, version, expectedTable),
|
||||
compareColumnsAdditionalCallbackFor(expectedTable),
|
||||
)
|
||||
}
|
||||
|
||||
func compareColumnsCallbackFor(schemaName, version string, table schemas.TableDescription) func(_ *schemas.ColumnDescription, _ schemas.ColumnDescription) Summary {
|
||||
return func(column *schemas.ColumnDescription, expectedColumn schemas.ColumnDescription) Summary {
|
||||
if column == nil {
|
||||
url := makeSearchURL(schemaName, version,
|
||||
fmt.Sprintf("CREATE TABLE %s", expectedTable.Name),
|
||||
fmt.Sprintf("ALTER TABLE ONLY %s", expectedTable.Name),
|
||||
)
|
||||
|
||||
return newDriftSummary(
|
||||
fmt.Sprintf("%q.%q", expectedTable.Name, expectedColumn.Name),
|
||||
fmt.Sprintf("Missing column %q.%q", expectedTable.Name, expectedColumn.Name),
|
||||
fmt.Sprintf("%q.%q", table.GetName(), expectedColumn.GetName()),
|
||||
fmt.Sprintf("Missing column %q.%q", table.GetName(), expectedColumn.GetName()),
|
||||
"define the column",
|
||||
).withURLHint(url)
|
||||
).withStatements(
|
||||
expectedColumn.CreateStatement(table),
|
||||
).withURLHint(
|
||||
makeSearchURL(schemaName, version,
|
||||
fmt.Sprintf("CREATE TABLE %s", table.GetName()),
|
||||
fmt.Sprintf("ALTER TABLE ONLY %s", table.GetName()),
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
equivIf := func(f func(*schemas.ColumnDescription)) bool {
|
||||
c := *column
|
||||
f(&c)
|
||||
return cmp.Diff(c, expectedColumn) == ""
|
||||
}
|
||||
|
||||
// TODO
|
||||
// if equivIf(func(s *schemas.ColumnDescription) { s.TypeName = expectedColumn.TypeName }) {}
|
||||
if equivIf(func(s *schemas.ColumnDescription) { s.IsNullable = expectedColumn.IsNullable }) {
|
||||
var verb string
|
||||
if expectedColumn.IsNullable {
|
||||
verb = "DROP"
|
||||
} else {
|
||||
verb = "SET"
|
||||
}
|
||||
|
||||
alterColumnStmt := fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s %s NOT NULL;", expectedTable.Name, expectedColumn.Name, verb)
|
||||
|
||||
if alterStatements, ok := (*column).AlterToTarget(table, expectedColumn); ok {
|
||||
return newDriftSummary(
|
||||
fmt.Sprintf("%q.%q", expectedTable.Name, expectedColumn.Name),
|
||||
fmt.Sprintf("Unexpected properties of column %q.%q", expectedTable.Name, expectedColumn.Name),
|
||||
"change the column nullability constraint",
|
||||
).withDiff(expectedColumn, *column).withStatements(alterColumnStmt)
|
||||
expectedColumn.GetName(),
|
||||
fmt.Sprintf("Unexpected properties of column %s.%q", table.GetName(), expectedColumn.GetName()),
|
||||
"alter the column",
|
||||
).withStatements(
|
||||
alterStatements...,
|
||||
)
|
||||
}
|
||||
if equivIf(func(s *schemas.ColumnDescription) { s.Default = expectedColumn.Default }) {
|
||||
alterColumnStmt := fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s SET DEFAULT %s;", expectedTable.Name, expectedColumn.Name, expectedColumn.Default)
|
||||
|
||||
return newDriftSummary(
|
||||
fmt.Sprintf("%q.%q", expectedTable.Name, expectedColumn.Name),
|
||||
fmt.Sprintf("Unexpected properties of column %q.%q", expectedTable.Name, expectedColumn.Name),
|
||||
"change the column default",
|
||||
).withDiff(expectedColumn, *column).withStatements(alterColumnStmt)
|
||||
}
|
||||
|
||||
url := makeSearchURL(schemaName, version,
|
||||
fmt.Sprintf("CREATE TABLE %s", expectedTable.Name),
|
||||
fmt.Sprintf("ALTER TABLE ONLY %s", expectedTable.Name),
|
||||
)
|
||||
|
||||
return newDriftSummary(
|
||||
fmt.Sprintf("%q.%q", expectedTable.Name, expectedColumn.Name),
|
||||
fmt.Sprintf("Unexpected properties of column %q.%q", expectedTable.Name, expectedColumn.Name),
|
||||
fmt.Sprintf("%q.%q", table.GetName(), expectedColumn.GetName()),
|
||||
fmt.Sprintf("Unexpected properties of column %q.%q", table.GetName(), expectedColumn.GetName()),
|
||||
"redefine the column",
|
||||
).withDiff(expectedColumn, *column).withURLHint(url)
|
||||
}, func(additional []schemas.ColumnDescription) []Summary {
|
||||
).withDiff(
|
||||
expectedColumn,
|
||||
*column,
|
||||
).withURLHint(
|
||||
makeSearchURL(schemaName, version,
|
||||
fmt.Sprintf("CREATE TABLE %s", table.GetName()),
|
||||
fmt.Sprintf("ALTER TABLE ONLY %s", table.GetName()),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func compareColumnsAdditionalCallbackFor(table schemas.TableDescription) func(_ []schemas.ColumnDescription) []Summary {
|
||||
return func(additional []schemas.ColumnDescription) []Summary {
|
||||
summaries := []Summary{}
|
||||
for _, column := range additional {
|
||||
alterColumnStmt := fmt.Sprintf("ALTER TABLE %s DROP COLUMN %s;", expectedTable.Name, column.Name)
|
||||
|
||||
summary := newDriftSummary(
|
||||
fmt.Sprintf("%q.%q", expectedTable.Name, column.Name),
|
||||
fmt.Sprintf("Unexpected column %q.%q", expectedTable.Name, column.Name),
|
||||
summaries = append(summaries, newDriftSummary(
|
||||
fmt.Sprintf("%q.%q", table.GetName(), column.GetName()),
|
||||
fmt.Sprintf("Unexpected column %q.%q", table.GetName(), column.GetName()),
|
||||
"drop the column",
|
||||
).withStatements(alterColumnStmt)
|
||||
summaries = append(summaries, summary)
|
||||
).withStatements(
|
||||
column.DropStatement(table),
|
||||
))
|
||||
}
|
||||
|
||||
return summaries
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -7,36 +7,53 @@ import (
|
||||
)
|
||||
|
||||
func compareConstraints(actualTable, expectedTable schemas.TableDescription) []Summary {
|
||||
return compareNamedLists(actualTable.Constraints, expectedTable.Constraints, func(constraint *schemas.ConstraintDescription, expectedConstraint schemas.ConstraintDescription) Summary {
|
||||
createConstraintStmt := fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s %s;", expectedTable.Name, expectedConstraint.Name, expectedConstraint.ConstraintDefinition)
|
||||
dropConstraintStmt := fmt.Sprintf("ALTER TABLE %s DROP CONSTRAINT %s;", expectedTable.Name, expectedConstraint.Name)
|
||||
return compareNamedListsStrict(
|
||||
actualTable.Constraints,
|
||||
expectedTable.Constraints,
|
||||
compareConstraintsCallbackFor(expectedTable),
|
||||
compareConstraintsAdditionalCallbackFor(expectedTable),
|
||||
)
|
||||
}
|
||||
|
||||
func compareConstraintsCallbackFor(table schemas.TableDescription) func(_ *schemas.ConstraintDescription, _ schemas.ConstraintDescription) Summary {
|
||||
return func(constraint *schemas.ConstraintDescription, expectedConstraint schemas.ConstraintDescription) Summary {
|
||||
if constraint == nil {
|
||||
return newDriftSummary(
|
||||
fmt.Sprintf("%q.%q", expectedTable.Name, expectedConstraint.Name),
|
||||
fmt.Sprintf("Missing constraint %q.%q", expectedTable.Name, expectedConstraint.Name),
|
||||
fmt.Sprintf("%q.%q", table.GetName(), expectedConstraint.GetName()),
|
||||
fmt.Sprintf("Missing constraint %q.%q", table.GetName(), expectedConstraint.GetName()),
|
||||
"define the constraint",
|
||||
).withStatements(createConstraintStmt)
|
||||
).withStatements(
|
||||
expectedConstraint.CreateStatement(table),
|
||||
)
|
||||
}
|
||||
|
||||
return newDriftSummary(
|
||||
fmt.Sprintf("%q.%q", expectedTable.Name, expectedConstraint.Name),
|
||||
fmt.Sprintf("Unexpected properties of constraint %q.%q", expectedTable.Name, expectedConstraint.Name),
|
||||
fmt.Sprintf("%q.%q", table.GetName(), expectedConstraint.GetName()),
|
||||
fmt.Sprintf("Unexpected properties of constraint %q.%q", table.GetName(), expectedConstraint.GetName()),
|
||||
"redefine the constraint",
|
||||
).withDiff(expectedConstraint, *constraint).withStatements(dropConstraintStmt, createConstraintStmt)
|
||||
}, func(additional []schemas.ConstraintDescription) []Summary {
|
||||
).withDiff(
|
||||
expectedConstraint,
|
||||
*constraint,
|
||||
).withStatements(
|
||||
expectedConstraint.DropStatement(table),
|
||||
expectedConstraint.CreateStatement(table),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func compareConstraintsAdditionalCallbackFor(table schemas.TableDescription) func(_ []schemas.ConstraintDescription) []Summary {
|
||||
return func(additional []schemas.ConstraintDescription) []Summary {
|
||||
summaries := []Summary{}
|
||||
for _, constraint := range additional {
|
||||
alterTableStmt := fmt.Sprintf("ALTER TABLE %s DROP CONSTRAINT %s;", expectedTable.Name, constraint.Name)
|
||||
|
||||
summary := newDriftSummary(
|
||||
fmt.Sprintf("%q.%q", expectedTable.Name, constraint.Name),
|
||||
fmt.Sprintf("Unexpected constraint %q.%q", expectedTable.Name, constraint.Name),
|
||||
summaries = append(summaries, newDriftSummary(
|
||||
fmt.Sprintf("%q.%q", table.GetName(), constraint.GetName()),
|
||||
fmt.Sprintf("Unexpected constraint %q.%q", table.GetName(), constraint.GetName()),
|
||||
"drop the constraint",
|
||||
).withStatements(alterTableStmt)
|
||||
summaries = append(summaries, summary)
|
||||
).withStatements(
|
||||
constraint.DropStatement(table),
|
||||
))
|
||||
}
|
||||
|
||||
return summaries
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -2,120 +2,44 @@ package drift
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/internal/database/migration/schemas"
|
||||
)
|
||||
|
||||
func compareEnums(schemaName, version string, actual, expected schemas.SchemaDescription) []Summary {
|
||||
return compareNamedLists(actual.Enums, expected.Enums, func(enum *schemas.EnumDescription, expectedEnum schemas.EnumDescription) Summary {
|
||||
quotedLabels := make([]string, 0, len(expectedEnum.Labels))
|
||||
for _, label := range expectedEnum.Labels {
|
||||
quotedLabels = append(quotedLabels, fmt.Sprintf("'%s'", label))
|
||||
}
|
||||
createEnumStmt := fmt.Sprintf("CREATE TYPE %s AS ENUM (%s);", expectedEnum.Name, strings.Join(quotedLabels, ", "))
|
||||
dropEnumStmt := fmt.Sprintf("DROP TYPE %s;", expectedEnum.Name)
|
||||
|
||||
if enum == nil {
|
||||
return newDriftSummary(
|
||||
expectedEnum.Name,
|
||||
fmt.Sprintf("Missing enum %q", expectedEnum.Name),
|
||||
"create the type",
|
||||
).withStatements(createEnumStmt)
|
||||
}
|
||||
|
||||
if ordered, ok := constructEnumRepairStatements(*enum, expectedEnum); ok {
|
||||
return newDriftSummary(
|
||||
expectedEnum.Name,
|
||||
fmt.Sprintf("Missing %d labels for enum %q", len(ordered), expectedEnum.Name),
|
||||
"add the missing enum labels",
|
||||
).withStatements(ordered...)
|
||||
}
|
||||
return compareNamedLists(actual.Enums, expected.Enums, compareEnumsCallback)
|
||||
}
|
||||
|
||||
func compareEnumsCallback(enum *schemas.EnumDescription, expectedEnum schemas.EnumDescription) Summary {
|
||||
if enum == nil {
|
||||
return newDriftSummary(
|
||||
expectedEnum.Name,
|
||||
fmt.Sprintf("Unexpected labels for enum %q", expectedEnum.Name),
|
||||
"drop and re-create the type",
|
||||
).withDiff(enum.Labels, expectedEnum.Labels).withStatements(dropEnumStmt, createEnumStmt)
|
||||
}, noopAdditionalCallback[schemas.EnumDescription])
|
||||
}
|
||||
|
||||
// constructEnumRepairStatements returns a set of `ALTER ENUM ADD VALUE` statements to make
|
||||
// the given enum equivalent to the given expected enum. If the given enum is not a subset of
|
||||
// the expected enum, then additive statements cannot bring the enum to the expected state and
|
||||
// we return a false-valued flag. In this case the existing type must be dropped and re-created
|
||||
// as there's currently no way to *remove* values from an enum type.
|
||||
func constructEnumRepairStatements(enum, expectedEnum schemas.EnumDescription) ([]string, bool) {
|
||||
labels := groupByName(wrapStrings(enum.Labels))
|
||||
expectedLabels := groupByName(wrapStrings(expectedEnum.Labels))
|
||||
|
||||
for label := range labels {
|
||||
if _, ok := expectedLabels[label]; !ok {
|
||||
return nil, false
|
||||
}
|
||||
expectedEnum.GetName(),
|
||||
fmt.Sprintf("Missing enum %q", expectedEnum.GetName()),
|
||||
"define the type",
|
||||
).withStatements(
|
||||
expectedEnum.CreateStatement(),
|
||||
)
|
||||
}
|
||||
|
||||
// If we're here then we're strictly missing labels and can add them in-place.
|
||||
// Try to reconstruct the data we need to make the proper create type statement.
|
||||
|
||||
type missingLabel struct {
|
||||
label string
|
||||
neighbor string
|
||||
before bool
|
||||
}
|
||||
missingLabels := make([]missingLabel, 0, len(expectedEnum.Labels))
|
||||
|
||||
after := ""
|
||||
for _, label := range expectedEnum.Labels {
|
||||
if _, ok := labels[label]; !ok && after != "" {
|
||||
missingLabels = append(missingLabels, missingLabel{label: label, neighbor: after, before: false})
|
||||
}
|
||||
after = label
|
||||
if alterStatements, ok := (*enum).AlterToTarget(expectedEnum); ok {
|
||||
return newDriftSummary(
|
||||
expectedEnum.GetName(),
|
||||
fmt.Sprintf("Unexpected properties of enum %q", expectedEnum.GetName()),
|
||||
"alter the type",
|
||||
).withStatements(
|
||||
alterStatements...,
|
||||
)
|
||||
}
|
||||
|
||||
before := ""
|
||||
for i := len(expectedEnum.Labels) - 1; i >= 0; i-- {
|
||||
label := expectedEnum.Labels[i]
|
||||
|
||||
if _, ok := labels[label]; !ok && before != "" {
|
||||
missingLabels = append(missingLabels, missingLabel{label: label, neighbor: before, before: true})
|
||||
}
|
||||
before = label
|
||||
}
|
||||
|
||||
var (
|
||||
ordered []string
|
||||
reachable = groupByName(wrapStrings(enum.Labels))
|
||||
return newDriftSummary(
|
||||
expectedEnum.GetName(),
|
||||
fmt.Sprintf("Unexpected properties of enum %q", expectedEnum.GetName()),
|
||||
"redefine the type",
|
||||
).withDiff(
|
||||
expectedEnum.Labels,
|
||||
enum.Labels,
|
||||
).withStatements(
|
||||
expectedEnum.DropStatement(),
|
||||
expectedEnum.CreateStatement(),
|
||||
)
|
||||
|
||||
outer:
|
||||
for len(missingLabels) > 0 {
|
||||
for _, s := range missingLabels {
|
||||
// Neighbor doesn't exist yet, blocked from creating
|
||||
if _, ok := reachable[s.neighbor]; !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
rel := "AFTER"
|
||||
if s.before {
|
||||
rel = "BEFORE"
|
||||
}
|
||||
|
||||
filtered := missingLabels[:0]
|
||||
for _, l := range missingLabels {
|
||||
if l.label != s.label {
|
||||
filtered = append(filtered, l)
|
||||
}
|
||||
}
|
||||
|
||||
missingLabels = filtered
|
||||
reachable[s.label] = stringNamer(s.label)
|
||||
ordered = append(ordered, fmt.Sprintf("ALTER TYPE %s ADD VALUE '%s' %s '%s';", expectedEnum.Name, s.label, rel, s.neighbor))
|
||||
continue outer
|
||||
}
|
||||
|
||||
panic("Infinite loop")
|
||||
}
|
||||
|
||||
return ordered, true
|
||||
}
|
||||
|
||||
@ -7,17 +7,19 @@ import (
|
||||
)
|
||||
|
||||
func compareExtensions(schemaName, version string, actual, expected schemas.SchemaDescription) []Summary {
|
||||
return compareNamedLists(wrapStrings(actual.Extensions), wrapStrings(expected.Extensions), func(extension *stringNamer, expectedExtension stringNamer) Summary {
|
||||
if extension == nil {
|
||||
createExtensionStmt := fmt.Sprintf("CREATE EXTENSION %s;", expectedExtension)
|
||||
|
||||
return newDriftSummary(
|
||||
expectedExtension.GetName(),
|
||||
fmt.Sprintf("Missing extension %q", expectedExtension),
|
||||
"install the extension",
|
||||
).withStatements(createExtensionStmt)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, noopAdditionalCallback[stringNamer])
|
||||
return compareNamedLists(actual.WrappedExtensions(), expected.WrappedExtensions(), compareExtensionsCallback)
|
||||
}
|
||||
|
||||
func compareExtensionsCallback(extension *schemas.ExtensionDescription, expectedExtension schemas.ExtensionDescription) Summary {
|
||||
if extension == nil {
|
||||
return newDriftSummary(
|
||||
expectedExtension.GetName(),
|
||||
fmt.Sprintf("Missing extension %q", expectedExtension.GetName()),
|
||||
"define the extension",
|
||||
).withStatements(
|
||||
expectedExtension.CreateStatement(),
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -2,27 +2,33 @@ package drift
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/internal/database/migration/schemas"
|
||||
)
|
||||
|
||||
func compareFunctions(schemaName, version string, actual, expected schemas.SchemaDescription) []Summary {
|
||||
return compareNamedLists(actual.Functions, expected.Functions, func(function *schemas.FunctionDescription, expectedFunction schemas.FunctionDescription) Summary {
|
||||
definitionStmt := fmt.Sprintf("%s;", strings.TrimSpace(expectedFunction.Definition))
|
||||
|
||||
if function == nil {
|
||||
return newDriftSummary(
|
||||
expectedFunction.Name,
|
||||
fmt.Sprintf("Missing function %q", expectedFunction.Name),
|
||||
"define the function",
|
||||
).withStatements(definitionStmt)
|
||||
}
|
||||
|
||||
return newDriftSummary(
|
||||
expectedFunction.Name,
|
||||
fmt.Sprintf("Unexpected definition of function %q", expectedFunction.Name),
|
||||
"replace the function definition",
|
||||
).withDiff(expectedFunction.Definition, function.Definition).withStatements(definitionStmt)
|
||||
}, noopAdditionalCallback[schemas.FunctionDescription])
|
||||
return compareNamedLists(actual.Functions, expected.Functions, compareFunctionsCallback)
|
||||
}
|
||||
|
||||
func compareFunctionsCallback(function *schemas.FunctionDescription, expectedFunction schemas.FunctionDescription) Summary {
|
||||
if function == nil {
|
||||
return newDriftSummary(
|
||||
expectedFunction.GetName(),
|
||||
fmt.Sprintf("Missing function %q", expectedFunction.GetName()),
|
||||
"define the function",
|
||||
).withStatements(
|
||||
expectedFunction.CreateOrReplaceStatement(),
|
||||
)
|
||||
}
|
||||
|
||||
return newDriftSummary(
|
||||
expectedFunction.GetName(),
|
||||
fmt.Sprintf("Unexpected definition of function %q", expectedFunction.GetName()),
|
||||
"redefine the function",
|
||||
).withDiff(
|
||||
expectedFunction.Definition,
|
||||
function.Definition,
|
||||
).withStatements(
|
||||
expectedFunction.CreateOrReplaceStatement(),
|
||||
)
|
||||
}
|
||||
|
||||
@ -7,45 +7,53 @@ import (
|
||||
)
|
||||
|
||||
func compareIndexes(actualTable, expectedTable schemas.TableDescription) []Summary {
|
||||
return compareNamedLists(actualTable.Indexes, expectedTable.Indexes, func(index *schemas.IndexDescription, expectedIndex schemas.IndexDescription) Summary {
|
||||
var createIndexStmt string
|
||||
switch expectedIndex.ConstraintType {
|
||||
case "u":
|
||||
fallthrough
|
||||
case "p":
|
||||
createIndexStmt = fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s %s;", actualTable.Name, expectedIndex.Name, expectedIndex.ConstraintDefinition)
|
||||
default:
|
||||
createIndexStmt = fmt.Sprintf("%s;", expectedIndex.IndexDefinition)
|
||||
}
|
||||
return compareNamedListsStrict(
|
||||
actualTable.Indexes,
|
||||
expectedTable.Indexes,
|
||||
compareIndexesCallbackFor(expectedTable),
|
||||
compareIndexesCallbackAdditionalFor(expectedTable),
|
||||
)
|
||||
}
|
||||
|
||||
func compareIndexesCallbackFor(table schemas.TableDescription) func(_ *schemas.IndexDescription, _ schemas.IndexDescription) Summary {
|
||||
return func(index *schemas.IndexDescription, expectedIndex schemas.IndexDescription) Summary {
|
||||
if index == nil {
|
||||
return newDriftSummary(
|
||||
fmt.Sprintf("%q.%q", expectedTable.Name, expectedIndex.Name),
|
||||
fmt.Sprintf("Missing index %q.%q", expectedTable.Name, expectedIndex.Name),
|
||||
fmt.Sprintf("%q.%q", table.GetName(), expectedIndex.GetName()),
|
||||
fmt.Sprintf("Missing index %q.%q", table.GetName(), expectedIndex.GetName()),
|
||||
"define the index",
|
||||
).withStatements(createIndexStmt)
|
||||
).withStatements(
|
||||
expectedIndex.CreateStatement(table),
|
||||
)
|
||||
}
|
||||
|
||||
dropIndexStmt := fmt.Sprintf("DROP INDEX %s;", expectedIndex.Name)
|
||||
|
||||
return newDriftSummary(
|
||||
fmt.Sprintf("%q.%q", expectedTable.Name, expectedIndex.Name),
|
||||
fmt.Sprintf("Unexpected properties of index %q.%q", expectedTable.Name, expectedIndex.Name),
|
||||
fmt.Sprintf("%q.%q", table.GetName(), expectedIndex.GetName()),
|
||||
fmt.Sprintf("Unexpected properties of index %q.%q", table.GetName(), expectedIndex.GetName()),
|
||||
"redefine the index",
|
||||
).withDiff(expectedIndex, *index).withStatements(dropIndexStmt, createIndexStmt)
|
||||
}, func(additional []schemas.IndexDescription) []Summary {
|
||||
).withDiff(
|
||||
expectedIndex,
|
||||
*index,
|
||||
).withStatements(
|
||||
expectedIndex.DropStatement(),
|
||||
expectedIndex.CreateStatement(table),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func compareIndexesCallbackAdditionalFor(table schemas.TableDescription) func(_ []schemas.IndexDescription) []Summary {
|
||||
return func(additional []schemas.IndexDescription) []Summary {
|
||||
summaries := []Summary{}
|
||||
for _, index := range additional {
|
||||
dropIndexStmt := fmt.Sprintf("DROP INDEX %s;", index.Name)
|
||||
|
||||
summary := newDriftSummary(
|
||||
fmt.Sprintf("%q.%q", expectedTable.Name, index.Name),
|
||||
fmt.Sprintf("Unexpected index %q.%q", expectedTable.Name, index.Name),
|
||||
summaries = append(summaries, newDriftSummary(
|
||||
fmt.Sprintf("%q.%q", table.GetName(), index.GetName()),
|
||||
fmt.Sprintf("Unexpected index %q.%q", table.GetName(), index.GetName()),
|
||||
"drop the index",
|
||||
).withStatements(dropIndexStmt)
|
||||
summaries = append(summaries, summary)
|
||||
).withStatements(
|
||||
index.DropStatement(),
|
||||
))
|
||||
}
|
||||
|
||||
return summaries
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -7,24 +7,43 @@ import (
|
||||
)
|
||||
|
||||
func compareSequences(schemaName, version string, actual, expected schemas.SchemaDescription) []Summary {
|
||||
return compareNamedLists(actual.Sequences, expected.Sequences, func(sequence *schemas.SequenceDescription, expectedSequence schemas.SequenceDescription) Summary {
|
||||
definitionStmt := makeSearchURL(schemaName, version,
|
||||
fmt.Sprintf("CREATE SEQUENCE %s", expectedSequence.Name),
|
||||
fmt.Sprintf("nextval('%s'::regclass);", expectedSequence.Name),
|
||||
)
|
||||
return compareNamedLists(actual.Sequences, expected.Sequences, compareSequencesCallbackFor(schemaName, version))
|
||||
}
|
||||
|
||||
func compareSequencesCallbackFor(schemaName, version string) func(_ *schemas.SequenceDescription, _ schemas.SequenceDescription) Summary {
|
||||
return func(sequence *schemas.SequenceDescription, expectedSequence schemas.SequenceDescription) Summary {
|
||||
if sequence == nil {
|
||||
return newDriftSummary(
|
||||
expectedSequence.Name,
|
||||
fmt.Sprintf("Missing sequence %q", expectedSequence.Name),
|
||||
expectedSequence.GetName(),
|
||||
fmt.Sprintf("Missing sequence %q", expectedSequence.GetName()),
|
||||
"define the sequence",
|
||||
).withStatements(definitionStmt)
|
||||
).withStatements(
|
||||
expectedSequence.CreateStatement(),
|
||||
)
|
||||
}
|
||||
|
||||
if alterStatements, ok := (*sequence).AlterToTarget(expectedSequence); ok {
|
||||
return newDriftSummary(
|
||||
expectedSequence.GetName(),
|
||||
fmt.Sprintf("Unexpected properties of sequence %q", expectedSequence.GetName()),
|
||||
"alter the sequence",
|
||||
).withStatements(
|
||||
alterStatements...,
|
||||
)
|
||||
}
|
||||
|
||||
return newDriftSummary(
|
||||
expectedSequence.Name,
|
||||
fmt.Sprintf("Unexpected properties of sequence %q", expectedSequence.Name),
|
||||
expectedSequence.GetName(),
|
||||
fmt.Sprintf("Unexpected properties of sequence %q", expectedSequence.GetName()),
|
||||
"redefine the sequence",
|
||||
).withDiff(expectedSequence, *sequence).withStatements(definitionStmt)
|
||||
}, noopAdditionalCallback[schemas.SequenceDescription])
|
||||
).withDiff(
|
||||
expectedSequence,
|
||||
*sequence,
|
||||
).withURLHint(
|
||||
makeSearchURL(schemaName, version,
|
||||
fmt.Sprintf("CREATE SEQUENCE %s", expectedSequence.GetName()),
|
||||
fmt.Sprintf("nextval('%s'::regclass);", expectedSequence.GetName()),
|
||||
),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
@ -7,19 +7,23 @@ import (
|
||||
)
|
||||
|
||||
func compareTables(schemaName, version string, actual, expected schemas.SchemaDescription) []Summary {
|
||||
return compareNamedListsMulti(actual.Tables, expected.Tables, func(table *schemas.TableDescription, expectedTable schemas.TableDescription) []Summary {
|
||||
if table == nil {
|
||||
url := makeSearchURL(schemaName, version,
|
||||
fmt.Sprintf("CREATE TABLE %s", expectedTable.Name),
|
||||
fmt.Sprintf("ALTER TABLE ONLY %s", expectedTable.Name),
|
||||
fmt.Sprintf("CREATE .*(INDEX|TRIGGER).* ON %s", expectedTable.Name),
|
||||
)
|
||||
return compareNamedListsMulti(actual.Tables, expected.Tables, compareTablesCallbackFor(schemaName, version))
|
||||
}
|
||||
|
||||
func compareTablesCallbackFor(schemaName, version string) func(_ *schemas.TableDescription, _ schemas.TableDescription) []Summary {
|
||||
return func(table *schemas.TableDescription, expectedTable schemas.TableDescription) []Summary {
|
||||
if table == nil {
|
||||
return singleton(newDriftSummary(
|
||||
expectedTable.Name,
|
||||
fmt.Sprintf("Missing table %q", expectedTable.Name),
|
||||
expectedTable.GetName(),
|
||||
fmt.Sprintf("Missing table %q", expectedTable.GetName()),
|
||||
"define the table",
|
||||
).withURLHint(url))
|
||||
).withURLHint(
|
||||
makeSearchURL(schemaName, version,
|
||||
fmt.Sprintf("CREATE TABLE %s", expectedTable.GetName()),
|
||||
fmt.Sprintf("ALTER TABLE ONLY %s", expectedTable.GetName()),
|
||||
fmt.Sprintf("CREATE .*(INDEX|TRIGGER).* ON %s", expectedTable.GetName()),
|
||||
),
|
||||
))
|
||||
}
|
||||
|
||||
summaries := []Summary(nil)
|
||||
@ -28,5 +32,5 @@ func compareTables(schemaName, version string, actual, expected schemas.SchemaDe
|
||||
summaries = append(summaries, compareIndexes(*table, expectedTable)...)
|
||||
summaries = append(summaries, compareTriggers(*table, expectedTable)...)
|
||||
return summaries
|
||||
}, noopAdditionalCallback[schemas.TableDescription])
|
||||
}
|
||||
}
|
||||
|
||||
@ -7,36 +7,53 @@ import (
|
||||
)
|
||||
|
||||
func compareTriggers(actualTable, expectedTable schemas.TableDescription) []Summary {
|
||||
return compareNamedLists(actualTable.Triggers, expectedTable.Triggers, func(trigger *schemas.TriggerDescription, expectedTrigger schemas.TriggerDescription) Summary {
|
||||
createTriggerStmt := fmt.Sprintf("%s;", expectedTrigger.Definition)
|
||||
dropTriggerStmt := fmt.Sprintf("DROP TRIGGER %s ON %s;", expectedTrigger.Name, expectedTable.Name)
|
||||
return compareNamedListsStrict(
|
||||
actualTable.Triggers,
|
||||
expectedTable.Triggers,
|
||||
compareNamedListsCallbackFor(expectedTable),
|
||||
compareNamedListsAdditionalCallbackFor(expectedTable),
|
||||
)
|
||||
}
|
||||
|
||||
func compareNamedListsCallbackFor(table schemas.TableDescription) func(_ *schemas.TriggerDescription, _ schemas.TriggerDescription) Summary {
|
||||
return func(trigger *schemas.TriggerDescription, expectedTrigger schemas.TriggerDescription) Summary {
|
||||
if trigger == nil {
|
||||
return newDriftSummary(
|
||||
fmt.Sprintf("%q.%q", expectedTable.Name, expectedTrigger.Name),
|
||||
fmt.Sprintf("Missing trigger %q.%q", expectedTable.Name, expectedTrigger.Name),
|
||||
fmt.Sprintf("%q.%q", table.GetName(), expectedTrigger.GetName()),
|
||||
fmt.Sprintf("Missing trigger %q.%q", table.GetName(), expectedTrigger.GetName()),
|
||||
"define the trigger",
|
||||
).withStatements(createTriggerStmt)
|
||||
).withStatements(
|
||||
expectedTrigger.CreateStatement(),
|
||||
)
|
||||
}
|
||||
|
||||
return newDriftSummary(
|
||||
fmt.Sprintf("%q.%q", expectedTable.Name, expectedTrigger.Name),
|
||||
fmt.Sprintf("Unexpected properties of trigger %q.%q", expectedTable.Name, expectedTrigger.Name),
|
||||
fmt.Sprintf("%q.%q", table.GetName(), expectedTrigger.GetName()),
|
||||
fmt.Sprintf("Unexpected properties of trigger %q.%q", table.GetName(), expectedTrigger.GetName()),
|
||||
"redefine the trigger",
|
||||
).withDiff(expectedTrigger, *trigger).withStatements(dropTriggerStmt, createTriggerStmt)
|
||||
}, func(additional []schemas.TriggerDescription) []Summary {
|
||||
).withDiff(
|
||||
expectedTrigger,
|
||||
*trigger,
|
||||
).withStatements(
|
||||
expectedTrigger.DropStatement(table),
|
||||
expectedTrigger.CreateStatement(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func compareNamedListsAdditionalCallbackFor(table schemas.TableDescription) func(_ []schemas.TriggerDescription) []Summary {
|
||||
return func(additional []schemas.TriggerDescription) []Summary {
|
||||
summaries := []Summary{}
|
||||
for _, trigger := range additional {
|
||||
dropTriggerStmt := fmt.Sprintf("DROP TRIGGER %s ON %s;", trigger.Name, expectedTable.Name)
|
||||
|
||||
summary := newDriftSummary(
|
||||
fmt.Sprintf("%q.%q", expectedTable.Name, trigger.Name),
|
||||
fmt.Sprintf("Unexpected trigger %q.%q", expectedTable.Name, trigger.Name),
|
||||
summaries = append(summaries, newDriftSummary(
|
||||
fmt.Sprintf("%q.%q", table.GetName(), trigger.GetName()),
|
||||
fmt.Sprintf("Unexpected trigger %q.%q", table.GetName(), trigger.GetName()),
|
||||
"drop the trigger",
|
||||
).withStatements(dropTriggerStmt)
|
||||
summaries = append(summaries, summary)
|
||||
).withStatements(
|
||||
trigger.DropStatement(table),
|
||||
))
|
||||
}
|
||||
|
||||
return summaries
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -2,30 +2,34 @@ package drift
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/internal/database/migration/schemas"
|
||||
)
|
||||
|
||||
func compareViews(schemaName, version string, actual, expected schemas.SchemaDescription) []Summary {
|
||||
return compareNamedLists(actual.Views, expected.Views, func(view *schemas.ViewDescription, expectedView schemas.ViewDescription) Summary {
|
||||
// pgsql has weird indents here
|
||||
viewDefinition := strings.TrimSpace(stripIndent(" " + expectedView.Definition))
|
||||
createViewStmt := fmt.Sprintf("CREATE VIEW %s AS %s", expectedView.Name, viewDefinition)
|
||||
dropViewStmt := fmt.Sprintf("DROP VIEW %s;", expectedView.Name)
|
||||
|
||||
if view == nil {
|
||||
return newDriftSummary(
|
||||
expectedView.Name,
|
||||
fmt.Sprintf("Missing view %q", expectedView.Name),
|
||||
"define the view",
|
||||
).withStatements(createViewStmt)
|
||||
}
|
||||
|
||||
return newDriftSummary(
|
||||
expectedView.Name,
|
||||
fmt.Sprintf("Unexpected definition of view %q", expectedView.Name),
|
||||
"redefine the view",
|
||||
).withDiff(expectedView.Definition, view.Definition).withStatements(dropViewStmt, createViewStmt)
|
||||
}, noopAdditionalCallback[schemas.ViewDescription])
|
||||
return compareNamedLists(actual.Views, expected.Views, compareViewsCallback)
|
||||
}
|
||||
|
||||
func compareViewsCallback(view *schemas.ViewDescription, expectedView schemas.ViewDescription) Summary {
|
||||
if view == nil {
|
||||
return newDriftSummary(
|
||||
expectedView.GetName(),
|
||||
fmt.Sprintf("Missing view %q", expectedView.GetName()),
|
||||
"define the view",
|
||||
).withStatements(
|
||||
expectedView.CreateStatement(),
|
||||
)
|
||||
}
|
||||
|
||||
return newDriftSummary(
|
||||
expectedView.GetName(),
|
||||
fmt.Sprintf("Unexpected definition of view %q", expectedView.GetName()),
|
||||
"redefine the view",
|
||||
).withDiff(
|
||||
expectedView.Definition,
|
||||
view.Definition,
|
||||
).withStatements(
|
||||
expectedView.DropStatement(),
|
||||
expectedView.CreateStatement(),
|
||||
)
|
||||
}
|
||||
|
||||
@ -1,46 +0,0 @@
|
||||
package drift
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/lib/output"
|
||||
)
|
||||
|
||||
type ConsoleFormatter struct {
|
||||
out OutputWriter
|
||||
}
|
||||
|
||||
type OutputWriter interface {
|
||||
Write(s string)
|
||||
Writef(format string, args ...any)
|
||||
WriteLine(line output.FancyLine)
|
||||
WriteCode(languageName, str string) error
|
||||
}
|
||||
|
||||
func NewConsoleFormatter(out OutputWriter) *ConsoleFormatter {
|
||||
return &ConsoleFormatter{out: out}
|
||||
}
|
||||
|
||||
func (f *ConsoleFormatter) Display(s Summary) {
|
||||
f.out.WriteLine(output.Line(output.EmojiFailure, output.StyleBold, s.Problem()))
|
||||
|
||||
if a, b, ok := s.Diff(); ok {
|
||||
_ = f.out.WriteCode("diff", strings.TrimSpace(cmp.Diff(a, b)))
|
||||
}
|
||||
|
||||
f.out.WriteLine(output.Line(output.EmojiLightbulb, output.StyleItalic, fmt.Sprintf("Suggested action: %s.", s.Solution())))
|
||||
|
||||
if statements, ok := s.Statements(); ok {
|
||||
_ = f.out.WriteCode("sql", strings.Join(statements, "\n"))
|
||||
}
|
||||
|
||||
if urlHint, ok := s.URLHint(); ok {
|
||||
f.out.WriteLine(output.Line(output.EmojiLightbulb, output.StyleItalic, fmt.Sprintf("Hint: Reproduce %s as defined at the following URL:", s.Name())))
|
||||
f.out.Write("")
|
||||
f.out.WriteLine(output.Line(output.EmojiFingerPointRight, output.StyleUnderline, urlHint))
|
||||
f.out.Write("")
|
||||
}
|
||||
}
|
||||
@ -1,29 +0,0 @@
|
||||
package drift
|
||||
|
||||
import (
|
||||
"sort"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/internal/database/migration/schemas"
|
||||
)
|
||||
|
||||
// keys returns the ordered keys of the given map.
|
||||
func keys[T any](m map[string]T) []string {
|
||||
keys := make([]string, 0, len(m))
|
||||
for k := range m {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
return keys
|
||||
}
|
||||
|
||||
// groupByName converts the given element slice into a map indexed by
|
||||
// each element's name.
|
||||
func groupByName[T schemas.Namer](ts []T) map[string]T {
|
||||
m := make(map[string]T, len(ts))
|
||||
for _, t := range ts {
|
||||
m[t.GetName()] = t
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
@ -7,16 +7,6 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
// quoteTerm converts the given literal search term into a regular expression.
|
||||
func quoteTerm(searchTerm string) string {
|
||||
terms := strings.Split(searchTerm, " ")
|
||||
for i, term := range terms {
|
||||
terms[i] = regexp.QuoteMeta(term)
|
||||
}
|
||||
|
||||
return "(^|\\b)" + strings.Join(terms, "\\s") + "($|\\b)"
|
||||
}
|
||||
|
||||
// makeSearchURL returns a URL to a sourcegraph.com search query within the squashed
|
||||
// definition of the given schema.
|
||||
func makeSearchURL(schemaName, version string, searchTerms ...string) string {
|
||||
@ -39,3 +29,13 @@ func makeSearchURL(schemaName, version string, searchTerms ...string) string {
|
||||
searchUrl.RawQuery = qs.Encode()
|
||||
return searchUrl.String()
|
||||
}
|
||||
|
||||
// quoteTerm converts the given literal search term into a regular expression.
|
||||
func quoteTerm(searchTerm string) string {
|
||||
terms := strings.Split(searchTerm, " ")
|
||||
for i, term := range terms {
|
||||
terms[i] = regexp.QuoteMeta(term)
|
||||
}
|
||||
|
||||
return "(^|\\b)" + strings.Join(terms, "\\s") + "($|\\b)"
|
||||
}
|
||||
|
||||
@ -1,37 +0,0 @@
|
||||
package drift
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
type stringNamer string
|
||||
|
||||
func (s stringNamer) GetName() string { return string(s) }
|
||||
|
||||
// wrapStrings converts a string slice into a string slice with GetName
|
||||
// on each element.
|
||||
func wrapStrings(ss []string) []stringNamer {
|
||||
sn := make([]stringNamer, 0, len(ss))
|
||||
for _, s := range ss {
|
||||
sn = append(sn, stringNamer(s))
|
||||
}
|
||||
|
||||
return sn
|
||||
}
|
||||
|
||||
// stripIndent removes the largest common indent from the given text.
|
||||
func stripIndent(s string) string {
|
||||
lines := strings.Split(strings.TrimRight(s, "\n"), "\n")
|
||||
|
||||
min := len(lines[0])
|
||||
for _, line := range lines {
|
||||
if indent := len(line) - len(strings.TrimLeft(line, " ")); indent < min {
|
||||
min = indent
|
||||
}
|
||||
}
|
||||
for i, line := range lines {
|
||||
lines[i] = line[min:]
|
||||
}
|
||||
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
@ -13,6 +13,7 @@ type Store interface {
|
||||
Done(err error) error
|
||||
|
||||
Versions(ctx context.Context) (appliedVersions, pendingVersions, failedVersions []int, _ error)
|
||||
RunDDLStatements(ctx context.Context, statements []string) error
|
||||
TryLock(ctx context.Context) (bool, func(err error) error, error)
|
||||
Up(ctx context.Context, migration definition.Definition) error
|
||||
Down(ctx context.Context, migration definition.Definition) error
|
||||
|
||||
121
internal/database/migration/runner/mocks_test.go
generated
121
internal/database/migration/runner/mocks_test.go
generated
@ -32,6 +32,9 @@ type MockStore struct {
|
||||
// IndexStatusFunc is an instance of a mock function object controlling
|
||||
// the behavior of the method IndexStatus.
|
||||
IndexStatusFunc *StoreIndexStatusFunc
|
||||
// RunDDLStatementsFunc is an instance of a mock function object
|
||||
// controlling the behavior of the method RunDDLStatements.
|
||||
RunDDLStatementsFunc *StoreRunDDLStatementsFunc
|
||||
// TransactFunc is an instance of a mock function object controlling the
|
||||
// behavior of the method Transact.
|
||||
TransactFunc *StoreTransactFunc
|
||||
@ -73,6 +76,11 @@ func NewMockStore() *MockStore {
|
||||
return
|
||||
},
|
||||
},
|
||||
RunDDLStatementsFunc: &StoreRunDDLStatementsFunc{
|
||||
defaultHook: func(context.Context, []string) (r0 error) {
|
||||
return
|
||||
},
|
||||
},
|
||||
TransactFunc: &StoreTransactFunc{
|
||||
defaultHook: func(context.Context) (r0 Store, r1 error) {
|
||||
return
|
||||
@ -125,6 +133,11 @@ func NewStrictMockStore() *MockStore {
|
||||
panic("unexpected invocation of MockStore.IndexStatus")
|
||||
},
|
||||
},
|
||||
RunDDLStatementsFunc: &StoreRunDDLStatementsFunc{
|
||||
defaultHook: func(context.Context, []string) error {
|
||||
panic("unexpected invocation of MockStore.RunDDLStatements")
|
||||
},
|
||||
},
|
||||
TransactFunc: &StoreTransactFunc{
|
||||
defaultHook: func(context.Context) (Store, error) {
|
||||
panic("unexpected invocation of MockStore.Transact")
|
||||
@ -169,6 +182,9 @@ func NewMockStoreFrom(i Store) *MockStore {
|
||||
IndexStatusFunc: &StoreIndexStatusFunc{
|
||||
defaultHook: i.IndexStatus,
|
||||
},
|
||||
RunDDLStatementsFunc: &StoreRunDDLStatementsFunc{
|
||||
defaultHook: i.RunDDLStatements,
|
||||
},
|
||||
TransactFunc: &StoreTransactFunc{
|
||||
defaultHook: i.Transact,
|
||||
},
|
||||
@ -609,6 +625,111 @@ func (c StoreIndexStatusFuncCall) Results() []interface{} {
|
||||
return []interface{}{c.Result0, c.Result1, c.Result2}
|
||||
}
|
||||
|
||||
// StoreRunDDLStatementsFunc describes the behavior when the
|
||||
// RunDDLStatements method of the parent MockStore instance is invoked.
|
||||
type StoreRunDDLStatementsFunc struct {
|
||||
defaultHook func(context.Context, []string) error
|
||||
hooks []func(context.Context, []string) error
|
||||
history []StoreRunDDLStatementsFuncCall
|
||||
mutex sync.Mutex
|
||||
}
|
||||
|
||||
// RunDDLStatements delegates to the next hook function in the queue and
|
||||
// stores the parameter and result values of this invocation.
|
||||
func (m *MockStore) RunDDLStatements(v0 context.Context, v1 []string) error {
|
||||
r0 := m.RunDDLStatementsFunc.nextHook()(v0, v1)
|
||||
m.RunDDLStatementsFunc.appendCall(StoreRunDDLStatementsFuncCall{v0, v1, r0})
|
||||
return r0
|
||||
}
|
||||
|
||||
// SetDefaultHook sets function that is called when the RunDDLStatements
|
||||
// method of the parent MockStore instance is invoked and the hook queue is
|
||||
// empty.
|
||||
func (f *StoreRunDDLStatementsFunc) SetDefaultHook(hook func(context.Context, []string) error) {
|
||||
f.defaultHook = hook
|
||||
}
|
||||
|
||||
// PushHook adds a function to the end of hook queue. Each invocation of the
|
||||
// RunDDLStatements method of the parent MockStore instance invokes the hook
|
||||
// at the front of the queue and discards it. After the queue is empty, the
|
||||
// default hook function is invoked for any future action.
|
||||
func (f *StoreRunDDLStatementsFunc) PushHook(hook func(context.Context, []string) error) {
|
||||
f.mutex.Lock()
|
||||
f.hooks = append(f.hooks, hook)
|
||||
f.mutex.Unlock()
|
||||
}
|
||||
|
||||
// SetDefaultReturn calls SetDefaultHook with a function that returns the
|
||||
// given values.
|
||||
func (f *StoreRunDDLStatementsFunc) SetDefaultReturn(r0 error) {
|
||||
f.SetDefaultHook(func(context.Context, []string) error {
|
||||
return r0
|
||||
})
|
||||
}
|
||||
|
||||
// PushReturn calls PushHook with a function that returns the given values.
|
||||
func (f *StoreRunDDLStatementsFunc) PushReturn(r0 error) {
|
||||
f.PushHook(func(context.Context, []string) error {
|
||||
return r0
|
||||
})
|
||||
}
|
||||
|
||||
func (f *StoreRunDDLStatementsFunc) nextHook() func(context.Context, []string) error {
|
||||
f.mutex.Lock()
|
||||
defer f.mutex.Unlock()
|
||||
|
||||
if len(f.hooks) == 0 {
|
||||
return f.defaultHook
|
||||
}
|
||||
|
||||
hook := f.hooks[0]
|
||||
f.hooks = f.hooks[1:]
|
||||
return hook
|
||||
}
|
||||
|
||||
func (f *StoreRunDDLStatementsFunc) appendCall(r0 StoreRunDDLStatementsFuncCall) {
|
||||
f.mutex.Lock()
|
||||
f.history = append(f.history, r0)
|
||||
f.mutex.Unlock()
|
||||
}
|
||||
|
||||
// History returns a sequence of StoreRunDDLStatementsFuncCall objects
|
||||
// describing the invocations of this function.
|
||||
func (f *StoreRunDDLStatementsFunc) History() []StoreRunDDLStatementsFuncCall {
|
||||
f.mutex.Lock()
|
||||
history := make([]StoreRunDDLStatementsFuncCall, len(f.history))
|
||||
copy(history, f.history)
|
||||
f.mutex.Unlock()
|
||||
|
||||
return history
|
||||
}
|
||||
|
||||
// StoreRunDDLStatementsFuncCall is an object that describes an invocation
|
||||
// of method RunDDLStatements on an instance of MockStore.
|
||||
type StoreRunDDLStatementsFuncCall struct {
|
||||
// Arg0 is the value of the 1st argument passed to this method
|
||||
// invocation.
|
||||
Arg0 context.Context
|
||||
// Arg1 is the value of the 2nd argument passed to this method
|
||||
// invocation.
|
||||
Arg1 []string
|
||||
// Result0 is the value of the 1st result returned from this method
|
||||
// invocation.
|
||||
Result0 error
|
||||
}
|
||||
|
||||
// Args returns an interface slice containing the arguments of this
|
||||
// invocation.
|
||||
func (c StoreRunDDLStatementsFuncCall) Args() []interface{} {
|
||||
return []interface{}{c.Arg0, c.Arg1}
|
||||
}
|
||||
|
||||
// Results returns an interface slice containing the results of this
|
||||
// invocation.
|
||||
func (c StoreRunDDLStatementsFuncCall) Results() []interface{} {
|
||||
return []interface{}{c.Result0}
|
||||
}
|
||||
|
||||
// StoreTransactFunc describes the behavior when the Transact method of the
|
||||
// parent MockStore instance is invoked.
|
||||
type StoreTransactFunc struct {
|
||||
|
||||
1
internal/database/migration/schemas/BUILD.bazel
generated
1
internal/database/migration/schemas/BUILD.bazel
generated
@ -19,6 +19,7 @@ go_library(
|
||||
"//internal/lazyregexp",
|
||||
"//lib/errors",
|
||||
"//migrations",
|
||||
"@com_github_google_go_cmp//cmp",
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
@ -1,8 +1,11 @@
|
||||
package schemas
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/google/go-cmp/cmp"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/internal/lazyregexp"
|
||||
)
|
||||
|
||||
@ -15,16 +18,129 @@ type SchemaDescription struct {
|
||||
Views []ViewDescription
|
||||
}
|
||||
|
||||
func (d SchemaDescription) WrappedExtensions() []ExtensionDescription {
|
||||
extensions := make([]ExtensionDescription, 0, len(d.Extensions))
|
||||
for _, name := range d.Extensions {
|
||||
extensions = append(extensions, ExtensionDescription{Name: name})
|
||||
}
|
||||
|
||||
return extensions
|
||||
}
|
||||
|
||||
type ExtensionDescription struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
func (d ExtensionDescription) CreateStatement() string {
|
||||
return fmt.Sprintf("CREATE EXTENSION %s;", d.Name)
|
||||
}
|
||||
|
||||
type EnumDescription struct {
|
||||
Name string
|
||||
Labels []string
|
||||
}
|
||||
|
||||
func (d EnumDescription) CreateStatement() string {
|
||||
quotedLabels := make([]string, 0, len(d.Labels))
|
||||
for _, label := range d.Labels {
|
||||
quotedLabels = append(quotedLabels, fmt.Sprintf("'%s'", label))
|
||||
}
|
||||
|
||||
return fmt.Sprintf("CREATE TYPE %s AS ENUM (%s);", d.Name, strings.Join(quotedLabels, ", "))
|
||||
}
|
||||
|
||||
func (d EnumDescription) DropStatement() string {
|
||||
return fmt.Sprintf("DROP TYPE IF EXISTS %s;", d.Name)
|
||||
}
|
||||
|
||||
// AlterToTarget returns a set of `ALTER ENUM ADD VALUE` statements to make the given enum equivalent to
|
||||
// the expected enum, then additive statements cannot bring the enum to the expected state and we return
|
||||
// a false-valued flag. In this case the existing type must be dropped and re-created as there's currently
|
||||
// no way to *remove* values from an enum type.
|
||||
func (d EnumDescription) AlterToTarget(target EnumDescription) ([]string, bool) {
|
||||
labels := GroupByName(wrapStrings(d.Labels))
|
||||
expectedLabels := GroupByName(wrapStrings(target.Labels))
|
||||
|
||||
for label := range labels {
|
||||
if _, ok := expectedLabels[label]; !ok {
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
// If we're here then we're strictly missing labels and can add them in-place.
|
||||
// Try to reconstruct the data we need to make the proper create type statement.
|
||||
|
||||
type missingLabel struct {
|
||||
label string
|
||||
neighbor string
|
||||
before bool
|
||||
}
|
||||
missingLabels := make([]missingLabel, 0, len(target.Labels))
|
||||
|
||||
after := ""
|
||||
for _, label := range target.Labels {
|
||||
if _, ok := labels[label]; !ok && after != "" {
|
||||
missingLabels = append(missingLabels, missingLabel{label: label, neighbor: after, before: false})
|
||||
}
|
||||
after = label
|
||||
}
|
||||
|
||||
before := ""
|
||||
for i := len(target.Labels) - 1; i >= 0; i-- {
|
||||
label := target.Labels[i]
|
||||
|
||||
if _, ok := labels[label]; !ok && before != "" {
|
||||
missingLabels = append(missingLabels, missingLabel{label: label, neighbor: before, before: true})
|
||||
}
|
||||
before = label
|
||||
}
|
||||
|
||||
var (
|
||||
ordered []string
|
||||
reachable = GroupByName(wrapStrings(d.Labels))
|
||||
)
|
||||
|
||||
outer:
|
||||
for len(missingLabels) > 0 {
|
||||
for _, s := range missingLabels {
|
||||
// Neighbor doesn't exist yet, blocked from creating
|
||||
if _, ok := reachable[s.neighbor]; !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
rel := "AFTER"
|
||||
if s.before {
|
||||
rel = "BEFORE"
|
||||
}
|
||||
|
||||
filtered := missingLabels[:0]
|
||||
for _, l := range missingLabels {
|
||||
if l.label != s.label {
|
||||
filtered = append(filtered, l)
|
||||
}
|
||||
}
|
||||
|
||||
missingLabels = filtered
|
||||
reachable[s.label] = stringNamer(s.label)
|
||||
ordered = append(ordered, fmt.Sprintf("ALTER TYPE %s ADD VALUE '%s' %s '%s';", target.GetName(), s.label, rel, s.neighbor))
|
||||
continue outer
|
||||
}
|
||||
|
||||
panic("Infinite loop")
|
||||
}
|
||||
|
||||
return ordered, true
|
||||
}
|
||||
|
||||
type FunctionDescription struct {
|
||||
Name string
|
||||
Definition string
|
||||
}
|
||||
|
||||
func (d FunctionDescription) CreateOrReplaceStatement() string {
|
||||
return fmt.Sprintf("%s;", d.Definition)
|
||||
}
|
||||
|
||||
type SequenceDescription struct {
|
||||
Name string
|
||||
TypeName string
|
||||
@ -35,6 +151,44 @@ type SequenceDescription struct {
|
||||
CycleOption string
|
||||
}
|
||||
|
||||
func (d SequenceDescription) CreateStatement() string {
|
||||
minValue := "NO MINVALUE"
|
||||
if d.MinimumValue != 0 {
|
||||
minValue = fmt.Sprintf("MINVALUE %d", d.MinimumValue)
|
||||
}
|
||||
maxValue := "NO MAXVALUE"
|
||||
if d.MaximumValue != 0 {
|
||||
maxValue = fmt.Sprintf("MAXVALUE %d", d.MaximumValue)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(
|
||||
"CREATE SEQUENCE %s AS %s INCREMENT BY %d %s %s START WITH %d %s CYCLE;",
|
||||
d.Name,
|
||||
d.TypeName,
|
||||
d.Increment,
|
||||
minValue,
|
||||
maxValue,
|
||||
d.StartValue,
|
||||
d.CycleOption,
|
||||
)
|
||||
}
|
||||
|
||||
func (d SequenceDescription) AlterToTarget(target SequenceDescription) ([]string, bool) {
|
||||
statements := []string{}
|
||||
|
||||
if d.TypeName != target.TypeName {
|
||||
statements = append(statements, fmt.Sprintf("ALTER SEQUENCE %s AS %s MAXVALUE %d;", d.Name, target.TypeName, target.MaximumValue))
|
||||
|
||||
// Remove from diff below
|
||||
d.TypeName = target.TypeName
|
||||
d.MaximumValue = target.MaximumValue
|
||||
}
|
||||
|
||||
// Abort if there are other fields we haven't addressed
|
||||
hasAdditionalDiff := cmp.Diff(d, target) != ""
|
||||
return statements, !hasAdditionalDiff
|
||||
}
|
||||
|
||||
type TableDescription struct {
|
||||
Name string
|
||||
Comment string
|
||||
@ -58,6 +212,57 @@ type ColumnDescription struct {
|
||||
Comment string
|
||||
}
|
||||
|
||||
func (d ColumnDescription) CreateStatement(table TableDescription) string {
|
||||
nullableExpr := ""
|
||||
if !d.IsNullable {
|
||||
nullableExpr = " NOT NULL"
|
||||
}
|
||||
defaultExpr := ""
|
||||
if d.Default != "" {
|
||||
defaultExpr = fmt.Sprintf(" DEFAULT %s", d.Default)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("ALTER TABLE %s ADD COLUMN %s %s%s%s;", table.Name, d.Name, d.TypeName, nullableExpr, defaultExpr)
|
||||
}
|
||||
|
||||
func (d ColumnDescription) DropStatement(table TableDescription) string {
|
||||
return fmt.Sprintf("ALTER TABLE %s DROP COLUMN IF EXISTS %s;", table.Name, d.Name)
|
||||
}
|
||||
|
||||
func (d ColumnDescription) AlterToTarget(table TableDescription, target ColumnDescription) ([]string, bool) {
|
||||
statements := []string{}
|
||||
|
||||
if d.TypeName != target.TypeName {
|
||||
statements = append(statements, fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s SET DATA TYPE %s;", table.Name, target.Name, target.TypeName))
|
||||
|
||||
// Remove from diff below
|
||||
d.TypeName = target.TypeName
|
||||
}
|
||||
if d.IsNullable != target.IsNullable {
|
||||
var verb string
|
||||
if target.IsNullable {
|
||||
verb = "DROP"
|
||||
} else {
|
||||
verb = "SET"
|
||||
}
|
||||
|
||||
statements = append(statements, fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s %s NOT NULL;", table.Name, target.Name, verb))
|
||||
|
||||
// Remove from diff below
|
||||
d.IsNullable = target.IsNullable
|
||||
}
|
||||
if d.Default != target.Default {
|
||||
statements = append(statements, fmt.Sprintf("ALTER TABLE %s ALTER COLUMN %s SET DEFAULT %s;", table.Name, target.Name, target.Default))
|
||||
|
||||
// Remove from diff below
|
||||
d.Default = target.Default
|
||||
}
|
||||
|
||||
// Abort if there are other fields we haven't addressed
|
||||
hasAdditionalDiff := cmp.Diff(d, target) != ""
|
||||
return statements, !hasAdditionalDiff
|
||||
}
|
||||
|
||||
type IndexDescription struct {
|
||||
Name string
|
||||
IsPrimaryKey bool
|
||||
@ -69,6 +274,18 @@ type IndexDescription struct {
|
||||
ConstraintDefinition string
|
||||
}
|
||||
|
||||
func (d IndexDescription) CreateStatement(table TableDescription) string {
|
||||
if d.ConstraintType == "u" || d.ConstraintType == "p" {
|
||||
return fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s %s;", table.Name, d.Name, d.ConstraintDefinition)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s;", d.IndexDefinition)
|
||||
}
|
||||
|
||||
func (d IndexDescription) DropStatement() string {
|
||||
return fmt.Sprintf("DROP INDEX IF EXISTS %s;", d.GetName())
|
||||
}
|
||||
|
||||
type ConstraintDescription struct {
|
||||
Name string
|
||||
ConstraintType string
|
||||
@ -77,16 +294,58 @@ type ConstraintDescription struct {
|
||||
ConstraintDefinition string
|
||||
}
|
||||
|
||||
func (d ConstraintDescription) CreateStatement(table TableDescription) string {
|
||||
return fmt.Sprintf("ALTER TABLE %s ADD CONSTRAINT %s %s;", table.Name, d.Name, d.ConstraintDefinition)
|
||||
}
|
||||
|
||||
func (d ConstraintDescription) DropStatement(table TableDescription) string {
|
||||
return fmt.Sprintf("ALTER TABLE %s DROP CONSTRAINT %s;", table.Name, d.Name)
|
||||
}
|
||||
|
||||
type TriggerDescription struct {
|
||||
Name string
|
||||
Definition string
|
||||
}
|
||||
|
||||
func (d TriggerDescription) CreateStatement() string {
|
||||
return fmt.Sprintf("%s;", d.Definition)
|
||||
}
|
||||
|
||||
func (d TriggerDescription) DropStatement(table TableDescription) string {
|
||||
return fmt.Sprintf("DROP TRIGGER IF EXISTS %s ON %s;", d.Name, table.Name)
|
||||
}
|
||||
|
||||
type ViewDescription struct {
|
||||
Name string
|
||||
Definition string
|
||||
}
|
||||
|
||||
func (d ViewDescription) CreateStatement() string {
|
||||
// pgsql indents definitions strangely; we copy that
|
||||
return fmt.Sprintf("CREATE VIEW %s AS %s", d.Name, strings.TrimSpace(stripIndent(" "+d.Definition)))
|
||||
}
|
||||
|
||||
func (d ViewDescription) DropStatement() string {
|
||||
return fmt.Sprintf("DROP VIEW IF EXISTS %s;", d.Name)
|
||||
}
|
||||
|
||||
// stripIndent removes the largest common indent from the given text.
|
||||
func stripIndent(s string) string {
|
||||
lines := strings.Split(strings.TrimRight(s, "\n"), "\n")
|
||||
|
||||
min := len(lines[0])
|
||||
for _, line := range lines {
|
||||
if indent := len(line) - len(strings.TrimLeft(line, " ")); indent < min {
|
||||
min = indent
|
||||
}
|
||||
}
|
||||
for i, line := range lines {
|
||||
lines[i] = line[min:]
|
||||
}
|
||||
|
||||
return strings.Join(lines, "\n")
|
||||
}
|
||||
|
||||
func Canonicalize(schemaDescription SchemaDescription) {
|
||||
for i := range schemaDescription.Tables {
|
||||
sortColumnsByName(schemaDescription.Tables[i].Columns)
|
||||
@ -104,6 +363,28 @@ func Canonicalize(schemaDescription SchemaDescription) {
|
||||
|
||||
type Namer interface{ GetName() string }
|
||||
|
||||
func GroupByName[T Namer](ts []T) map[string]T {
|
||||
m := make(map[string]T, len(ts))
|
||||
for _, t := range ts {
|
||||
m[t.GetName()] = t
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
type stringNamer string
|
||||
|
||||
func wrapStrings(ss []string) []Namer {
|
||||
sn := make([]Namer, 0, len(ss))
|
||||
for _, s := range ss {
|
||||
sn = append(sn, stringNamer(s))
|
||||
}
|
||||
|
||||
return sn
|
||||
}
|
||||
|
||||
func (n stringNamer) GetName() string { return string(n) }
|
||||
func (d ExtensionDescription) GetName() string { return d.Name }
|
||||
func (d EnumDescription) GetName() string { return d.Name }
|
||||
func (d FunctionDescription) GetName() string { return d.Name }
|
||||
func (d SequenceDescription) GetName() string { return d.Name }
|
||||
|
||||
@ -16,6 +16,7 @@ type Operations struct {
|
||||
tryLock *observation.Operation
|
||||
up *observation.Operation
|
||||
versions *observation.Operation
|
||||
runDDLStatements *observation.Operation
|
||||
withMigrationLog *observation.Operation
|
||||
}
|
||||
|
||||
@ -49,6 +50,7 @@ func NewOperations(observationCtx *observation.Context) *Operations {
|
||||
tryLock: op("TryLock"),
|
||||
up: op("Up"),
|
||||
versions: op("Versions"),
|
||||
runDDLStatements: op("RunDDLStatements"),
|
||||
withMigrationLog: op("WithMigrationLog"),
|
||||
}
|
||||
})
|
||||
|
||||
@ -282,6 +282,25 @@ WHERE row_number = 1
|
||||
ORDER BY version
|
||||
`
|
||||
|
||||
func (s *Store) RunDDLStatements(ctx context.Context, statements []string) (err error) {
|
||||
ctx, _, endObservation := s.operations.runDDLStatements.With(ctx, &err, observation.Args{})
|
||||
defer endObservation(1, observation.Args{})
|
||||
|
||||
tx, err := s.Transact(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { err = tx.Done(err) }()
|
||||
|
||||
for _, statement := range statements {
|
||||
if err := tx.Exec(ctx, sqlf.Sprintf(statement)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// TryLock attempts to create hold an advisory lock. This method returns a function that should be
|
||||
// called once the lock should be released. This method accepts the current function's error output
|
||||
// and wraps any additional errors that occur on close. Calling this method when the lock was not
|
||||
|
||||
Loading…
Reference in New Issue
Block a user