diff --git a/cmd/enterprise-portal/internal/database/BUILD.bazel b/cmd/enterprise-portal/internal/database/BUILD.bazel index d0daee3def4..dbbb9871bcb 100644 --- a/cmd/enterprise-portal/internal/database/BUILD.bazel +++ b/cmd/enterprise-portal/internal/database/BUILD.bazel @@ -5,12 +5,14 @@ go_library( srcs = [ "database.go", "migrate.go", + "types.go", ], importpath = "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database", tags = [TAG_INFRA_CORESERVICES], visibility = ["//cmd/enterprise-portal:__subpackages__"], deps = [ "//cmd/enterprise-portal/internal/database/internal/tables", + "//cmd/enterprise-portal/internal/database/internal/tables/custommigrator", "//cmd/enterprise-portal/internal/database/subscriptions", "//lib/errors", "//lib/managedservicesplatform/runtime", diff --git a/cmd/enterprise-portal/internal/database/codyaccess/codygateway.go b/cmd/enterprise-portal/internal/database/codyaccess/codygateway.go index 792dd8b5c02..89d2b738bba 100644 --- a/cmd/enterprise-portal/internal/database/codyaccess/codygateway.go +++ b/cmd/enterprise-portal/internal/database/codyaccess/codygateway.go @@ -4,7 +4,7 @@ import "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/databa type CodyGatewayAccess struct { // ⚠️ DO NOT USE: This field is only used for creating foreign key constraint. - Subscription *subscriptions.Subscription `gorm:"foreignKey:SubscriptionID"` + Subscription *subscriptions.TableSubscription `gorm:"foreignKey:SubscriptionID"` // SubscriptionID is the internal unprefixed UUID of the related subscription. SubscriptionID string `gorm:"type:uuid;not null;unique"` diff --git a/cmd/enterprise-portal/internal/database/databasetest/BUILD.bazel b/cmd/enterprise-portal/internal/database/databasetest/BUILD.bazel index 045d275fc4c..d2a6fe3656d 100644 --- a/cmd/enterprise-portal/internal/database/databasetest/BUILD.bazel +++ b/cmd/enterprise-portal/internal/database/databasetest/BUILD.bazel @@ -7,7 +7,9 @@ go_library( tags = [TAG_INFRA_CORESERVICES], visibility = ["//cmd/enterprise-portal:__subpackages__"], deps = [ + "//cmd/enterprise-portal/internal/database/internal/tables/custommigrator", "//internal/database/dbtest", + "@com_github_jackc_pgx_v5//:pgx", "@com_github_jackc_pgx_v5//pgxpool", "@com_github_stretchr_testify//require", "@io_gorm_driver_postgres//:postgres", diff --git a/cmd/enterprise-portal/internal/database/databasetest/databasetest.go b/cmd/enterprise-portal/internal/database/databasetest/databasetest.go index 7ecce936498..2671aa6fc52 100644 --- a/cmd/enterprise-portal/internal/database/databasetest/databasetest.go +++ b/cmd/enterprise-portal/internal/database/databasetest/databasetest.go @@ -3,17 +3,20 @@ package databasetest import ( "context" "database/sql" + "encoding/json" "fmt" "strings" "testing" "time" + "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" "github.com/stretchr/testify/require" "gorm.io/driver/postgres" "gorm.io/gorm" "gorm.io/gorm/schema" + "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables/custommigrator" "github.com/sourcegraph/sourcegraph/internal/database/dbtest" ) @@ -57,6 +60,11 @@ func NewTestDB(t testing.TB, system, suite string, tables ...schema.Tabler) *pgx for _, table := range tables { err = db.AutoMigrate(table) require.NoError(t, err) + if m, ok := table.(custommigrator.CustomTableMigrator); ok { + if err := m.RunCustomMigrations(db.Migrator()); err != nil { + require.NoError(t, err) + } + } } // Close the connection used to auto-migrate the database. @@ -66,7 +74,12 @@ func NewTestDB(t testing.TB, system, suite string, tables ...schema.Tabler) *pgx require.NoError(t, err) // Open a new connection to the test suite database. - testDB, err := pgxpool.New(context.Background(), dsn.String()) + dbConfig, err := pgxpool.ParseConfig(dsn.String()) + require.NoError(t, err) + if testing.Verbose() { + dbConfig.ConnConfig.Tracer = pgxTestTracer{TB: t} + } + testDB, err := pgxpool.NewWithConfig(context.Background(), dbConfig) require.NoError(t, err) t.Cleanup(func() { @@ -110,3 +123,43 @@ func ClearTablesAfterTest(t *testing.T, db *pgxpool.Pool, tables ...schema.Table } }) } + +// pgxTestTracer implements various pgx tracing hooks for dumping diagnostics +// in testing. +type pgxTestTracer struct{ testing.TB } + +// Select tracing hooks we want to implement. +var ( + _ pgx.QueryTracer = pgxTestTracer{} +) + +func (t pgxTestTracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context { + var args []string + if len(data.Args) > 0 { + // Divider for readability + args = append(args, "\n---") + } + for _, arg := range data.Args { + data, err := json.MarshalIndent(arg, "", " ") + if err != nil { + args = append(args, fmt.Sprintf("marshal %T: %+v", arg, err)) + } + args = append(args, string(data)) + } + + t.Logf(`pgx.QueryStart db=%q +%s%s`, + conn.Config().Database, + strings.TrimSpace(data.SQL), + strings.Join(args, "\n")) + return ctx +} + +func (t pgxTestTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) { + if data.Err != nil { + t.Logf(`pgx.QueryEnd db=%q tag=%q error=%q`, + conn.Config().Database, + data.CommandTag.String(), + data.Err) + } +} diff --git a/cmd/enterprise-portal/internal/database/internal/tables/custommigrator/BUILD.bazel b/cmd/enterprise-portal/internal/database/internal/tables/custommigrator/BUILD.bazel new file mode 100644 index 00000000000..d832660e58a --- /dev/null +++ b/cmd/enterprise-portal/internal/database/internal/tables/custommigrator/BUILD.bazel @@ -0,0 +1,9 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "custommigrator", + srcs = ["custommigrator.go"], + importpath = "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables/custommigrator", + visibility = ["//cmd/enterprise-portal:__subpackages__"], + deps = ["@io_gorm_gorm//:gorm"], +) diff --git a/cmd/enterprise-portal/internal/database/internal/tables/custommigrator/custommigrator.go b/cmd/enterprise-portal/internal/database/internal/tables/custommigrator/custommigrator.go new file mode 100644 index 00000000000..2686eb8ab7e --- /dev/null +++ b/cmd/enterprise-portal/internal/database/internal/tables/custommigrator/custommigrator.go @@ -0,0 +1,9 @@ +package custommigrator + +import "gorm.io/gorm" + +type CustomTableMigrator interface { + // RunCustomMigrations is called after all other migrations have been run. + // It can implement custom migrations. + RunCustomMigrations(migrator gorm.Migrator) error +} diff --git a/cmd/enterprise-portal/internal/database/internal/tables/tables.go b/cmd/enterprise-portal/internal/database/internal/tables/tables.go index 3a08337fb70..7d84ec06e43 100644 --- a/cmd/enterprise-portal/internal/database/internal/tables/tables.go +++ b/cmd/enterprise-portal/internal/database/internal/tables/tables.go @@ -12,9 +12,9 @@ import ( // ⚠️ WARNING: This list is meant to be read-only. func All() []schema.Tabler { return []schema.Tabler{ - &subscriptions.Subscription{}, + &subscriptions.TableSubscription{}, &subscriptions.SubscriptionCondition{}, - &subscriptions.SubscriptionLicense{}, + &subscriptions.TableSubscriptionLicense{}, &subscriptions.SubscriptionLicenseCondition{}, &codyaccess.CodyGatewayAccess{}, diff --git a/cmd/enterprise-portal/internal/database/internal/utctime/BUILD.bazel b/cmd/enterprise-portal/internal/database/internal/utctime/BUILD.bazel new file mode 100644 index 00000000000..8e3be7b56e3 --- /dev/null +++ b/cmd/enterprise-portal/internal/database/internal/utctime/BUILD.bazel @@ -0,0 +1,12 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "utctime", + srcs = ["utctime.go"], + importpath = "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/utctime", + visibility = ["//cmd/enterprise-portal:__subpackages__"], + deps = [ + "//lib/errors", + "//lib/pointers", + ], +) diff --git a/cmd/enterprise-portal/internal/database/internal/utctime/utctime.go b/cmd/enterprise-portal/internal/database/internal/utctime/utctime.go new file mode 100644 index 00000000000..0e90287687e --- /dev/null +++ b/cmd/enterprise-portal/internal/database/internal/utctime/utctime.go @@ -0,0 +1,80 @@ +package utctime + +import ( + "database/sql" + "database/sql/driver" + "encoding/json" + "time" + + "github.com/sourcegraph/sourcegraph/lib/errors" + "github.com/sourcegraph/sourcegraph/lib/pointers" +) + +// Time is a wrapper around time.Time that implements the database/sql.Scanner +// and database/sql/driver.Valuer interfaces to serialize and deserialize time +// in UTC time zone. +// +// Time ensures that time.Time values are always: +// +// - represented in UTC for consistency +// - rounded to microsecond precision +// +// We round the time because PostgreSQL times are represented in microseconds: +// https://www.postgresql.org/docs/current/datatype-datetime.html +type Time time.Time + +// Now returns the current time in UTC. +func Now() Time { return Time(time.Now()) } + +// FromTime returns a utctime.Time from a time.Time. +func FromTime(t time.Time) Time { return Time(t.UTC().Round(time.Microsecond)) } + +var _ sql.Scanner = (*Time)(nil) + +func (t *Time) Scan(src any) error { + if src == nil { + return nil + } + if v, ok := src.(time.Time); ok { + *t = FromTime(v) + return nil + } + return errors.Newf("value %T is not time.Time", src) +} + +var _ driver.Valuer = (*Time)(nil) + +// Value must be called with a non-nil Time. driver.Valuer callers will first +// check that the value is non-nil, so this is safe. +func (t Time) Value() (driver.Value, error) { + stdTime := t.GetTime() + return *stdTime, nil +} + +var _ json.Marshaler = (*Time)(nil) + +func (t Time) MarshalJSON() ([]byte, error) { return json.Marshal(t.GetTime()) } + +var _ json.Unmarshaler = (*Time)(nil) + +func (t *Time) UnmarshalJSON(data []byte) error { + var stdTime time.Time + if err := json.Unmarshal(data, &stdTime); err != nil { + return err + } + *t = FromTime(stdTime) + return nil +} + +// GetTime returns the underlying time.GetTime value, or nil if it is nil. +func (t *Time) GetTime() *time.Time { + if t == nil { + return nil + } + return pointers.Ptr(t.AsTime()) +} + +// Time casts the Time as a standard time.Time value. +func (t Time) AsTime() time.Time { + return time.Time(t).UTC().Round(time.Microsecond) +} diff --git a/cmd/enterprise-portal/internal/database/migrate.go b/cmd/enterprise-portal/internal/database/migrate.go index 5a03474975b..380d9125966 100644 --- a/cmd/enterprise-portal/internal/database/migrate.go +++ b/cmd/enterprise-portal/internal/database/migrate.go @@ -21,6 +21,7 @@ import ( "github.com/sourcegraph/sourcegraph/lib/redislock" "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables" + "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables/custommigrator" ) // maybeMigrate runs the auto-migration for the database when needed based on @@ -42,6 +43,12 @@ func maybeMigrate(ctx context.Context, logger log.Logger, contract runtime.Contr } span.End() }() + logger = logger. + WithTrace(log.TraceContext{ + TraceID: span.SpanContext().TraceID().String(), + SpanID: span.SpanContext().SpanID().String(), + }). + With(log.String("database", dbName)) sqlDB, err := contract.PostgreSQL.OpenDatabase(ctx, dbName) if err != nil { @@ -83,17 +90,20 @@ func maybeMigrate(ctx context.Context, logger log.Logger, contract runtime.Contr span.AddEvent("lock.acquired") versionKey := fmt.Sprintf("%s:db_version", dbName) + liveVersion := redisClient.Get(ctx, versionKey).Val() if shouldSkipMigration( - redisClient.Get(ctx, versionKey).Val(), + liveVersion, currentVersion, ) { logger.Info("skipped auto-migration", - log.String("database", dbName), log.String("currentVersion", currentVersion), ) span.SetAttributes(attribute.Bool("skipped", true)) return nil } + logger.Info("executing auto-migration", + log.String("liveVersion", liveVersion), + log.String("currentVersion", currentVersion)) span.SetAttributes(attribute.Bool("skipped", false)) // Create a session that ignore debug logging. @@ -108,6 +118,11 @@ func maybeMigrate(ctx context.Context, logger log.Logger, contract runtime.Contr if err != nil { return errors.Wrapf(err, "auto migrating table for %s", errors.Safe(fmt.Sprintf("%T", table))) } + if m, ok := table.(custommigrator.CustomTableMigrator); ok { + if err := m.RunCustomMigrations(sess.Migrator()); err != nil { + return errors.Wrapf(err, "running custom migrations for %s", errors.Safe(fmt.Sprintf("%T", table))) + } + } } return redisClient.Set(ctx, versionKey, currentVersion, 0).Err() diff --git a/cmd/enterprise-portal/internal/database/subscriptions/BUILD.bazel b/cmd/enterprise-portal/internal/database/subscriptions/BUILD.bazel index d73ea1939da..8cdb672afb0 100644 --- a/cmd/enterprise-portal/internal/database/subscriptions/BUILD.bazel +++ b/cmd/enterprise-portal/internal/database/subscriptions/BUILD.bazel @@ -14,27 +14,39 @@ go_library( visibility = ["//cmd/enterprise-portal:__subpackages__"], deps = [ "//cmd/enterprise-portal/internal/database/internal/upsert", + "//cmd/enterprise-portal/internal/database/internal/utctime", + "//internal/license", + "//lib/enterpriseportal/subscriptions/v1:subscriptions", "//lib/errors", "//lib/pointers", - "@com_github_jackc_pgtype//:pgtype", + "@com_github_google_uuid//:uuid", "@com_github_jackc_pgx_v5//:pgx", "@com_github_jackc_pgx_v5//pgxpool", + "@io_gorm_gorm//:gorm", ], ) go_test( name = "subscriptions_test", - srcs = ["subscriptions_test.go"], + srcs = [ + "licenses_test.go", + "subscriptions_test.go", + ], tags = [ TAG_INFRA_CORESERVICES, "requires-network", ], deps = [ ":subscriptions", + "//cmd/enterprise-portal/internal/database", "//cmd/enterprise-portal/internal/database/databasetest", "//cmd/enterprise-portal/internal/database/internal/tables", + "//cmd/enterprise-portal/internal/database/internal/utctime", + "//internal/license", "//lib/pointers", "@com_github_google_uuid//:uuid", + "@com_github_hexops_autogold_v2//:autogold", + "@com_github_hexops_valast//:valast", "@com_github_jackc_pgx_v5//:pgx", "@com_github_stretchr_testify//assert", "@com_github_stretchr_testify//require", diff --git a/cmd/enterprise-portal/internal/database/subscriptions/license_conditions.go b/cmd/enterprise-portal/internal/database/subscriptions/license_conditions.go index 80cb11211b5..e3cee4179dc 100644 --- a/cmd/enterprise-portal/internal/database/subscriptions/license_conditions.go +++ b/cmd/enterprise-portal/internal/database/subscriptions/license_conditions.go @@ -1,11 +1,17 @@ package subscriptions -import "time" +import ( + "context" + + "github.com/jackc/pgx/v5" + + "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/utctime" + subscriptionsv1 "github.com/sourcegraph/sourcegraph/lib/enterpriseportal/subscriptions/v1" + "github.com/sourcegraph/sourcegraph/lib/errors" + "github.com/sourcegraph/sourcegraph/lib/pointers" +) type SubscriptionLicenseCondition struct { - // ⚠️ DO NOT USE: This field is only used for creating foreign key constraint. - License *SubscriptionLicense `gorm:"foreignKey:LicenseID"` - // SubscriptionID is the internal unprefixed UUID of the related license. LicenseID string `gorm:"type:uuid;not null"` // Status is the type of status corresponding to this condition, corresponding @@ -15,9 +21,73 @@ type SubscriptionLicenseCondition struct { Message *string `gorm:"size:256"` // TransitionTime is the time at which the condition was created, i.e. when // the license transitioned into this status. - TransitionTime time.Time `gorm:"not null;default:current_timestamp"` + TransitionTime utctime.Time `gorm:"not null;default:current_timestamp"` } -func (s *SubscriptionLicenseCondition) TableName() string { +func (*SubscriptionLicenseCondition) TableName() string { return "enterprise_portal_subscription_license_conditions" } + +// subscriptionLicenseConditionJSONBAgg must be used with: +// +// LEFT JOIN +// enterprise_portal_subscription_license_conditions license_condition +// ON license_condition.license_id = id +// GROUP BY +// id +// +// The conditions are aggregated in JSON to 'conditions', which can be directly +// unmarshaled into the 'SubscriptionLicenseCondition' type using 'pgx'. +func subscriptionLicenseConditionJSONBAgg() string { + return ` +jsonb_agg( + jsonb_build_object( + 'Status', license_condition.status, + 'Message', license_condition.message, + 'TransitionTime', license_condition.transition_time + ) + ORDER BY license_condition.transition_time DESC +) AS conditions` +} + +type licenseConditionsStore struct{ tx pgx.Tx } + +// newLicenseConditionsStore is meant to be used exclusively in the context of +// a transaction, where the parent license is being updated at the same time. +// +// The caller owns the transaction lifecycle. +func newLicenseConditionsStore(tx pgx.Tx) *licenseConditionsStore { + return &licenseConditionsStore{tx: tx} +} + +type createLicenseConditionOpts struct { + Status subscriptionsv1.EnterpriseSubscriptionLicenseCondition_Status + Message string + TransitionTime utctime.Time +} + +func (s *licenseConditionsStore) createLicenseCondition(ctx context.Context, licenseID string, opts createLicenseConditionOpts) error { + if opts.TransitionTime.GetTime().IsZero() { + return errors.New("transition time is required") + } + _, err := s.tx.Exec(ctx, ` +INSERT INTO enterprise_portal_subscription_license_conditions ( + license_id, + status, + message, + transition_time +) +VALUES ( + @licenseID, + @status, + @message, + @transitionTime +)`, pgx.NamedArgs{ + "licenseID": licenseID, + // Convert to string representation of EnterpriseSubscriptionLicenseCondition + "status": subscriptionsv1.EnterpriseSubscriptionLicenseCondition_Status_name[int32(opts.Status)], + "message": pointers.NilIfZero(opts.Message), + "transitionTime": opts.TransitionTime, + }) + return err +} diff --git a/cmd/enterprise-portal/internal/database/subscriptions/licenses.go b/cmd/enterprise-portal/internal/database/subscriptions/licenses.go index 2b64bbea108..386130a5549 100644 --- a/cmd/enterprise-portal/internal/database/subscriptions/licenses.go +++ b/cmd/enterprise-portal/internal/database/subscriptions/licenses.go @@ -1,15 +1,49 @@ package subscriptions import ( + "context" + "encoding/json" + "fmt" + "strings" "time" - "github.com/jackc/pgtype" + "github.com/google/uuid" + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + "gorm.io/gorm" + + "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/utctime" + internallicense "github.com/sourcegraph/sourcegraph/internal/license" + subscriptionsv1 "github.com/sourcegraph/sourcegraph/lib/enterpriseportal/subscriptions/v1" + "github.com/sourcegraph/sourcegraph/lib/pointers" + + "github.com/sourcegraph/sourcegraph/lib/errors" ) -type SubscriptionLicense struct { +// ⚠️ DO NOT USE: This type is only used for creating foreign key constraints +// and initializing tables with gorm. +type TableSubscriptionLicense struct { // ⚠️ DO NOT USE: This field is only used for creating foreign key constraint. - Subscription *Subscription `gorm:"foreignKey:SubscriptionID"` + Conditions *[]SubscriptionLicenseCondition `gorm:"foreignKey:LicenseID"` + SubscriptionLicense +} + +func (*TableSubscriptionLicense) TableName() string { + return "enterprise_portal_subscription_licenses" +} + +// Implement tables.CustomMigrator +func (s *TableSubscriptionLicense) RunCustomMigrations(migrator gorm.Migrator) error { + if migrator.HasColumn(s, "license_kind") { + if err := migrator.DropColumn(s, "license_kind"); err != nil { + return err + } + } + return nil +} + +type SubscriptionLicense struct { // SubscriptionID is the internal unprefixed UUID of the related subscription. SubscriptionID string `gorm:"type:uuid;not null"` // ID is the internal unprefixed UUID of this license. @@ -19,21 +53,331 @@ type SubscriptionLicense struct { // to this subscription. // // Condition transition details are tracked in 'enterprise_portal_subscription_license_conditions'. - CreatedAt time.Time `gorm:"not null;default:current_timestamp"` - RevokedAt *time.Time // Null indicates the licnese is not revoked. + CreatedAt utctime.Time `gorm:"not null;default:current_timestamp"` + RevokedAt *utctime.Time // Null indicates the license is not revoked. - // LicenseKind is the kind of license stored in LicenseData, corresponding + // ExpireAt is the time at which the license should expire. Expiration does + // NOT get a corresponding condition entry in 'enterprise_portal_subscription_license_conditions'. + ExpireAt utctime.Time `gorm:"not null"` + + // LicenseType is the kind of license stored in LicenseData, corresponding // to the API 'EnterpriseSubscriptionLicenseType'. - LicenseKind string `gorm:"not null"` + LicenseType string `gorm:"not null"` // LicenseData is the license data stored in JSON format. It is read-only // and generally never queried in conditions - properties that are should // be stored at the subscription or license level. // // Value shapes correspond to API types appropriate for each // 'EnterpriseSubscriptionLicenseType'. - LicenseData pgtype.JSONB `gorm:"type:jsonb"` + LicenseData json.RawMessage `gorm:"type:jsonb"` } -func (s *SubscriptionLicense) TableName() string { - return "enterprise_portal_subscription_licenses" +// subscriptionLicenseWithConditionsColumns must match scanSubscriptionLicense() +// values. +func subscriptionLicenseWithConditionsColumns() []string { + return []string{ + "subscription_id", + "id", + + "created_at", + "revoked_at", + "expire_at", + + "license_type", + "license_data", + + subscriptionLicenseConditionJSONBAgg(), + } +} + +type LicenseWithConditions struct { + SubscriptionLicense + Conditions []SubscriptionLicenseCondition +} + +// scanSubscription matches subscriptionTableColumns() values. +func scanSubscriptionLicenseWithConditions(row pgx.Row) (*LicenseWithConditions, error) { + var l LicenseWithConditions + err := row.Scan( + &l.SubscriptionID, + &l.ID, + &l.CreatedAt, + &l.RevokedAt, + &l.ExpireAt, + &l.LicenseType, + &l.LicenseData, + &l.Conditions, // see subscriptionLicenseConditionJSONBAgg docstring + ) + return &l, err +} + +// LicensesStore manages licenses belonging to Enterprise subscriptions. +// +// Licenses can only be created and revoked - they can never be updated. +type LicensesStore struct { + db *pgxpool.Pool +} + +func NewLicensesStore(db *pgxpool.Pool) *LicensesStore { + return &LicensesStore{ + db: db, + } +} + +type ListLicensesOpts struct { + SubscriptionID string + // PageSize is the maximum number of licenses to return. + PageSize int +} + +func (opts ListLicensesOpts) toQueryConditions() (where, limitClause string, _ pgx.NamedArgs) { + whereConds := []string{"TRUE"} + namedArgs := pgx.NamedArgs{} + if opts.SubscriptionID != "" { + whereConds = append(whereConds, "subscription_id = @subscriptionID") + namedArgs["subscriptionID"] = opts.SubscriptionID + } + where = strings.Join(whereConds, " AND ") + + if opts.PageSize > 0 { + limitClause = "LIMIT @pageSize" + namedArgs["pageSize"] = opts.PageSize + } + return where, limitClause, namedArgs +} + +func (s *LicensesStore) List(ctx context.Context, opts ListLicensesOpts) ([]*LicenseWithConditions, error) { + where, limitClause, namedArgs := opts.toQueryConditions() + query := fmt.Sprintf(` +SELECT + %s +FROM + enterprise_portal_subscription_licenses +LEFT JOIN + enterprise_portal_subscription_license_conditions license_condition + ON license_condition.license_id = id +WHERE + %s +GROUP BY + id +ORDER BY + created_at DESC +%s`, + strings.Join(subscriptionLicenseWithConditionsColumns(), ", "), + where, limitClause) + + rows, err := s.db.Query(ctx, query, namedArgs) + if err != nil { + return nil, errors.Wrap(err, "query rows") + } + defer rows.Close() + + var licenses []*LicenseWithConditions + for rows.Next() { + license, err := scanSubscriptionLicenseWithConditions(rows) + if err != nil { + return nil, errors.Wrap(err, "scan row") + } + licenses = append(licenses, license) + } + return licenses, rows.Err() +} + +func (s *LicensesStore) Get(ctx context.Context, licenseID string) (*LicenseWithConditions, error) { + query := fmt.Sprintf(` +SELECT + %s +FROM + enterprise_portal_subscription_licenses +LEFT JOIN + enterprise_portal_subscription_license_conditions license_condition + ON license_condition.license_id = id +WHERE + id = @licenseID +GROUP BY + id`, + strings.Join(subscriptionLicenseWithConditionsColumns(), ", ")) + + license, err := scanSubscriptionLicenseWithConditions( + s.db.QueryRow(ctx, query, pgx.NamedArgs{ + "licenseID": licenseID, + }), + ) + if err != nil { + return nil, errors.Wrap(err, "query rows") + } + return license, nil +} + +type CreateLicenseOpts struct { + Message string + // If nil, the creation time will be set to the current time. + Time *utctime.Time + // Expiration time of the license. + ExpireTime utctime.Time +} + +// LicenseKey corresponds to *subscriptionsv1.EnterpriseSubscriptionLicenseKey +// and the 'ENTERPRISE_SUBSCRIPTION_LICENSE_TYPE_KEY' license type. +type LicenseKey struct { + Info internallicense.Info + // Signed license key with the license information in Info. + SignedKey string +} + +// CreateLicense creates a new classic offline license for the given subscription. +func (s *LicensesStore) CreateLicenseKey( + ctx context.Context, + subscriptionID string, + license *LicenseKey, + opts CreateLicenseOpts, +) (_ *LicenseWithConditions, err error) { + // Special behaviour: the license key embeds the creation time, and it must + // match the time provided in the options. + if opts.Time == nil { + return nil, errors.New("creation time must be specified for licensekeys") + } else if !opts.Time.GetTime().Equal(utctime.FromTime(license.Info.CreatedAt).AsTime()) { + return nil, errors.New("creation time must match the license key information") + } + if !opts.ExpireTime.GetTime().Equal(utctime.FromTime(license.Info.ExpiresAt).AsTime()) { + return nil, errors.New("expiration time must match the license key information") + } + + return s.create( + ctx, + subscriptionID, + subscriptionsv1.EnterpriseSubscriptionLicenseType_ENTERPRISE_SUBSCRIPTION_LICENSE_TYPE_KEY, + license, + opts, + ) +} + +func (s *LicensesStore) create( + ctx context.Context, + subscriptionID string, + licenseType subscriptionsv1.EnterpriseSubscriptionLicenseType, + license any, + opts CreateLicenseOpts, +) (_ *LicenseWithConditions, err error) { + if subscriptionID == "" { + return nil, errors.New("subscription ID must be specified") + } + if opts.Time == nil { + opts.Time = pointers.Ptr(utctime.Now()) + } else if opts.Time.GetTime().After(time.Now()) { + return nil, errors.New("creation time cannot be in the future") + } + if licenseType == subscriptionsv1.EnterpriseSubscriptionLicenseType_ENTERPRISE_SUBSCRIPTION_LICENSE_TYPE_UNSPECIFIED { + return nil, errors.New("license type must be specified") + } + + licenseID, err := uuid.NewV7() + if err != nil { + return nil, errors.Wrap(err, "generate uuid") + } + licenseData, err := json.Marshal(license) + if err != nil { + return nil, errors.Wrap(err, "marshal license data") + } + tx, err := s.db.Begin(ctx) + if err != nil { + return nil, errors.Wrap(err, "begin transaction") + } + defer func() { + if rollbackErr := tx.Rollback(context.Background()); rollbackErr != nil { + err = errors.Append(err, errors.Wrap(err, "rollback")) + } + }() + + if _, err := tx.Exec(ctx, ` +INSERT INTO enterprise_portal_subscription_licenses ( + id, + subscription_id, + license_type, + license_data, + created_at, + expire_at +) +VALUES ( + @licenseID, + @subscriptionID, + @licenseType, + @licenseData, + @createdAt, + @expireAt +) +`, pgx.NamedArgs{ + "licenseID": licenseID.String(), + "subscriptionID": subscriptionID, + "licenseType": subscriptionsv1.EnterpriseSubscriptionLicenseType_name[int32(licenseType)], + "licenseData": licenseData, + "createdAt": opts.Time, + "expireAt": opts.ExpireTime, + }); err != nil { + return nil, errors.Wrap(err, "create license") + } + + if err := newLicenseConditionsStore(tx).createLicenseCondition(ctx, licenseID.String(), createLicenseConditionOpts{ + Status: subscriptionsv1.EnterpriseSubscriptionLicenseCondition_STATUS_CREATED, + Message: opts.Message, + TransitionTime: *opts.Time, + }); err != nil { + return nil, errors.Wrap(err, "create license condition") + } + + if err := tx.Commit(ctx); err != nil { + return nil, errors.Wrap(err, "commit transaction") + } + + return s.Get(ctx, licenseID.String()) +} + +type RevokeLicenseOpts struct { + Message string + // If nil, the revocation time will be set to the current time. + Time *utctime.Time +} + +// Revoke marks the given license as revoked. +func (s *LicensesStore) Revoke(ctx context.Context, licenseID string, opts RevokeLicenseOpts) (*LicenseWithConditions, error) { + if opts.Time == nil { + opts.Time = pointers.Ptr(utctime.Now()) + } else if opts.Time.GetTime().After(time.Now()) { + return nil, errors.New("revocation time cannot be in the future") + } + + tx, err := s.db.Begin(ctx) + if err != nil { + return nil, errors.Wrap(err, "begin transaction") + } + defer func() { + if rollbackErr := tx.Rollback(context.Background()); rollbackErr != nil { + err = errors.Append(err, rollbackErr) + } + }() + + if _, err := tx.Exec(ctx, ` +UPDATE enterprise_portal_subscription_licenses +SET revoked_at = COALESCE(revoked_at, @revokedAt) -- use existing revoke time if already revoked +WHERE id = @licenseID +`, pgx.NamedArgs{ + "revokedAt": opts.Time, + "licenseID": licenseID, + }); err != nil { + return nil, errors.Wrap(err, "revoke license") + } + + if err := newLicenseConditionsStore(tx).createLicenseCondition(ctx, licenseID, createLicenseConditionOpts{ + Status: subscriptionsv1.EnterpriseSubscriptionLicenseCondition_STATUS_REVOKED, + Message: opts.Message, + TransitionTime: *opts.Time, + }); err != nil { + return nil, errors.Wrap(err, "create license condition") + } + + if err := tx.Commit(ctx); err != nil { + return nil, errors.Wrap(err, "commit transaction") + } + + return s.Get(ctx, licenseID) } diff --git a/cmd/enterprise-portal/internal/database/subscriptions/licenses_test.go b/cmd/enterprise-portal/internal/database/subscriptions/licenses_test.go new file mode 100644 index 00000000000..c79657b7045 --- /dev/null +++ b/cmd/enterprise-portal/internal/database/subscriptions/licenses_test.go @@ -0,0 +1,238 @@ +package subscriptions_test + +import ( + "context" + "fmt" + "testing" + "time" + + "github.com/google/uuid" + "github.com/hexops/autogold/v2" + "github.com/hexops/valast" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database" + "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/databasetest" + "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables" + "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/utctime" + "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/subscriptions" + "github.com/sourcegraph/sourcegraph/internal/license" + "github.com/sourcegraph/sourcegraph/lib/pointers" +) + +func TestLicensesStore(t *testing.T) { + t.Parallel() + + ctx := context.Background() + db := databasetest.NewTestDB(t, "enterprise-portal", t.Name(), tables.All()...) + + subscriptionID1 := uuid.NewString() + subscriptionID2 := uuid.NewString() + + subs := subscriptions.NewStore(db) + _, err := subs.Upsert(ctx, subscriptionID1, subscriptions.UpsertSubscriptionOptions{ + DisplayName: pointers.Ptr(database.NewNullString("Acme, Inc. 1")), + }) + require.NoError(t, err) + _, err = subs.Upsert(ctx, subscriptionID2, subscriptions.UpsertSubscriptionOptions{ + DisplayName: pointers.Ptr(database.NewNullString("Acme, Inc. 2")), + }) + require.NoError(t, err) + + licenses := subscriptions.NewLicensesStore(db) + + var createdLicenses []*subscriptions.LicenseWithConditions + getCreatedByLicenseID := func(t *testing.T, licenseID string) *subscriptions.LicenseWithConditions { + for _, l := range createdLicenses { + if l.ID == licenseID { + return l + } + } + t.Errorf("license %q not found", licenseID) + t.FailNow() + return nil + } + t.Run("CreateLicenseKey", func(t *testing.T) { + testLicense := func( + got *subscriptions.LicenseWithConditions, + wantMessage autogold.Value, + wantLicenseData autogold.Value, + ) { + assert.NotEmpty(t, got.ID) + assert.NotZero(t, got.CreatedAt) + assert.NotZero(t, got.ExpireAt) + assert.Equal(t, "ENTERPRISE_SUBSCRIPTION_LICENSE_TYPE_KEY", got.LicenseType) + wantLicenseData.Equal(t, string(got.LicenseData)) + + assert.Len(t, got.Conditions, 1) + wantMessage.Equal(t, got.Conditions[0].Message) + assert.Equal(t, "STATUS_CREATED", got.Conditions[0].Status) + assert.Equal(t, got.CreatedAt, got.Conditions[0].TransitionTime) + } + + got, err := licenses.CreateLicenseKey(ctx, subscriptionID1, + &subscriptions.LicenseKey{ + Info: license.Info{ + Tags: []string{"foo"}, + CreatedAt: time.Time{}.Add(1 * time.Hour), + ExpiresAt: time.Time{}.Add(48 * time.Hour), + }, + SignedKey: "asdfasdf", + }, + subscriptions.CreateLicenseOpts{ + Message: t.Name() + " 1 old", + Time: pointers.Ptr(utctime.FromTime(time.Time{}.Add(1 * time.Hour))), + ExpireTime: utctime.FromTime(time.Time{}.Add(48 * time.Hour)), + }) + require.NoError(t, err) + testLicense( + got, + autogold.Expect(valast.Ptr("TestLicensesStore/CreateLicenseKey 1 old")), + autogold.Expect(`{"Info": {"c": "0001-01-01T01:00:00Z", "e": "0001-01-03T00:00:00Z", "t": ["foo"], "u": 0}, "SignedKey": "asdfasdf"}`), + ) + createdLicenses = append(createdLicenses, got) + + got, err = licenses.CreateLicenseKey(ctx, subscriptionID1, + &subscriptions.LicenseKey{ + Info: license.Info{ + Tags: []string{"baz"}, + CreatedAt: time.Time{}.Add(24 * time.Hour), + ExpiresAt: time.Time{}.Add(48 * time.Hour), + }, + SignedKey: "barasdf", + }, + subscriptions.CreateLicenseOpts{ + Message: t.Name() + " 1", + Time: pointers.Ptr(utctime.FromTime(time.Time{}.Add(24 * time.Hour))), + ExpireTime: utctime.FromTime(time.Time{}.Add(48 * time.Hour)), + }) + require.NoError(t, err) + testLicense( + got, + autogold.Expect(valast.Ptr("TestLicensesStore/CreateLicenseKey 1")), + autogold.Expect(`{"Info": {"c": "0001-01-02T00:00:00Z", "e": "0001-01-03T00:00:00Z", "t": ["baz"], "u": 0}, "SignedKey": "barasdf"}`), + ) + createdLicenses = append(createdLicenses, got) + + got, err = licenses.CreateLicenseKey(ctx, subscriptionID2, + &subscriptions.LicenseKey{ + Info: license.Info{ + Tags: []string{"tag"}, + CreatedAt: time.Time{}.Add(24 * time.Hour), + ExpiresAt: time.Time{}.Add(48 * time.Hour), + }, + SignedKey: "asdffdsadf", + }, + subscriptions.CreateLicenseOpts{ + Message: t.Name() + " 2", + Time: pointers.Ptr(utctime.FromTime(time.Time{}.Add(24 * time.Hour))), + ExpireTime: utctime.FromTime(time.Time{}.Add(48 * time.Hour)), + }) + require.NoError(t, err) + testLicense( + got, + autogold.Expect(valast.Ptr("TestLicensesStore/CreateLicenseKey 2")), + autogold.Expect(`{"Info": {"c": "0001-01-02T00:00:00Z", "e": "0001-01-03T00:00:00Z", "t": ["tag"], "u": 0}, "SignedKey": "asdffdsadf"}`), + ) + createdLicenses = append(createdLicenses, got) + + t.Run("createdAt does not match", func(t *testing.T) { + _, err = licenses.CreateLicenseKey(ctx, subscriptionID2, + &subscriptions.LicenseKey{ + Info: license.Info{ + Tags: []string{"tag"}, + CreatedAt: time.Time{}.Add(24 * time.Hour), + }, + SignedKey: "asdffdsadf", + }, + subscriptions.CreateLicenseOpts{ + Message: t.Name(), + Time: pointers.Ptr(utctime.Now()), + }) + require.Error(t, err) + autogold.Expect("creation time must match the license key information").Equal(t, err.Error()) + }) + t.Run("expiresAt does not match", func(t *testing.T) { + _, err = licenses.CreateLicenseKey(ctx, subscriptionID2, + &subscriptions.LicenseKey{ + Info: license.Info{ + Tags: []string{"tag"}, + CreatedAt: time.Time{}, + ExpiresAt: time.Time{}.Add(48 * time.Hour), + }, + SignedKey: "asdffdsadf", + }, + subscriptions.CreateLicenseOpts{ + Message: t.Name(), + Time: pointers.Ptr(utctime.FromTime(time.Time{})), + ExpireTime: utctime.Now(), + }) + require.Error(t, err) + autogold.Expect("expiration time must match the license key information").Equal(t, err.Error()) + }) + }) + + // No point continuing if test licenses did not create, all tests after this + // will fail + if t.Failed() { + t.FailNow() + } + + t.Run("List", func(t *testing.T) { + listedLicenses, err := licenses.List(ctx, subscriptions.ListLicensesOpts{}) + require.NoError(t, err) + assert.Len(t, listedLicenses, len(createdLicenses)) + for _, l := range listedLicenses { + created := getCreatedByLicenseID(t, l.ID) + assert.Equal(t, *created, *l) + } + + t.Run("List by subscription", func(t *testing.T) { + listedLicenses, err := licenses.List(ctx, subscriptions.ListLicensesOpts{ + SubscriptionID: subscriptionID1, + }) + require.NoError(t, err) + assert.Len(t, listedLicenses, 2) + for _, l := range listedLicenses { + assert.Equal(t, subscriptionID1, l.SubscriptionID) + assert.Equal(t, *getCreatedByLicenseID(t, l.ID), *l) + } + + listedLicenses, err = licenses.List(ctx, subscriptions.ListLicensesOpts{ + SubscriptionID: subscriptionID2, + }) + require.NoError(t, err) + assert.Len(t, listedLicenses, 1) + for _, l := range listedLicenses { + assert.Equal(t, subscriptionID2, l.SubscriptionID) + assert.Equal(t, *getCreatedByLicenseID(t, l.ID), *l) + } + }) + }) + + t.Run("Get", func(t *testing.T) { + for _, license := range createdLicenses { + got, err := licenses.Get(ctx, license.ID) + require.NoError(t, err) + assert.Equal(t, *license, *got) + } + }) + + t.Run("Revoke", func(t *testing.T) { + for idx, license := range createdLicenses { + revokeTime := utctime.FromTime(time.Now().Add(-time.Second)) + got, err := licenses.Revoke(ctx, license.ID, subscriptions.RevokeLicenseOpts{ + Message: fmt.Sprintf("%s %d", t.Name(), idx), + Time: pointers.Ptr(revokeTime), + }) + require.NoError(t, err) + assert.Equal(t, revokeTime.AsTime(), got.RevokedAt.AsTime()) + require.Len(t, got.Conditions, 2) + // Most recent condition is sorted first, and should be the revocation + assert.Equal(t, "STATUS_REVOKED", got.Conditions[0].Status) + assert.Equal(t, revokeTime.AsTime(), got.Conditions[0].TransitionTime.AsTime()) + assert.Equal(t, "STATUS_CREATED", got.Conditions[1].Status) + } + }) +} diff --git a/cmd/enterprise-portal/internal/database/subscriptions/subscriptions.go b/cmd/enterprise-portal/internal/database/subscriptions/subscriptions.go index 94e8d8e0297..da454bcf4ee 100644 --- a/cmd/enterprise-portal/internal/database/subscriptions/subscriptions.go +++ b/cmd/enterprise-portal/internal/database/subscriptions/subscriptions.go @@ -2,6 +2,7 @@ package subscriptions import ( "context" + "database/sql" "fmt" "strings" "time" @@ -10,10 +11,26 @@ import ( "github.com/jackc/pgx/v5/pgxpool" "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/upsert" + "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/utctime" "github.com/sourcegraph/sourcegraph/lib/errors" - "github.com/sourcegraph/sourcegraph/lib/pointers" ) +// ⚠️ DO NOT USE: This type is only used for creating foreign key constraints +// and initializing tables with gorm. +type TableSubscription struct { + // Each Subscription has many Licenses. + Licenses []*TableSubscriptionLicense `gorm:"foreignKey:SubscriptionID"` + + // Each Subscription has many Conditions. + Conditions *[]SubscriptionCondition `gorm:"foreignKey:SubscriptionID"` + + Subscription +} + +func (*TableSubscription) TableName() string { + return "enterprise_portal_subscriptions" +} + // Subscription is an Enterprise subscription record. type Subscription struct { // ID is the internal (unprefixed) UUID-format identifier for the subscription. @@ -22,7 +39,7 @@ type Subscription struct { // "acme.sourcegraphcloud.com". This is set explicitly. // // It must be unique across all currently un-archived subscriptions. - InstanceDomain string `gorm:"uniqueIndex:,where:archived_at IS NULL"` + InstanceDomain *string `gorm:"uniqueIndex:,where:archived_at IS NULL"` // WARNING: The below fields are not yet used in production. @@ -33,15 +50,15 @@ type Subscription struct { // // TODO: Clean up the database post-deploy and remove the 'Unnamed subscription' // part of the constraint. - DisplayName string `gorm:"size:256;not null;uniqueIndex:,where:archived_at IS NULL AND display_name != 'Unnamed subscription' AND display_name != ''"` + DisplayName *string `gorm:"size:256;uniqueIndex:,where:archived_at IS NULL AND display_name != 'Unnamed subscription' AND display_name != ''"` // Timestamps representing the latest timestamps of key conditions related // to this subscription. // // Condition transition details are tracked in 'enterprise_portal_subscription_conditions'. - CreatedAt time.Time `gorm:"not null;default:current_timestamp"` - UpdatedAt time.Time `gorm:"not null;default:current_timestamp"` - ArchivedAt *time.Time // Null indicates the subscription is not archived. + CreatedAt utctime.Time `gorm:"not null;default:current_timestamp"` + UpdatedAt utctime.Time `gorm:"not null;default:current_timestamp"` + ArchivedAt *utctime.Time // Null indicates the subscription is not archived. // SalesforceSubscriptionID associated with this Enterprise subscription. SalesforceSubscriptionID *string @@ -49,11 +66,7 @@ type Subscription struct { SalesforceOpportunityID *string } -func (s Subscription) TableName() string { - return "enterprise_portal_subscriptions" -} - -// subscriptionTableColumns must match s.scan() values. +// subscriptionTableColumns must match scanSubscription() values. func subscriptionTableColumns() []string { return []string{ "id", @@ -67,7 +80,7 @@ func subscriptionTableColumns() []string { } } -// scanSubscription matches s.columns() values. +// scanSubscription matches subscriptionTableColumns() values. func scanSubscription(row pgx.Row) (*Subscription, error) { var s Subscription err := row.Scan( @@ -83,13 +96,6 @@ func scanSubscription(row pgx.Row) (*Subscription, error) { if err != nil { return nil, err } - - s.CreatedAt = s.CreatedAt.UTC() - s.UpdatedAt = s.UpdatedAt.UTC() - if s.ArchivedAt != nil { - s.ArchivedAt = pointers.Ptr(s.ArchivedAt.UTC()) - } - return &s, nil } @@ -172,8 +178,8 @@ WHERE %s } type UpsertSubscriptionOptions struct { - InstanceDomain string - DisplayName string + InstanceDomain *sql.NullString + DisplayName *sql.NullString CreatedAt time.Time ArchivedAt *time.Time diff --git a/cmd/enterprise-portal/internal/database/subscriptions/subscriptions_conditions.go b/cmd/enterprise-portal/internal/database/subscriptions/subscriptions_conditions.go index f30419769a3..9de9f3ab821 100644 --- a/cmd/enterprise-portal/internal/database/subscriptions/subscriptions_conditions.go +++ b/cmd/enterprise-portal/internal/database/subscriptions/subscriptions_conditions.go @@ -1,14 +1,9 @@ package subscriptions -import ( - "time" -) +import "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/utctime" // Subscription is an Enterprise subscription condition record. type SubscriptionCondition struct { - // ⚠️ DO NOT USE: This field is only used for creating foreign key constraint. - Subscription *Subscription `gorm:"foreignKey:SubscriptionID"` - // SubscriptionID is the internal unprefixed UUID of the related subscription. SubscriptionID string `gorm:"type:uuid;not null"` // Status is the type of status corresponding to this condition, corresponding @@ -18,7 +13,7 @@ type SubscriptionCondition struct { Message *string `gorm:"size:256"` // TransitionTime is the time at which the condition was created, i.e. when // the subscription transitioned into this status. - TransitionTime time.Time `gorm:"not null;default:current_timestamp"` + TransitionTime utctime.Time `gorm:"not null;default:current_timestamp"` } func (s *SubscriptionCondition) TableName() string { diff --git a/cmd/enterprise-portal/internal/database/subscriptions/subscriptions_test.go b/cmd/enterprise-portal/internal/database/subscriptions/subscriptions_test.go index 4aacc18c881..2cbef7b07df 100644 --- a/cmd/enterprise-portal/internal/database/subscriptions/subscriptions_test.go +++ b/cmd/enterprise-portal/internal/database/subscriptions/subscriptions_test.go @@ -10,6 +10,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database" "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/databasetest" "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables" "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/subscriptions" @@ -20,7 +21,7 @@ func TestSubscriptionsStore(t *testing.T) { t.Parallel() ctx := context.Background() - db := databasetest.NewTestDB(t, "enterprise-portal", "SubscriptionsStore", tables.All()...) + db := databasetest.NewTestDB(t, "enterprise-portal", t.Name(), tables.All()...) for _, tc := range []struct { name string @@ -45,19 +46,25 @@ func SubscriptionsStoreList(t *testing.T, ctx context.Context, s *subscriptions. s1, err := s.Upsert( ctx, uuid.New().String(), - subscriptions.UpsertSubscriptionOptions{InstanceDomain: "s1.sourcegraph.com"}, + subscriptions.UpsertSubscriptionOptions{ + InstanceDomain: pointers.Ptr(database.NewNullString("s1.sourcegraph.com")), + }, ) require.NoError(t, err) s2, err := s.Upsert( ctx, uuid.New().String(), - subscriptions.UpsertSubscriptionOptions{InstanceDomain: "s2.sourcegraph.com"}, + subscriptions.UpsertSubscriptionOptions{ + InstanceDomain: pointers.Ptr(database.NewNullString("s2.sourcegraph.com")), + }, ) require.NoError(t, err) _, err = s.Upsert( ctx, uuid.New().String(), - subscriptions.UpsertSubscriptionOptions{InstanceDomain: "s3.sourcegraph.com"}, + subscriptions.UpsertSubscriptionOptions{ + InstanceDomain: pointers.Ptr(database.NewNullString("s3.sourcegraph.com")), + }, ) require.NoError(t, err) @@ -79,7 +86,7 @@ func SubscriptionsStoreList(t *testing.T, ctx context.Context, s *subscriptions. t.Run("list by instance domains", func(t *testing.T) { ss, err := s.List(ctx, subscriptions.ListEnterpriseSubscriptionsOptions{ - InstanceDomains: []string{s1.InstanceDomain, s2.InstanceDomain}}, + InstanceDomains: []string{*s1.InstanceDomain, *s2.InstanceDomain}}, ) require.NoError(t, err) require.Len(t, ss, 2) @@ -115,14 +122,16 @@ func SubscriptionsStoreUpsert(t *testing.T, ctx context.Context, s *subscription currentSubscription, err := s.Upsert( ctx, uuid.New().String(), - subscriptions.UpsertSubscriptionOptions{InstanceDomain: "s1.sourcegraph.com"}, + subscriptions.UpsertSubscriptionOptions{ + InstanceDomain: pointers.Ptr(database.NewNullString("s1.sourcegraph.com")), + }, ) require.NoError(t, err) got, err := s.Get(ctx, currentSubscription.ID) require.NoError(t, err) assert.Equal(t, currentSubscription.ID, got.ID) - assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain) + assert.Equal(t, *currentSubscription.InstanceDomain, *got.InstanceDomain) assert.Empty(t, got.DisplayName) assert.NotZero(t, got.CreatedAt) assert.NotZero(t, got.UpdatedAt) @@ -133,17 +142,19 @@ func SubscriptionsStoreUpsert(t *testing.T, ctx context.Context, s *subscription got, err = s.Upsert(ctx, currentSubscription.ID, subscriptions.UpsertSubscriptionOptions{}) require.NoError(t, err) - assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain) + assert.Equal(t, + pointers.DerefZero(currentSubscription.InstanceDomain), + pointers.DerefZero(got.InstanceDomain)) }) t.Run("update only domain", func(t *testing.T) { t.Cleanup(func() { currentSubscription = got }) got, err = s.Upsert(ctx, currentSubscription.ID, subscriptions.UpsertSubscriptionOptions{ - InstanceDomain: "s1-new.sourcegraph.com", + InstanceDomain: pointers.Ptr(database.NewNullString("s1-new.sourcegraph.com")), }) require.NoError(t, err) - assert.Equal(t, "s1-new.sourcegraph.com", got.InstanceDomain) + assert.Equal(t, "s1-new.sourcegraph.com", pointers.DerefZero(got.InstanceDomain)) assert.Equal(t, currentSubscription.DisplayName, got.DisplayName) }) @@ -151,11 +162,11 @@ func SubscriptionsStoreUpsert(t *testing.T, ctx context.Context, s *subscription t.Cleanup(func() { currentSubscription = got }) got, err = s.Upsert(ctx, currentSubscription.ID, subscriptions.UpsertSubscriptionOptions{ - DisplayName: "My New Display Name", + DisplayName: pointers.Ptr(database.NewNullString("My New Display Name")), }) require.NoError(t, err) - assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain) - assert.Equal(t, "My New Display Name", got.DisplayName) + assert.Equal(t, *currentSubscription.InstanceDomain, *got.InstanceDomain) + assert.Equal(t, "My New Display Name", pointers.DerefZero(got.DisplayName)) }) t.Run("update only created at", func(t *testing.T) { @@ -166,10 +177,12 @@ func SubscriptionsStoreUpsert(t *testing.T, ctx context.Context, s *subscription CreatedAt: yesterday, }) require.NoError(t, err) - assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain) + assert.Equal(t, + pointers.DerefZero(currentSubscription.InstanceDomain), + pointers.DerefZero(got.InstanceDomain)) assert.Equal(t, currentSubscription.DisplayName, got.DisplayName) // Round times to allow for some precision drift in CI - assert.Equal(t, yesterday.Round(time.Second).UTC(), got.CreatedAt.Round(time.Second)) + assert.Equal(t, yesterday.Round(time.Second).UTC(), got.CreatedAt.GetTime().Round(time.Second)) }) t.Run("update only archived at", func(t *testing.T) { @@ -180,11 +193,11 @@ func SubscriptionsStoreUpsert(t *testing.T, ctx context.Context, s *subscription ArchivedAt: pointers.Ptr(yesterday), }) require.NoError(t, err) - assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain) - assert.Equal(t, currentSubscription.DisplayName, got.DisplayName) + assert.Equal(t, *currentSubscription.InstanceDomain, *got.InstanceDomain) + assert.Equal(t, *currentSubscription.DisplayName, *got.DisplayName) assert.Equal(t, currentSubscription.CreatedAt, got.CreatedAt) // Round times to allow for some precision drift in CI - assert.Equal(t, yesterday.Round(time.Second).UTC(), got.ArchivedAt.Round(time.Second)) + assert.Equal(t, yesterday.Round(time.Second).UTC(), got.ArchivedAt.GetTime().Round(time.Second)) }) t.Run("force update to zero values", func(t *testing.T) { @@ -209,7 +222,9 @@ func SubscriptionsStoreGet(t *testing.T, ctx context.Context, s *subscriptions.S s1, err := s.Upsert( ctx, uuid.New().String(), - subscriptions.UpsertSubscriptionOptions{InstanceDomain: "s1.sourcegraph.com"}, + subscriptions.UpsertSubscriptionOptions{ + InstanceDomain: pointers.Ptr(database.NewNullString("s1.sourcegraph.com")), + }, ) require.NoError(t, err) diff --git a/cmd/enterprise-portal/internal/database/types.go b/cmd/enterprise-portal/internal/database/types.go new file mode 100644 index 00000000000..366ed5e5481 --- /dev/null +++ b/cmd/enterprise-portal/internal/database/types.go @@ -0,0 +1,10 @@ +package database + +import "database/sql" + +func NewNullString(v string) sql.NullString { + return sql.NullString{ + String: v, + Valid: v != "", + } +} diff --git a/cmd/enterprise-portal/internal/subscriptionsservice/adapters.go b/cmd/enterprise-portal/internal/subscriptionsservice/adapters.go index 52ed657b067..97e2c7aba5d 100644 --- a/cmd/enterprise-portal/internal/subscriptionsservice/adapters.go +++ b/cmd/enterprise-portal/internal/subscriptionsservice/adapters.go @@ -68,8 +68,8 @@ func convertSubscriptionToProto(subscription *subscriptions.Subscription, attrs return &subscriptionsv1.EnterpriseSubscription{ Id: subscriptionsv1.EnterpriseSubscriptionIDPrefix + attrs.ID, Conditions: conds, - InstanceDomain: subscription.InstanceDomain, - DisplayName: subscription.DisplayName, + InstanceDomain: pointers.DerefZero(subscription.InstanceDomain), + DisplayName: pointers.DerefZero(subscription.DisplayName), } } diff --git a/cmd/enterprise-portal/internal/subscriptionsservice/v1.go b/cmd/enterprise-portal/internal/subscriptionsservice/v1.go index 7e4a16cbc95..0e76abde953 100644 --- a/cmd/enterprise-portal/internal/subscriptionsservice/v1.go +++ b/cmd/enterprise-portal/internal/subscriptionsservice/v1.go @@ -16,8 +16,10 @@ import ( subscriptionsv1connect "github.com/sourcegraph/sourcegraph/lib/enterpriseportal/subscriptions/v1/v1connect" "github.com/sourcegraph/sourcegraph/lib/errors" "github.com/sourcegraph/sourcegraph/lib/managedservicesplatform/iam" + "github.com/sourcegraph/sourcegraph/lib/pointers" "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/connectutil" + "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database" "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/subscriptions" "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/dotcomdb" "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/samsm2m" @@ -328,31 +330,41 @@ func (s *handlerV1) UpdateEnterpriseSubscription(ctx context.Context, req *conne // Empty field paths means update all non-empty fields. if len(fieldPaths) == 0 { if v := req.Msg.GetSubscription().GetInstanceDomain(); v != "" { - opts.InstanceDomain = v + opts.InstanceDomain = pointers.Ptr(database.NewNullString(v)) } if v := req.Msg.GetSubscription().GetDisplayName(); v != "" { - opts.DisplayName = v + opts.DisplayName = pointers.Ptr(database.NewNullString(v)) } } else { for _, p := range fieldPaths { switch p { case "instance_domain": - opts.InstanceDomain = req.Msg.GetSubscription().GetInstanceDomain() + opts.InstanceDomain = pointers.Ptr( + database.NewNullString(req.Msg.GetSubscription().GetInstanceDomain()), + ) case "display_name": - opts.DisplayName = req.Msg.GetSubscription().GetDisplayName() + opts.DisplayName = pointers.Ptr( + database.NewNullString(req.Msg.GetSubscription().GetDisplayName()), + ) case "*": opts.ForceUpdate = true - opts.InstanceDomain = req.Msg.GetSubscription().GetInstanceDomain() + opts.InstanceDomain = pointers.Ptr( + database.NewNullString(req.Msg.GetSubscription().GetInstanceDomain()), + ) + opts.DisplayName = pointers.Ptr( + database.NewNullString(req.Msg.GetSubscription().GetDisplayName()), + ) } } } // Validate and normalize the domain - if opts.InstanceDomain != "" { - opts.InstanceDomain, err = subscriptionsv1.NormalizeInstanceDomain(opts.InstanceDomain) + if opts.InstanceDomain != nil && opts.InstanceDomain.Valid { + normalizedDomain, err := subscriptionsv1.NormalizeInstanceDomain(opts.InstanceDomain.String) if err != nil { return nil, connect.NewError(connect.CodeInvalidArgument, errors.Wrap(err, "invalid instance domain")) } + opts.InstanceDomain.String = normalizedDomain } subscription, err := s.store.UpsertEnterpriseSubscription(ctx, subscriptionID, opts)