feat/enterpriseportal: db layer for subscription licenses (#63792)

Implements CRUD on the new licenses DB. I had to make significant
changes from the initial setup after spending more time working on this.

There's lots of schema changes but that's okay, as we have no data yet.

As in the RPC design, this is intended to accommodate new "types" of
licensing in the future, and so the DB is structured as such as well.
There's also feedback that context around license management events is
very useful - this is encoded in the conditions table, and can be
extended to include more types of conditions in the future.

Part of https://linear.app/sourcegraph/issue/CORE-158
Part of https://linear.app/sourcegraph/issue/CORE-100

## Test plan

Integration tests

Locally, running `sg run enterprise-portal` indicates migrations proceed
as expected
This commit is contained in:
Robert Lin 2024-07-15 11:47:51 -07:00 committed by GitHub
parent 879646a20e
commit 795f0bbc72
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
20 changed files with 964 additions and 80 deletions

View File

@ -5,12 +5,14 @@ go_library(
srcs = [
"database.go",
"migrate.go",
"types.go",
],
importpath = "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database",
tags = [TAG_INFRA_CORESERVICES],
visibility = ["//cmd/enterprise-portal:__subpackages__"],
deps = [
"//cmd/enterprise-portal/internal/database/internal/tables",
"//cmd/enterprise-portal/internal/database/internal/tables/custommigrator",
"//cmd/enterprise-portal/internal/database/subscriptions",
"//lib/errors",
"//lib/managedservicesplatform/runtime",

View File

@ -4,7 +4,7 @@ import "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/databa
type CodyGatewayAccess struct {
// ⚠️ DO NOT USE: This field is only used for creating foreign key constraint.
Subscription *subscriptions.Subscription `gorm:"foreignKey:SubscriptionID"`
Subscription *subscriptions.TableSubscription `gorm:"foreignKey:SubscriptionID"`
// SubscriptionID is the internal unprefixed UUID of the related subscription.
SubscriptionID string `gorm:"type:uuid;not null;unique"`

View File

@ -7,7 +7,9 @@ go_library(
tags = [TAG_INFRA_CORESERVICES],
visibility = ["//cmd/enterprise-portal:__subpackages__"],
deps = [
"//cmd/enterprise-portal/internal/database/internal/tables/custommigrator",
"//internal/database/dbtest",
"@com_github_jackc_pgx_v5//:pgx",
"@com_github_jackc_pgx_v5//pgxpool",
"@com_github_stretchr_testify//require",
"@io_gorm_driver_postgres//:postgres",

View File

@ -3,17 +3,20 @@ package databasetest
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"testing"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/stretchr/testify/require"
"gorm.io/driver/postgres"
"gorm.io/gorm"
"gorm.io/gorm/schema"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables/custommigrator"
"github.com/sourcegraph/sourcegraph/internal/database/dbtest"
)
@ -57,6 +60,11 @@ func NewTestDB(t testing.TB, system, suite string, tables ...schema.Tabler) *pgx
for _, table := range tables {
err = db.AutoMigrate(table)
require.NoError(t, err)
if m, ok := table.(custommigrator.CustomTableMigrator); ok {
if err := m.RunCustomMigrations(db.Migrator()); err != nil {
require.NoError(t, err)
}
}
}
// Close the connection used to auto-migrate the database.
@ -66,7 +74,12 @@ func NewTestDB(t testing.TB, system, suite string, tables ...schema.Tabler) *pgx
require.NoError(t, err)
// Open a new connection to the test suite database.
testDB, err := pgxpool.New(context.Background(), dsn.String())
dbConfig, err := pgxpool.ParseConfig(dsn.String())
require.NoError(t, err)
if testing.Verbose() {
dbConfig.ConnConfig.Tracer = pgxTestTracer{TB: t}
}
testDB, err := pgxpool.NewWithConfig(context.Background(), dbConfig)
require.NoError(t, err)
t.Cleanup(func() {
@ -110,3 +123,43 @@ func ClearTablesAfterTest(t *testing.T, db *pgxpool.Pool, tables ...schema.Table
}
})
}
// pgxTestTracer implements various pgx tracing hooks for dumping diagnostics
// in testing.
type pgxTestTracer struct{ testing.TB }
// Select tracing hooks we want to implement.
var (
_ pgx.QueryTracer = pgxTestTracer{}
)
func (t pgxTestTracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
var args []string
if len(data.Args) > 0 {
// Divider for readability
args = append(args, "\n---")
}
for _, arg := range data.Args {
data, err := json.MarshalIndent(arg, "", " ")
if err != nil {
args = append(args, fmt.Sprintf("marshal %T: %+v", arg, err))
}
args = append(args, string(data))
}
t.Logf(`pgx.QueryStart db=%q
%s%s`,
conn.Config().Database,
strings.TrimSpace(data.SQL),
strings.Join(args, "\n"))
return ctx
}
func (t pgxTestTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
if data.Err != nil {
t.Logf(`pgx.QueryEnd db=%q tag=%q error=%q`,
conn.Config().Database,
data.CommandTag.String(),
data.Err)
}
}

View File

@ -0,0 +1,9 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
go_library(
name = "custommigrator",
srcs = ["custommigrator.go"],
importpath = "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables/custommigrator",
visibility = ["//cmd/enterprise-portal:__subpackages__"],
deps = ["@io_gorm_gorm//:gorm"],
)

View File

@ -0,0 +1,9 @@
package custommigrator
import "gorm.io/gorm"
type CustomTableMigrator interface {
// RunCustomMigrations is called after all other migrations have been run.
// It can implement custom migrations.
RunCustomMigrations(migrator gorm.Migrator) error
}

View File

@ -12,9 +12,9 @@ import (
// ⚠️ WARNING: This list is meant to be read-only.
func All() []schema.Tabler {
return []schema.Tabler{
&subscriptions.Subscription{},
&subscriptions.TableSubscription{},
&subscriptions.SubscriptionCondition{},
&subscriptions.SubscriptionLicense{},
&subscriptions.TableSubscriptionLicense{},
&subscriptions.SubscriptionLicenseCondition{},
&codyaccess.CodyGatewayAccess{},

View File

@ -0,0 +1,12 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
go_library(
name = "utctime",
srcs = ["utctime.go"],
importpath = "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/utctime",
visibility = ["//cmd/enterprise-portal:__subpackages__"],
deps = [
"//lib/errors",
"//lib/pointers",
],
)

View File

@ -0,0 +1,80 @@
package utctime
import (
"database/sql"
"database/sql/driver"
"encoding/json"
"time"
"github.com/sourcegraph/sourcegraph/lib/errors"
"github.com/sourcegraph/sourcegraph/lib/pointers"
)
// Time is a wrapper around time.Time that implements the database/sql.Scanner
// and database/sql/driver.Valuer interfaces to serialize and deserialize time
// in UTC time zone.
//
// Time ensures that time.Time values are always:
//
// - represented in UTC for consistency
// - rounded to microsecond precision
//
// We round the time because PostgreSQL times are represented in microseconds:
// https://www.postgresql.org/docs/current/datatype-datetime.html
type Time time.Time
// Now returns the current time in UTC.
func Now() Time { return Time(time.Now()) }
// FromTime returns a utctime.Time from a time.Time.
func FromTime(t time.Time) Time { return Time(t.UTC().Round(time.Microsecond)) }
var _ sql.Scanner = (*Time)(nil)
func (t *Time) Scan(src any) error {
if src == nil {
return nil
}
if v, ok := src.(time.Time); ok {
*t = FromTime(v)
return nil
}
return errors.Newf("value %T is not time.Time", src)
}
var _ driver.Valuer = (*Time)(nil)
// Value must be called with a non-nil Time. driver.Valuer callers will first
// check that the value is non-nil, so this is safe.
func (t Time) Value() (driver.Value, error) {
stdTime := t.GetTime()
return *stdTime, nil
}
var _ json.Marshaler = (*Time)(nil)
func (t Time) MarshalJSON() ([]byte, error) { return json.Marshal(t.GetTime()) }
var _ json.Unmarshaler = (*Time)(nil)
func (t *Time) UnmarshalJSON(data []byte) error {
var stdTime time.Time
if err := json.Unmarshal(data, &stdTime); err != nil {
return err
}
*t = FromTime(stdTime)
return nil
}
// GetTime returns the underlying time.GetTime value, or nil if it is nil.
func (t *Time) GetTime() *time.Time {
if t == nil {
return nil
}
return pointers.Ptr(t.AsTime())
}
// Time casts the Time as a standard time.Time value.
func (t Time) AsTime() time.Time {
return time.Time(t).UTC().Round(time.Microsecond)
}

View File

@ -21,6 +21,7 @@ import (
"github.com/sourcegraph/sourcegraph/lib/redislock"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables/custommigrator"
)
// maybeMigrate runs the auto-migration for the database when needed based on
@ -42,6 +43,12 @@ func maybeMigrate(ctx context.Context, logger log.Logger, contract runtime.Contr
}
span.End()
}()
logger = logger.
WithTrace(log.TraceContext{
TraceID: span.SpanContext().TraceID().String(),
SpanID: span.SpanContext().SpanID().String(),
}).
With(log.String("database", dbName))
sqlDB, err := contract.PostgreSQL.OpenDatabase(ctx, dbName)
if err != nil {
@ -83,17 +90,20 @@ func maybeMigrate(ctx context.Context, logger log.Logger, contract runtime.Contr
span.AddEvent("lock.acquired")
versionKey := fmt.Sprintf("%s:db_version", dbName)
liveVersion := redisClient.Get(ctx, versionKey).Val()
if shouldSkipMigration(
redisClient.Get(ctx, versionKey).Val(),
liveVersion,
currentVersion,
) {
logger.Info("skipped auto-migration",
log.String("database", dbName),
log.String("currentVersion", currentVersion),
)
span.SetAttributes(attribute.Bool("skipped", true))
return nil
}
logger.Info("executing auto-migration",
log.String("liveVersion", liveVersion),
log.String("currentVersion", currentVersion))
span.SetAttributes(attribute.Bool("skipped", false))
// Create a session that ignore debug logging.
@ -108,6 +118,11 @@ func maybeMigrate(ctx context.Context, logger log.Logger, contract runtime.Contr
if err != nil {
return errors.Wrapf(err, "auto migrating table for %s", errors.Safe(fmt.Sprintf("%T", table)))
}
if m, ok := table.(custommigrator.CustomTableMigrator); ok {
if err := m.RunCustomMigrations(sess.Migrator()); err != nil {
return errors.Wrapf(err, "running custom migrations for %s", errors.Safe(fmt.Sprintf("%T", table)))
}
}
}
return redisClient.Set(ctx, versionKey, currentVersion, 0).Err()

View File

@ -14,27 +14,39 @@ go_library(
visibility = ["//cmd/enterprise-portal:__subpackages__"],
deps = [
"//cmd/enterprise-portal/internal/database/internal/upsert",
"//cmd/enterprise-portal/internal/database/internal/utctime",
"//internal/license",
"//lib/enterpriseportal/subscriptions/v1:subscriptions",
"//lib/errors",
"//lib/pointers",
"@com_github_jackc_pgtype//:pgtype",
"@com_github_google_uuid//:uuid",
"@com_github_jackc_pgx_v5//:pgx",
"@com_github_jackc_pgx_v5//pgxpool",
"@io_gorm_gorm//:gorm",
],
)
go_test(
name = "subscriptions_test",
srcs = ["subscriptions_test.go"],
srcs = [
"licenses_test.go",
"subscriptions_test.go",
],
tags = [
TAG_INFRA_CORESERVICES,
"requires-network",
],
deps = [
":subscriptions",
"//cmd/enterprise-portal/internal/database",
"//cmd/enterprise-portal/internal/database/databasetest",
"//cmd/enterprise-portal/internal/database/internal/tables",
"//cmd/enterprise-portal/internal/database/internal/utctime",
"//internal/license",
"//lib/pointers",
"@com_github_google_uuid//:uuid",
"@com_github_hexops_autogold_v2//:autogold",
"@com_github_hexops_valast//:valast",
"@com_github_jackc_pgx_v5//:pgx",
"@com_github_stretchr_testify//assert",
"@com_github_stretchr_testify//require",

View File

@ -1,11 +1,17 @@
package subscriptions
import "time"
import (
"context"
"github.com/jackc/pgx/v5"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/utctime"
subscriptionsv1 "github.com/sourcegraph/sourcegraph/lib/enterpriseportal/subscriptions/v1"
"github.com/sourcegraph/sourcegraph/lib/errors"
"github.com/sourcegraph/sourcegraph/lib/pointers"
)
type SubscriptionLicenseCondition struct {
// ⚠️ DO NOT USE: This field is only used for creating foreign key constraint.
License *SubscriptionLicense `gorm:"foreignKey:LicenseID"`
// SubscriptionID is the internal unprefixed UUID of the related license.
LicenseID string `gorm:"type:uuid;not null"`
// Status is the type of status corresponding to this condition, corresponding
@ -15,9 +21,73 @@ type SubscriptionLicenseCondition struct {
Message *string `gorm:"size:256"`
// TransitionTime is the time at which the condition was created, i.e. when
// the license transitioned into this status.
TransitionTime time.Time `gorm:"not null;default:current_timestamp"`
TransitionTime utctime.Time `gorm:"not null;default:current_timestamp"`
}
func (s *SubscriptionLicenseCondition) TableName() string {
func (*SubscriptionLicenseCondition) TableName() string {
return "enterprise_portal_subscription_license_conditions"
}
// subscriptionLicenseConditionJSONBAgg must be used with:
//
// LEFT JOIN
// enterprise_portal_subscription_license_conditions license_condition
// ON license_condition.license_id = id
// GROUP BY
// id
//
// The conditions are aggregated in JSON to 'conditions', which can be directly
// unmarshaled into the 'SubscriptionLicenseCondition' type using 'pgx'.
func subscriptionLicenseConditionJSONBAgg() string {
return `
jsonb_agg(
jsonb_build_object(
'Status', license_condition.status,
'Message', license_condition.message,
'TransitionTime', license_condition.transition_time
)
ORDER BY license_condition.transition_time DESC
) AS conditions`
}
type licenseConditionsStore struct{ tx pgx.Tx }
// newLicenseConditionsStore is meant to be used exclusively in the context of
// a transaction, where the parent license is being updated at the same time.
//
// The caller owns the transaction lifecycle.
func newLicenseConditionsStore(tx pgx.Tx) *licenseConditionsStore {
return &licenseConditionsStore{tx: tx}
}
type createLicenseConditionOpts struct {
Status subscriptionsv1.EnterpriseSubscriptionLicenseCondition_Status
Message string
TransitionTime utctime.Time
}
func (s *licenseConditionsStore) createLicenseCondition(ctx context.Context, licenseID string, opts createLicenseConditionOpts) error {
if opts.TransitionTime.GetTime().IsZero() {
return errors.New("transition time is required")
}
_, err := s.tx.Exec(ctx, `
INSERT INTO enterprise_portal_subscription_license_conditions (
license_id,
status,
message,
transition_time
)
VALUES (
@licenseID,
@status,
@message,
@transitionTime
)`, pgx.NamedArgs{
"licenseID": licenseID,
// Convert to string representation of EnterpriseSubscriptionLicenseCondition
"status": subscriptionsv1.EnterpriseSubscriptionLicenseCondition_Status_name[int32(opts.Status)],
"message": pointers.NilIfZero(opts.Message),
"transitionTime": opts.TransitionTime,
})
return err
}

View File

@ -1,15 +1,49 @@
package subscriptions
import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/jackc/pgtype"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"gorm.io/gorm"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/utctime"
internallicense "github.com/sourcegraph/sourcegraph/internal/license"
subscriptionsv1 "github.com/sourcegraph/sourcegraph/lib/enterpriseportal/subscriptions/v1"
"github.com/sourcegraph/sourcegraph/lib/pointers"
"github.com/sourcegraph/sourcegraph/lib/errors"
)
type SubscriptionLicense struct {
// ⚠️ DO NOT USE: This type is only used for creating foreign key constraints
// and initializing tables with gorm.
type TableSubscriptionLicense struct {
// ⚠️ DO NOT USE: This field is only used for creating foreign key constraint.
Subscription *Subscription `gorm:"foreignKey:SubscriptionID"`
Conditions *[]SubscriptionLicenseCondition `gorm:"foreignKey:LicenseID"`
SubscriptionLicense
}
func (*TableSubscriptionLicense) TableName() string {
return "enterprise_portal_subscription_licenses"
}
// Implement tables.CustomMigrator
func (s *TableSubscriptionLicense) RunCustomMigrations(migrator gorm.Migrator) error {
if migrator.HasColumn(s, "license_kind") {
if err := migrator.DropColumn(s, "license_kind"); err != nil {
return err
}
}
return nil
}
type SubscriptionLicense struct {
// SubscriptionID is the internal unprefixed UUID of the related subscription.
SubscriptionID string `gorm:"type:uuid;not null"`
// ID is the internal unprefixed UUID of this license.
@ -19,21 +53,331 @@ type SubscriptionLicense struct {
// to this subscription.
//
// Condition transition details are tracked in 'enterprise_portal_subscription_license_conditions'.
CreatedAt time.Time `gorm:"not null;default:current_timestamp"`
RevokedAt *time.Time // Null indicates the licnese is not revoked.
CreatedAt utctime.Time `gorm:"not null;default:current_timestamp"`
RevokedAt *utctime.Time // Null indicates the license is not revoked.
// LicenseKind is the kind of license stored in LicenseData, corresponding
// ExpireAt is the time at which the license should expire. Expiration does
// NOT get a corresponding condition entry in 'enterprise_portal_subscription_license_conditions'.
ExpireAt utctime.Time `gorm:"not null"`
// LicenseType is the kind of license stored in LicenseData, corresponding
// to the API 'EnterpriseSubscriptionLicenseType'.
LicenseKind string `gorm:"not null"`
LicenseType string `gorm:"not null"`
// LicenseData is the license data stored in JSON format. It is read-only
// and generally never queried in conditions - properties that are should
// be stored at the subscription or license level.
//
// Value shapes correspond to API types appropriate for each
// 'EnterpriseSubscriptionLicenseType'.
LicenseData pgtype.JSONB `gorm:"type:jsonb"`
LicenseData json.RawMessage `gorm:"type:jsonb"`
}
func (s *SubscriptionLicense) TableName() string {
return "enterprise_portal_subscription_licenses"
// subscriptionLicenseWithConditionsColumns must match scanSubscriptionLicense()
// values.
func subscriptionLicenseWithConditionsColumns() []string {
return []string{
"subscription_id",
"id",
"created_at",
"revoked_at",
"expire_at",
"license_type",
"license_data",
subscriptionLicenseConditionJSONBAgg(),
}
}
type LicenseWithConditions struct {
SubscriptionLicense
Conditions []SubscriptionLicenseCondition
}
// scanSubscription matches subscriptionTableColumns() values.
func scanSubscriptionLicenseWithConditions(row pgx.Row) (*LicenseWithConditions, error) {
var l LicenseWithConditions
err := row.Scan(
&l.SubscriptionID,
&l.ID,
&l.CreatedAt,
&l.RevokedAt,
&l.ExpireAt,
&l.LicenseType,
&l.LicenseData,
&l.Conditions, // see subscriptionLicenseConditionJSONBAgg docstring
)
return &l, err
}
// LicensesStore manages licenses belonging to Enterprise subscriptions.
//
// Licenses can only be created and revoked - they can never be updated.
type LicensesStore struct {
db *pgxpool.Pool
}
func NewLicensesStore(db *pgxpool.Pool) *LicensesStore {
return &LicensesStore{
db: db,
}
}
type ListLicensesOpts struct {
SubscriptionID string
// PageSize is the maximum number of licenses to return.
PageSize int
}
func (opts ListLicensesOpts) toQueryConditions() (where, limitClause string, _ pgx.NamedArgs) {
whereConds := []string{"TRUE"}
namedArgs := pgx.NamedArgs{}
if opts.SubscriptionID != "" {
whereConds = append(whereConds, "subscription_id = @subscriptionID")
namedArgs["subscriptionID"] = opts.SubscriptionID
}
where = strings.Join(whereConds, " AND ")
if opts.PageSize > 0 {
limitClause = "LIMIT @pageSize"
namedArgs["pageSize"] = opts.PageSize
}
return where, limitClause, namedArgs
}
func (s *LicensesStore) List(ctx context.Context, opts ListLicensesOpts) ([]*LicenseWithConditions, error) {
where, limitClause, namedArgs := opts.toQueryConditions()
query := fmt.Sprintf(`
SELECT
%s
FROM
enterprise_portal_subscription_licenses
LEFT JOIN
enterprise_portal_subscription_license_conditions license_condition
ON license_condition.license_id = id
WHERE
%s
GROUP BY
id
ORDER BY
created_at DESC
%s`,
strings.Join(subscriptionLicenseWithConditionsColumns(), ", "),
where, limitClause)
rows, err := s.db.Query(ctx, query, namedArgs)
if err != nil {
return nil, errors.Wrap(err, "query rows")
}
defer rows.Close()
var licenses []*LicenseWithConditions
for rows.Next() {
license, err := scanSubscriptionLicenseWithConditions(rows)
if err != nil {
return nil, errors.Wrap(err, "scan row")
}
licenses = append(licenses, license)
}
return licenses, rows.Err()
}
func (s *LicensesStore) Get(ctx context.Context, licenseID string) (*LicenseWithConditions, error) {
query := fmt.Sprintf(`
SELECT
%s
FROM
enterprise_portal_subscription_licenses
LEFT JOIN
enterprise_portal_subscription_license_conditions license_condition
ON license_condition.license_id = id
WHERE
id = @licenseID
GROUP BY
id`,
strings.Join(subscriptionLicenseWithConditionsColumns(), ", "))
license, err := scanSubscriptionLicenseWithConditions(
s.db.QueryRow(ctx, query, pgx.NamedArgs{
"licenseID": licenseID,
}),
)
if err != nil {
return nil, errors.Wrap(err, "query rows")
}
return license, nil
}
type CreateLicenseOpts struct {
Message string
// If nil, the creation time will be set to the current time.
Time *utctime.Time
// Expiration time of the license.
ExpireTime utctime.Time
}
// LicenseKey corresponds to *subscriptionsv1.EnterpriseSubscriptionLicenseKey
// and the 'ENTERPRISE_SUBSCRIPTION_LICENSE_TYPE_KEY' license type.
type LicenseKey struct {
Info internallicense.Info
// Signed license key with the license information in Info.
SignedKey string
}
// CreateLicense creates a new classic offline license for the given subscription.
func (s *LicensesStore) CreateLicenseKey(
ctx context.Context,
subscriptionID string,
license *LicenseKey,
opts CreateLicenseOpts,
) (_ *LicenseWithConditions, err error) {
// Special behaviour: the license key embeds the creation time, and it must
// match the time provided in the options.
if opts.Time == nil {
return nil, errors.New("creation time must be specified for licensekeys")
} else if !opts.Time.GetTime().Equal(utctime.FromTime(license.Info.CreatedAt).AsTime()) {
return nil, errors.New("creation time must match the license key information")
}
if !opts.ExpireTime.GetTime().Equal(utctime.FromTime(license.Info.ExpiresAt).AsTime()) {
return nil, errors.New("expiration time must match the license key information")
}
return s.create(
ctx,
subscriptionID,
subscriptionsv1.EnterpriseSubscriptionLicenseType_ENTERPRISE_SUBSCRIPTION_LICENSE_TYPE_KEY,
license,
opts,
)
}
func (s *LicensesStore) create(
ctx context.Context,
subscriptionID string,
licenseType subscriptionsv1.EnterpriseSubscriptionLicenseType,
license any,
opts CreateLicenseOpts,
) (_ *LicenseWithConditions, err error) {
if subscriptionID == "" {
return nil, errors.New("subscription ID must be specified")
}
if opts.Time == nil {
opts.Time = pointers.Ptr(utctime.Now())
} else if opts.Time.GetTime().After(time.Now()) {
return nil, errors.New("creation time cannot be in the future")
}
if licenseType == subscriptionsv1.EnterpriseSubscriptionLicenseType_ENTERPRISE_SUBSCRIPTION_LICENSE_TYPE_UNSPECIFIED {
return nil, errors.New("license type must be specified")
}
licenseID, err := uuid.NewV7()
if err != nil {
return nil, errors.Wrap(err, "generate uuid")
}
licenseData, err := json.Marshal(license)
if err != nil {
return nil, errors.Wrap(err, "marshal license data")
}
tx, err := s.db.Begin(ctx)
if err != nil {
return nil, errors.Wrap(err, "begin transaction")
}
defer func() {
if rollbackErr := tx.Rollback(context.Background()); rollbackErr != nil {
err = errors.Append(err, errors.Wrap(err, "rollback"))
}
}()
if _, err := tx.Exec(ctx, `
INSERT INTO enterprise_portal_subscription_licenses (
id,
subscription_id,
license_type,
license_data,
created_at,
expire_at
)
VALUES (
@licenseID,
@subscriptionID,
@licenseType,
@licenseData,
@createdAt,
@expireAt
)
`, pgx.NamedArgs{
"licenseID": licenseID.String(),
"subscriptionID": subscriptionID,
"licenseType": subscriptionsv1.EnterpriseSubscriptionLicenseType_name[int32(licenseType)],
"licenseData": licenseData,
"createdAt": opts.Time,
"expireAt": opts.ExpireTime,
}); err != nil {
return nil, errors.Wrap(err, "create license")
}
if err := newLicenseConditionsStore(tx).createLicenseCondition(ctx, licenseID.String(), createLicenseConditionOpts{
Status: subscriptionsv1.EnterpriseSubscriptionLicenseCondition_STATUS_CREATED,
Message: opts.Message,
TransitionTime: *opts.Time,
}); err != nil {
return nil, errors.Wrap(err, "create license condition")
}
if err := tx.Commit(ctx); err != nil {
return nil, errors.Wrap(err, "commit transaction")
}
return s.Get(ctx, licenseID.String())
}
type RevokeLicenseOpts struct {
Message string
// If nil, the revocation time will be set to the current time.
Time *utctime.Time
}
// Revoke marks the given license as revoked.
func (s *LicensesStore) Revoke(ctx context.Context, licenseID string, opts RevokeLicenseOpts) (*LicenseWithConditions, error) {
if opts.Time == nil {
opts.Time = pointers.Ptr(utctime.Now())
} else if opts.Time.GetTime().After(time.Now()) {
return nil, errors.New("revocation time cannot be in the future")
}
tx, err := s.db.Begin(ctx)
if err != nil {
return nil, errors.Wrap(err, "begin transaction")
}
defer func() {
if rollbackErr := tx.Rollback(context.Background()); rollbackErr != nil {
err = errors.Append(err, rollbackErr)
}
}()
if _, err := tx.Exec(ctx, `
UPDATE enterprise_portal_subscription_licenses
SET revoked_at = COALESCE(revoked_at, @revokedAt) -- use existing revoke time if already revoked
WHERE id = @licenseID
`, pgx.NamedArgs{
"revokedAt": opts.Time,
"licenseID": licenseID,
}); err != nil {
return nil, errors.Wrap(err, "revoke license")
}
if err := newLicenseConditionsStore(tx).createLicenseCondition(ctx, licenseID, createLicenseConditionOpts{
Status: subscriptionsv1.EnterpriseSubscriptionLicenseCondition_STATUS_REVOKED,
Message: opts.Message,
TransitionTime: *opts.Time,
}); err != nil {
return nil, errors.Wrap(err, "create license condition")
}
if err := tx.Commit(ctx); err != nil {
return nil, errors.Wrap(err, "commit transaction")
}
return s.Get(ctx, licenseID)
}

View File

@ -0,0 +1,238 @@
package subscriptions_test
import (
"context"
"fmt"
"testing"
"time"
"github.com/google/uuid"
"github.com/hexops/autogold/v2"
"github.com/hexops/valast"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/databasetest"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/utctime"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/subscriptions"
"github.com/sourcegraph/sourcegraph/internal/license"
"github.com/sourcegraph/sourcegraph/lib/pointers"
)
func TestLicensesStore(t *testing.T) {
t.Parallel()
ctx := context.Background()
db := databasetest.NewTestDB(t, "enterprise-portal", t.Name(), tables.All()...)
subscriptionID1 := uuid.NewString()
subscriptionID2 := uuid.NewString()
subs := subscriptions.NewStore(db)
_, err := subs.Upsert(ctx, subscriptionID1, subscriptions.UpsertSubscriptionOptions{
DisplayName: pointers.Ptr(database.NewNullString("Acme, Inc. 1")),
})
require.NoError(t, err)
_, err = subs.Upsert(ctx, subscriptionID2, subscriptions.UpsertSubscriptionOptions{
DisplayName: pointers.Ptr(database.NewNullString("Acme, Inc. 2")),
})
require.NoError(t, err)
licenses := subscriptions.NewLicensesStore(db)
var createdLicenses []*subscriptions.LicenseWithConditions
getCreatedByLicenseID := func(t *testing.T, licenseID string) *subscriptions.LicenseWithConditions {
for _, l := range createdLicenses {
if l.ID == licenseID {
return l
}
}
t.Errorf("license %q not found", licenseID)
t.FailNow()
return nil
}
t.Run("CreateLicenseKey", func(t *testing.T) {
testLicense := func(
got *subscriptions.LicenseWithConditions,
wantMessage autogold.Value,
wantLicenseData autogold.Value,
) {
assert.NotEmpty(t, got.ID)
assert.NotZero(t, got.CreatedAt)
assert.NotZero(t, got.ExpireAt)
assert.Equal(t, "ENTERPRISE_SUBSCRIPTION_LICENSE_TYPE_KEY", got.LicenseType)
wantLicenseData.Equal(t, string(got.LicenseData))
assert.Len(t, got.Conditions, 1)
wantMessage.Equal(t, got.Conditions[0].Message)
assert.Equal(t, "STATUS_CREATED", got.Conditions[0].Status)
assert.Equal(t, got.CreatedAt, got.Conditions[0].TransitionTime)
}
got, err := licenses.CreateLicenseKey(ctx, subscriptionID1,
&subscriptions.LicenseKey{
Info: license.Info{
Tags: []string{"foo"},
CreatedAt: time.Time{}.Add(1 * time.Hour),
ExpiresAt: time.Time{}.Add(48 * time.Hour),
},
SignedKey: "asdfasdf",
},
subscriptions.CreateLicenseOpts{
Message: t.Name() + " 1 old",
Time: pointers.Ptr(utctime.FromTime(time.Time{}.Add(1 * time.Hour))),
ExpireTime: utctime.FromTime(time.Time{}.Add(48 * time.Hour)),
})
require.NoError(t, err)
testLicense(
got,
autogold.Expect(valast.Ptr("TestLicensesStore/CreateLicenseKey 1 old")),
autogold.Expect(`{"Info": {"c": "0001-01-01T01:00:00Z", "e": "0001-01-03T00:00:00Z", "t": ["foo"], "u": 0}, "SignedKey": "asdfasdf"}`),
)
createdLicenses = append(createdLicenses, got)
got, err = licenses.CreateLicenseKey(ctx, subscriptionID1,
&subscriptions.LicenseKey{
Info: license.Info{
Tags: []string{"baz"},
CreatedAt: time.Time{}.Add(24 * time.Hour),
ExpiresAt: time.Time{}.Add(48 * time.Hour),
},
SignedKey: "barasdf",
},
subscriptions.CreateLicenseOpts{
Message: t.Name() + " 1",
Time: pointers.Ptr(utctime.FromTime(time.Time{}.Add(24 * time.Hour))),
ExpireTime: utctime.FromTime(time.Time{}.Add(48 * time.Hour)),
})
require.NoError(t, err)
testLicense(
got,
autogold.Expect(valast.Ptr("TestLicensesStore/CreateLicenseKey 1")),
autogold.Expect(`{"Info": {"c": "0001-01-02T00:00:00Z", "e": "0001-01-03T00:00:00Z", "t": ["baz"], "u": 0}, "SignedKey": "barasdf"}`),
)
createdLicenses = append(createdLicenses, got)
got, err = licenses.CreateLicenseKey(ctx, subscriptionID2,
&subscriptions.LicenseKey{
Info: license.Info{
Tags: []string{"tag"},
CreatedAt: time.Time{}.Add(24 * time.Hour),
ExpiresAt: time.Time{}.Add(48 * time.Hour),
},
SignedKey: "asdffdsadf",
},
subscriptions.CreateLicenseOpts{
Message: t.Name() + " 2",
Time: pointers.Ptr(utctime.FromTime(time.Time{}.Add(24 * time.Hour))),
ExpireTime: utctime.FromTime(time.Time{}.Add(48 * time.Hour)),
})
require.NoError(t, err)
testLicense(
got,
autogold.Expect(valast.Ptr("TestLicensesStore/CreateLicenseKey 2")),
autogold.Expect(`{"Info": {"c": "0001-01-02T00:00:00Z", "e": "0001-01-03T00:00:00Z", "t": ["tag"], "u": 0}, "SignedKey": "asdffdsadf"}`),
)
createdLicenses = append(createdLicenses, got)
t.Run("createdAt does not match", func(t *testing.T) {
_, err = licenses.CreateLicenseKey(ctx, subscriptionID2,
&subscriptions.LicenseKey{
Info: license.Info{
Tags: []string{"tag"},
CreatedAt: time.Time{}.Add(24 * time.Hour),
},
SignedKey: "asdffdsadf",
},
subscriptions.CreateLicenseOpts{
Message: t.Name(),
Time: pointers.Ptr(utctime.Now()),
})
require.Error(t, err)
autogold.Expect("creation time must match the license key information").Equal(t, err.Error())
})
t.Run("expiresAt does not match", func(t *testing.T) {
_, err = licenses.CreateLicenseKey(ctx, subscriptionID2,
&subscriptions.LicenseKey{
Info: license.Info{
Tags: []string{"tag"},
CreatedAt: time.Time{},
ExpiresAt: time.Time{}.Add(48 * time.Hour),
},
SignedKey: "asdffdsadf",
},
subscriptions.CreateLicenseOpts{
Message: t.Name(),
Time: pointers.Ptr(utctime.FromTime(time.Time{})),
ExpireTime: utctime.Now(),
})
require.Error(t, err)
autogold.Expect("expiration time must match the license key information").Equal(t, err.Error())
})
})
// No point continuing if test licenses did not create, all tests after this
// will fail
if t.Failed() {
t.FailNow()
}
t.Run("List", func(t *testing.T) {
listedLicenses, err := licenses.List(ctx, subscriptions.ListLicensesOpts{})
require.NoError(t, err)
assert.Len(t, listedLicenses, len(createdLicenses))
for _, l := range listedLicenses {
created := getCreatedByLicenseID(t, l.ID)
assert.Equal(t, *created, *l)
}
t.Run("List by subscription", func(t *testing.T) {
listedLicenses, err := licenses.List(ctx, subscriptions.ListLicensesOpts{
SubscriptionID: subscriptionID1,
})
require.NoError(t, err)
assert.Len(t, listedLicenses, 2)
for _, l := range listedLicenses {
assert.Equal(t, subscriptionID1, l.SubscriptionID)
assert.Equal(t, *getCreatedByLicenseID(t, l.ID), *l)
}
listedLicenses, err = licenses.List(ctx, subscriptions.ListLicensesOpts{
SubscriptionID: subscriptionID2,
})
require.NoError(t, err)
assert.Len(t, listedLicenses, 1)
for _, l := range listedLicenses {
assert.Equal(t, subscriptionID2, l.SubscriptionID)
assert.Equal(t, *getCreatedByLicenseID(t, l.ID), *l)
}
})
})
t.Run("Get", func(t *testing.T) {
for _, license := range createdLicenses {
got, err := licenses.Get(ctx, license.ID)
require.NoError(t, err)
assert.Equal(t, *license, *got)
}
})
t.Run("Revoke", func(t *testing.T) {
for idx, license := range createdLicenses {
revokeTime := utctime.FromTime(time.Now().Add(-time.Second))
got, err := licenses.Revoke(ctx, license.ID, subscriptions.RevokeLicenseOpts{
Message: fmt.Sprintf("%s %d", t.Name(), idx),
Time: pointers.Ptr(revokeTime),
})
require.NoError(t, err)
assert.Equal(t, revokeTime.AsTime(), got.RevokedAt.AsTime())
require.Len(t, got.Conditions, 2)
// Most recent condition is sorted first, and should be the revocation
assert.Equal(t, "STATUS_REVOKED", got.Conditions[0].Status)
assert.Equal(t, revokeTime.AsTime(), got.Conditions[0].TransitionTime.AsTime())
assert.Equal(t, "STATUS_CREATED", got.Conditions[1].Status)
}
})
}

View File

@ -2,6 +2,7 @@ package subscriptions
import (
"context"
"database/sql"
"fmt"
"strings"
"time"
@ -10,10 +11,26 @@ import (
"github.com/jackc/pgx/v5/pgxpool"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/upsert"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/utctime"
"github.com/sourcegraph/sourcegraph/lib/errors"
"github.com/sourcegraph/sourcegraph/lib/pointers"
)
// ⚠️ DO NOT USE: This type is only used for creating foreign key constraints
// and initializing tables with gorm.
type TableSubscription struct {
// Each Subscription has many Licenses.
Licenses []*TableSubscriptionLicense `gorm:"foreignKey:SubscriptionID"`
// Each Subscription has many Conditions.
Conditions *[]SubscriptionCondition `gorm:"foreignKey:SubscriptionID"`
Subscription
}
func (*TableSubscription) TableName() string {
return "enterprise_portal_subscriptions"
}
// Subscription is an Enterprise subscription record.
type Subscription struct {
// ID is the internal (unprefixed) UUID-format identifier for the subscription.
@ -22,7 +39,7 @@ type Subscription struct {
// "acme.sourcegraphcloud.com". This is set explicitly.
//
// It must be unique across all currently un-archived subscriptions.
InstanceDomain string `gorm:"uniqueIndex:,where:archived_at IS NULL"`
InstanceDomain *string `gorm:"uniqueIndex:,where:archived_at IS NULL"`
// WARNING: The below fields are not yet used in production.
@ -33,15 +50,15 @@ type Subscription struct {
//
// TODO: Clean up the database post-deploy and remove the 'Unnamed subscription'
// part of the constraint.
DisplayName string `gorm:"size:256;not null;uniqueIndex:,where:archived_at IS NULL AND display_name != 'Unnamed subscription' AND display_name != ''"`
DisplayName *string `gorm:"size:256;uniqueIndex:,where:archived_at IS NULL AND display_name != 'Unnamed subscription' AND display_name != ''"`
// Timestamps representing the latest timestamps of key conditions related
// to this subscription.
//
// Condition transition details are tracked in 'enterprise_portal_subscription_conditions'.
CreatedAt time.Time `gorm:"not null;default:current_timestamp"`
UpdatedAt time.Time `gorm:"not null;default:current_timestamp"`
ArchivedAt *time.Time // Null indicates the subscription is not archived.
CreatedAt utctime.Time `gorm:"not null;default:current_timestamp"`
UpdatedAt utctime.Time `gorm:"not null;default:current_timestamp"`
ArchivedAt *utctime.Time // Null indicates the subscription is not archived.
// SalesforceSubscriptionID associated with this Enterprise subscription.
SalesforceSubscriptionID *string
@ -49,11 +66,7 @@ type Subscription struct {
SalesforceOpportunityID *string
}
func (s Subscription) TableName() string {
return "enterprise_portal_subscriptions"
}
// subscriptionTableColumns must match s.scan() values.
// subscriptionTableColumns must match scanSubscription() values.
func subscriptionTableColumns() []string {
return []string{
"id",
@ -67,7 +80,7 @@ func subscriptionTableColumns() []string {
}
}
// scanSubscription matches s.columns() values.
// scanSubscription matches subscriptionTableColumns() values.
func scanSubscription(row pgx.Row) (*Subscription, error) {
var s Subscription
err := row.Scan(
@ -83,13 +96,6 @@ func scanSubscription(row pgx.Row) (*Subscription, error) {
if err != nil {
return nil, err
}
s.CreatedAt = s.CreatedAt.UTC()
s.UpdatedAt = s.UpdatedAt.UTC()
if s.ArchivedAt != nil {
s.ArchivedAt = pointers.Ptr(s.ArchivedAt.UTC())
}
return &s, nil
}
@ -172,8 +178,8 @@ WHERE %s
}
type UpsertSubscriptionOptions struct {
InstanceDomain string
DisplayName string
InstanceDomain *sql.NullString
DisplayName *sql.NullString
CreatedAt time.Time
ArchivedAt *time.Time

View File

@ -1,14 +1,9 @@
package subscriptions
import (
"time"
)
import "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/utctime"
// Subscription is an Enterprise subscription condition record.
type SubscriptionCondition struct {
// ⚠️ DO NOT USE: This field is only used for creating foreign key constraint.
Subscription *Subscription `gorm:"foreignKey:SubscriptionID"`
// SubscriptionID is the internal unprefixed UUID of the related subscription.
SubscriptionID string `gorm:"type:uuid;not null"`
// Status is the type of status corresponding to this condition, corresponding
@ -18,7 +13,7 @@ type SubscriptionCondition struct {
Message *string `gorm:"size:256"`
// TransitionTime is the time at which the condition was created, i.e. when
// the subscription transitioned into this status.
TransitionTime time.Time `gorm:"not null;default:current_timestamp"`
TransitionTime utctime.Time `gorm:"not null;default:current_timestamp"`
}
func (s *SubscriptionCondition) TableName() string {

View File

@ -10,6 +10,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/databasetest"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/subscriptions"
@ -20,7 +21,7 @@ func TestSubscriptionsStore(t *testing.T) {
t.Parallel()
ctx := context.Background()
db := databasetest.NewTestDB(t, "enterprise-portal", "SubscriptionsStore", tables.All()...)
db := databasetest.NewTestDB(t, "enterprise-portal", t.Name(), tables.All()...)
for _, tc := range []struct {
name string
@ -45,19 +46,25 @@ func SubscriptionsStoreList(t *testing.T, ctx context.Context, s *subscriptions.
s1, err := s.Upsert(
ctx,
uuid.New().String(),
subscriptions.UpsertSubscriptionOptions{InstanceDomain: "s1.sourcegraph.com"},
subscriptions.UpsertSubscriptionOptions{
InstanceDomain: pointers.Ptr(database.NewNullString("s1.sourcegraph.com")),
},
)
require.NoError(t, err)
s2, err := s.Upsert(
ctx,
uuid.New().String(),
subscriptions.UpsertSubscriptionOptions{InstanceDomain: "s2.sourcegraph.com"},
subscriptions.UpsertSubscriptionOptions{
InstanceDomain: pointers.Ptr(database.NewNullString("s2.sourcegraph.com")),
},
)
require.NoError(t, err)
_, err = s.Upsert(
ctx,
uuid.New().String(),
subscriptions.UpsertSubscriptionOptions{InstanceDomain: "s3.sourcegraph.com"},
subscriptions.UpsertSubscriptionOptions{
InstanceDomain: pointers.Ptr(database.NewNullString("s3.sourcegraph.com")),
},
)
require.NoError(t, err)
@ -79,7 +86,7 @@ func SubscriptionsStoreList(t *testing.T, ctx context.Context, s *subscriptions.
t.Run("list by instance domains", func(t *testing.T) {
ss, err := s.List(ctx, subscriptions.ListEnterpriseSubscriptionsOptions{
InstanceDomains: []string{s1.InstanceDomain, s2.InstanceDomain}},
InstanceDomains: []string{*s1.InstanceDomain, *s2.InstanceDomain}},
)
require.NoError(t, err)
require.Len(t, ss, 2)
@ -115,14 +122,16 @@ func SubscriptionsStoreUpsert(t *testing.T, ctx context.Context, s *subscription
currentSubscription, err := s.Upsert(
ctx,
uuid.New().String(),
subscriptions.UpsertSubscriptionOptions{InstanceDomain: "s1.sourcegraph.com"},
subscriptions.UpsertSubscriptionOptions{
InstanceDomain: pointers.Ptr(database.NewNullString("s1.sourcegraph.com")),
},
)
require.NoError(t, err)
got, err := s.Get(ctx, currentSubscription.ID)
require.NoError(t, err)
assert.Equal(t, currentSubscription.ID, got.ID)
assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain)
assert.Equal(t, *currentSubscription.InstanceDomain, *got.InstanceDomain)
assert.Empty(t, got.DisplayName)
assert.NotZero(t, got.CreatedAt)
assert.NotZero(t, got.UpdatedAt)
@ -133,17 +142,19 @@ func SubscriptionsStoreUpsert(t *testing.T, ctx context.Context, s *subscription
got, err = s.Upsert(ctx, currentSubscription.ID, subscriptions.UpsertSubscriptionOptions{})
require.NoError(t, err)
assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain)
assert.Equal(t,
pointers.DerefZero(currentSubscription.InstanceDomain),
pointers.DerefZero(got.InstanceDomain))
})
t.Run("update only domain", func(t *testing.T) {
t.Cleanup(func() { currentSubscription = got })
got, err = s.Upsert(ctx, currentSubscription.ID, subscriptions.UpsertSubscriptionOptions{
InstanceDomain: "s1-new.sourcegraph.com",
InstanceDomain: pointers.Ptr(database.NewNullString("s1-new.sourcegraph.com")),
})
require.NoError(t, err)
assert.Equal(t, "s1-new.sourcegraph.com", got.InstanceDomain)
assert.Equal(t, "s1-new.sourcegraph.com", pointers.DerefZero(got.InstanceDomain))
assert.Equal(t, currentSubscription.DisplayName, got.DisplayName)
})
@ -151,11 +162,11 @@ func SubscriptionsStoreUpsert(t *testing.T, ctx context.Context, s *subscription
t.Cleanup(func() { currentSubscription = got })
got, err = s.Upsert(ctx, currentSubscription.ID, subscriptions.UpsertSubscriptionOptions{
DisplayName: "My New Display Name",
DisplayName: pointers.Ptr(database.NewNullString("My New Display Name")),
})
require.NoError(t, err)
assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain)
assert.Equal(t, "My New Display Name", got.DisplayName)
assert.Equal(t, *currentSubscription.InstanceDomain, *got.InstanceDomain)
assert.Equal(t, "My New Display Name", pointers.DerefZero(got.DisplayName))
})
t.Run("update only created at", func(t *testing.T) {
@ -166,10 +177,12 @@ func SubscriptionsStoreUpsert(t *testing.T, ctx context.Context, s *subscription
CreatedAt: yesterday,
})
require.NoError(t, err)
assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain)
assert.Equal(t,
pointers.DerefZero(currentSubscription.InstanceDomain),
pointers.DerefZero(got.InstanceDomain))
assert.Equal(t, currentSubscription.DisplayName, got.DisplayName)
// Round times to allow for some precision drift in CI
assert.Equal(t, yesterday.Round(time.Second).UTC(), got.CreatedAt.Round(time.Second))
assert.Equal(t, yesterday.Round(time.Second).UTC(), got.CreatedAt.GetTime().Round(time.Second))
})
t.Run("update only archived at", func(t *testing.T) {
@ -180,11 +193,11 @@ func SubscriptionsStoreUpsert(t *testing.T, ctx context.Context, s *subscription
ArchivedAt: pointers.Ptr(yesterday),
})
require.NoError(t, err)
assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain)
assert.Equal(t, currentSubscription.DisplayName, got.DisplayName)
assert.Equal(t, *currentSubscription.InstanceDomain, *got.InstanceDomain)
assert.Equal(t, *currentSubscription.DisplayName, *got.DisplayName)
assert.Equal(t, currentSubscription.CreatedAt, got.CreatedAt)
// Round times to allow for some precision drift in CI
assert.Equal(t, yesterday.Round(time.Second).UTC(), got.ArchivedAt.Round(time.Second))
assert.Equal(t, yesterday.Round(time.Second).UTC(), got.ArchivedAt.GetTime().Round(time.Second))
})
t.Run("force update to zero values", func(t *testing.T) {
@ -209,7 +222,9 @@ func SubscriptionsStoreGet(t *testing.T, ctx context.Context, s *subscriptions.S
s1, err := s.Upsert(
ctx,
uuid.New().String(),
subscriptions.UpsertSubscriptionOptions{InstanceDomain: "s1.sourcegraph.com"},
subscriptions.UpsertSubscriptionOptions{
InstanceDomain: pointers.Ptr(database.NewNullString("s1.sourcegraph.com")),
},
)
require.NoError(t, err)

View File

@ -0,0 +1,10 @@
package database
import "database/sql"
func NewNullString(v string) sql.NullString {
return sql.NullString{
String: v,
Valid: v != "",
}
}

View File

@ -68,8 +68,8 @@ func convertSubscriptionToProto(subscription *subscriptions.Subscription, attrs
return &subscriptionsv1.EnterpriseSubscription{
Id: subscriptionsv1.EnterpriseSubscriptionIDPrefix + attrs.ID,
Conditions: conds,
InstanceDomain: subscription.InstanceDomain,
DisplayName: subscription.DisplayName,
InstanceDomain: pointers.DerefZero(subscription.InstanceDomain),
DisplayName: pointers.DerefZero(subscription.DisplayName),
}
}

View File

@ -16,8 +16,10 @@ import (
subscriptionsv1connect "github.com/sourcegraph/sourcegraph/lib/enterpriseportal/subscriptions/v1/v1connect"
"github.com/sourcegraph/sourcegraph/lib/errors"
"github.com/sourcegraph/sourcegraph/lib/managedservicesplatform/iam"
"github.com/sourcegraph/sourcegraph/lib/pointers"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/connectutil"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/subscriptions"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/dotcomdb"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/samsm2m"
@ -328,31 +330,41 @@ func (s *handlerV1) UpdateEnterpriseSubscription(ctx context.Context, req *conne
// Empty field paths means update all non-empty fields.
if len(fieldPaths) == 0 {
if v := req.Msg.GetSubscription().GetInstanceDomain(); v != "" {
opts.InstanceDomain = v
opts.InstanceDomain = pointers.Ptr(database.NewNullString(v))
}
if v := req.Msg.GetSubscription().GetDisplayName(); v != "" {
opts.DisplayName = v
opts.DisplayName = pointers.Ptr(database.NewNullString(v))
}
} else {
for _, p := range fieldPaths {
switch p {
case "instance_domain":
opts.InstanceDomain = req.Msg.GetSubscription().GetInstanceDomain()
opts.InstanceDomain = pointers.Ptr(
database.NewNullString(req.Msg.GetSubscription().GetInstanceDomain()),
)
case "display_name":
opts.DisplayName = req.Msg.GetSubscription().GetDisplayName()
opts.DisplayName = pointers.Ptr(
database.NewNullString(req.Msg.GetSubscription().GetDisplayName()),
)
case "*":
opts.ForceUpdate = true
opts.InstanceDomain = req.Msg.GetSubscription().GetInstanceDomain()
opts.InstanceDomain = pointers.Ptr(
database.NewNullString(req.Msg.GetSubscription().GetInstanceDomain()),
)
opts.DisplayName = pointers.Ptr(
database.NewNullString(req.Msg.GetSubscription().GetDisplayName()),
)
}
}
}
// Validate and normalize the domain
if opts.InstanceDomain != "" {
opts.InstanceDomain, err = subscriptionsv1.NormalizeInstanceDomain(opts.InstanceDomain)
if opts.InstanceDomain != nil && opts.InstanceDomain.Valid {
normalizedDomain, err := subscriptionsv1.NormalizeInstanceDomain(opts.InstanceDomain.String)
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, errors.Wrap(err, "invalid instance domain"))
}
opts.InstanceDomain.String = normalizedDomain
}
subscription, err := s.store.UpsertEnterpriseSubscription(ctx, subscriptionID, opts)