mirror of
https://github.com/sourcegraph/sourcegraph.git
synced 2026-02-06 17:31:43 +00:00
feat/enterpriseportal: db layer for subscription licenses (#63792)
Implements CRUD on the new licenses DB. I had to make significant changes from the initial setup after spending more time working on this. There's lots of schema changes but that's okay, as we have no data yet. As in the RPC design, this is intended to accommodate new "types" of licensing in the future, and so the DB is structured as such as well. There's also feedback that context around license management events is very useful - this is encoded in the conditions table, and can be extended to include more types of conditions in the future. Part of https://linear.app/sourcegraph/issue/CORE-158 Part of https://linear.app/sourcegraph/issue/CORE-100 ## Test plan Integration tests Locally, running `sg run enterprise-portal` indicates migrations proceed as expected
This commit is contained in:
parent
879646a20e
commit
795f0bbc72
@ -5,12 +5,14 @@ go_library(
|
||||
srcs = [
|
||||
"database.go",
|
||||
"migrate.go",
|
||||
"types.go",
|
||||
],
|
||||
importpath = "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database",
|
||||
tags = [TAG_INFRA_CORESERVICES],
|
||||
visibility = ["//cmd/enterprise-portal:__subpackages__"],
|
||||
deps = [
|
||||
"//cmd/enterprise-portal/internal/database/internal/tables",
|
||||
"//cmd/enterprise-portal/internal/database/internal/tables/custommigrator",
|
||||
"//cmd/enterprise-portal/internal/database/subscriptions",
|
||||
"//lib/errors",
|
||||
"//lib/managedservicesplatform/runtime",
|
||||
|
||||
@ -4,7 +4,7 @@ import "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/databa
|
||||
|
||||
type CodyGatewayAccess struct {
|
||||
// ⚠️ DO NOT USE: This field is only used for creating foreign key constraint.
|
||||
Subscription *subscriptions.Subscription `gorm:"foreignKey:SubscriptionID"`
|
||||
Subscription *subscriptions.TableSubscription `gorm:"foreignKey:SubscriptionID"`
|
||||
|
||||
// SubscriptionID is the internal unprefixed UUID of the related subscription.
|
||||
SubscriptionID string `gorm:"type:uuid;not null;unique"`
|
||||
|
||||
@ -7,7 +7,9 @@ go_library(
|
||||
tags = [TAG_INFRA_CORESERVICES],
|
||||
visibility = ["//cmd/enterprise-portal:__subpackages__"],
|
||||
deps = [
|
||||
"//cmd/enterprise-portal/internal/database/internal/tables/custommigrator",
|
||||
"//internal/database/dbtest",
|
||||
"@com_github_jackc_pgx_v5//:pgx",
|
||||
"@com_github_jackc_pgx_v5//pgxpool",
|
||||
"@com_github_stretchr_testify//require",
|
||||
"@io_gorm_driver_postgres//:postgres",
|
||||
|
||||
@ -3,17 +3,20 @@ package databasetest
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/schema"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables/custommigrator"
|
||||
"github.com/sourcegraph/sourcegraph/internal/database/dbtest"
|
||||
)
|
||||
|
||||
@ -57,6 +60,11 @@ func NewTestDB(t testing.TB, system, suite string, tables ...schema.Tabler) *pgx
|
||||
for _, table := range tables {
|
||||
err = db.AutoMigrate(table)
|
||||
require.NoError(t, err)
|
||||
if m, ok := table.(custommigrator.CustomTableMigrator); ok {
|
||||
if err := m.RunCustomMigrations(db.Migrator()); err != nil {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Close the connection used to auto-migrate the database.
|
||||
@ -66,7 +74,12 @@ func NewTestDB(t testing.TB, system, suite string, tables ...schema.Tabler) *pgx
|
||||
require.NoError(t, err)
|
||||
|
||||
// Open a new connection to the test suite database.
|
||||
testDB, err := pgxpool.New(context.Background(), dsn.String())
|
||||
dbConfig, err := pgxpool.ParseConfig(dsn.String())
|
||||
require.NoError(t, err)
|
||||
if testing.Verbose() {
|
||||
dbConfig.ConnConfig.Tracer = pgxTestTracer{TB: t}
|
||||
}
|
||||
testDB, err := pgxpool.NewWithConfig(context.Background(), dbConfig)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
@ -110,3 +123,43 @@ func ClearTablesAfterTest(t *testing.T, db *pgxpool.Pool, tables ...schema.Table
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// pgxTestTracer implements various pgx tracing hooks for dumping diagnostics
|
||||
// in testing.
|
||||
type pgxTestTracer struct{ testing.TB }
|
||||
|
||||
// Select tracing hooks we want to implement.
|
||||
var (
|
||||
_ pgx.QueryTracer = pgxTestTracer{}
|
||||
)
|
||||
|
||||
func (t pgxTestTracer) TraceQueryStart(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryStartData) context.Context {
|
||||
var args []string
|
||||
if len(data.Args) > 0 {
|
||||
// Divider for readability
|
||||
args = append(args, "\n---")
|
||||
}
|
||||
for _, arg := range data.Args {
|
||||
data, err := json.MarshalIndent(arg, "", " ")
|
||||
if err != nil {
|
||||
args = append(args, fmt.Sprintf("marshal %T: %+v", arg, err))
|
||||
}
|
||||
args = append(args, string(data))
|
||||
}
|
||||
|
||||
t.Logf(`pgx.QueryStart db=%q
|
||||
%s%s`,
|
||||
conn.Config().Database,
|
||||
strings.TrimSpace(data.SQL),
|
||||
strings.Join(args, "\n"))
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (t pgxTestTracer) TraceQueryEnd(ctx context.Context, conn *pgx.Conn, data pgx.TraceQueryEndData) {
|
||||
if data.Err != nil {
|
||||
t.Logf(`pgx.QueryEnd db=%q tag=%q error=%q`,
|
||||
conn.Config().Database,
|
||||
data.CommandTag.String(),
|
||||
data.Err)
|
||||
}
|
||||
}
|
||||
|
||||
@ -0,0 +1,9 @@
|
||||
load("@io_bazel_rules_go//go:def.bzl", "go_library")
|
||||
|
||||
go_library(
|
||||
name = "custommigrator",
|
||||
srcs = ["custommigrator.go"],
|
||||
importpath = "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables/custommigrator",
|
||||
visibility = ["//cmd/enterprise-portal:__subpackages__"],
|
||||
deps = ["@io_gorm_gorm//:gorm"],
|
||||
)
|
||||
@ -0,0 +1,9 @@
|
||||
package custommigrator
|
||||
|
||||
import "gorm.io/gorm"
|
||||
|
||||
type CustomTableMigrator interface {
|
||||
// RunCustomMigrations is called after all other migrations have been run.
|
||||
// It can implement custom migrations.
|
||||
RunCustomMigrations(migrator gorm.Migrator) error
|
||||
}
|
||||
@ -12,9 +12,9 @@ import (
|
||||
// ⚠️ WARNING: This list is meant to be read-only.
|
||||
func All() []schema.Tabler {
|
||||
return []schema.Tabler{
|
||||
&subscriptions.Subscription{},
|
||||
&subscriptions.TableSubscription{},
|
||||
&subscriptions.SubscriptionCondition{},
|
||||
&subscriptions.SubscriptionLicense{},
|
||||
&subscriptions.TableSubscriptionLicense{},
|
||||
&subscriptions.SubscriptionLicenseCondition{},
|
||||
|
||||
&codyaccess.CodyGatewayAccess{},
|
||||
|
||||
@ -0,0 +1,12 @@
|
||||
load("@io_bazel_rules_go//go:def.bzl", "go_library")
|
||||
|
||||
go_library(
|
||||
name = "utctime",
|
||||
srcs = ["utctime.go"],
|
||||
importpath = "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/utctime",
|
||||
visibility = ["//cmd/enterprise-portal:__subpackages__"],
|
||||
deps = [
|
||||
"//lib/errors",
|
||||
"//lib/pointers",
|
||||
],
|
||||
)
|
||||
@ -0,0 +1,80 @@
|
||||
package utctime
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
"github.com/sourcegraph/sourcegraph/lib/pointers"
|
||||
)
|
||||
|
||||
// Time is a wrapper around time.Time that implements the database/sql.Scanner
|
||||
// and database/sql/driver.Valuer interfaces to serialize and deserialize time
|
||||
// in UTC time zone.
|
||||
//
|
||||
// Time ensures that time.Time values are always:
|
||||
//
|
||||
// - represented in UTC for consistency
|
||||
// - rounded to microsecond precision
|
||||
//
|
||||
// We round the time because PostgreSQL times are represented in microseconds:
|
||||
// https://www.postgresql.org/docs/current/datatype-datetime.html
|
||||
type Time time.Time
|
||||
|
||||
// Now returns the current time in UTC.
|
||||
func Now() Time { return Time(time.Now()) }
|
||||
|
||||
// FromTime returns a utctime.Time from a time.Time.
|
||||
func FromTime(t time.Time) Time { return Time(t.UTC().Round(time.Microsecond)) }
|
||||
|
||||
var _ sql.Scanner = (*Time)(nil)
|
||||
|
||||
func (t *Time) Scan(src any) error {
|
||||
if src == nil {
|
||||
return nil
|
||||
}
|
||||
if v, ok := src.(time.Time); ok {
|
||||
*t = FromTime(v)
|
||||
return nil
|
||||
}
|
||||
return errors.Newf("value %T is not time.Time", src)
|
||||
}
|
||||
|
||||
var _ driver.Valuer = (*Time)(nil)
|
||||
|
||||
// Value must be called with a non-nil Time. driver.Valuer callers will first
|
||||
// check that the value is non-nil, so this is safe.
|
||||
func (t Time) Value() (driver.Value, error) {
|
||||
stdTime := t.GetTime()
|
||||
return *stdTime, nil
|
||||
}
|
||||
|
||||
var _ json.Marshaler = (*Time)(nil)
|
||||
|
||||
func (t Time) MarshalJSON() ([]byte, error) { return json.Marshal(t.GetTime()) }
|
||||
|
||||
var _ json.Unmarshaler = (*Time)(nil)
|
||||
|
||||
func (t *Time) UnmarshalJSON(data []byte) error {
|
||||
var stdTime time.Time
|
||||
if err := json.Unmarshal(data, &stdTime); err != nil {
|
||||
return err
|
||||
}
|
||||
*t = FromTime(stdTime)
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetTime returns the underlying time.GetTime value, or nil if it is nil.
|
||||
func (t *Time) GetTime() *time.Time {
|
||||
if t == nil {
|
||||
return nil
|
||||
}
|
||||
return pointers.Ptr(t.AsTime())
|
||||
}
|
||||
|
||||
// Time casts the Time as a standard time.Time value.
|
||||
func (t Time) AsTime() time.Time {
|
||||
return time.Time(t).UTC().Round(time.Microsecond)
|
||||
}
|
||||
@ -21,6 +21,7 @@ import (
|
||||
"github.com/sourcegraph/sourcegraph/lib/redislock"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables/custommigrator"
|
||||
)
|
||||
|
||||
// maybeMigrate runs the auto-migration for the database when needed based on
|
||||
@ -42,6 +43,12 @@ func maybeMigrate(ctx context.Context, logger log.Logger, contract runtime.Contr
|
||||
}
|
||||
span.End()
|
||||
}()
|
||||
logger = logger.
|
||||
WithTrace(log.TraceContext{
|
||||
TraceID: span.SpanContext().TraceID().String(),
|
||||
SpanID: span.SpanContext().SpanID().String(),
|
||||
}).
|
||||
With(log.String("database", dbName))
|
||||
|
||||
sqlDB, err := contract.PostgreSQL.OpenDatabase(ctx, dbName)
|
||||
if err != nil {
|
||||
@ -83,17 +90,20 @@ func maybeMigrate(ctx context.Context, logger log.Logger, contract runtime.Contr
|
||||
span.AddEvent("lock.acquired")
|
||||
|
||||
versionKey := fmt.Sprintf("%s:db_version", dbName)
|
||||
liveVersion := redisClient.Get(ctx, versionKey).Val()
|
||||
if shouldSkipMigration(
|
||||
redisClient.Get(ctx, versionKey).Val(),
|
||||
liveVersion,
|
||||
currentVersion,
|
||||
) {
|
||||
logger.Info("skipped auto-migration",
|
||||
log.String("database", dbName),
|
||||
log.String("currentVersion", currentVersion),
|
||||
)
|
||||
span.SetAttributes(attribute.Bool("skipped", true))
|
||||
return nil
|
||||
}
|
||||
logger.Info("executing auto-migration",
|
||||
log.String("liveVersion", liveVersion),
|
||||
log.String("currentVersion", currentVersion))
|
||||
span.SetAttributes(attribute.Bool("skipped", false))
|
||||
|
||||
// Create a session that ignore debug logging.
|
||||
@ -108,6 +118,11 @@ func maybeMigrate(ctx context.Context, logger log.Logger, contract runtime.Contr
|
||||
if err != nil {
|
||||
return errors.Wrapf(err, "auto migrating table for %s", errors.Safe(fmt.Sprintf("%T", table)))
|
||||
}
|
||||
if m, ok := table.(custommigrator.CustomTableMigrator); ok {
|
||||
if err := m.RunCustomMigrations(sess.Migrator()); err != nil {
|
||||
return errors.Wrapf(err, "running custom migrations for %s", errors.Safe(fmt.Sprintf("%T", table)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return redisClient.Set(ctx, versionKey, currentVersion, 0).Err()
|
||||
|
||||
@ -14,27 +14,39 @@ go_library(
|
||||
visibility = ["//cmd/enterprise-portal:__subpackages__"],
|
||||
deps = [
|
||||
"//cmd/enterprise-portal/internal/database/internal/upsert",
|
||||
"//cmd/enterprise-portal/internal/database/internal/utctime",
|
||||
"//internal/license",
|
||||
"//lib/enterpriseportal/subscriptions/v1:subscriptions",
|
||||
"//lib/errors",
|
||||
"//lib/pointers",
|
||||
"@com_github_jackc_pgtype//:pgtype",
|
||||
"@com_github_google_uuid//:uuid",
|
||||
"@com_github_jackc_pgx_v5//:pgx",
|
||||
"@com_github_jackc_pgx_v5//pgxpool",
|
||||
"@io_gorm_gorm//:gorm",
|
||||
],
|
||||
)
|
||||
|
||||
go_test(
|
||||
name = "subscriptions_test",
|
||||
srcs = ["subscriptions_test.go"],
|
||||
srcs = [
|
||||
"licenses_test.go",
|
||||
"subscriptions_test.go",
|
||||
],
|
||||
tags = [
|
||||
TAG_INFRA_CORESERVICES,
|
||||
"requires-network",
|
||||
],
|
||||
deps = [
|
||||
":subscriptions",
|
||||
"//cmd/enterprise-portal/internal/database",
|
||||
"//cmd/enterprise-portal/internal/database/databasetest",
|
||||
"//cmd/enterprise-portal/internal/database/internal/tables",
|
||||
"//cmd/enterprise-portal/internal/database/internal/utctime",
|
||||
"//internal/license",
|
||||
"//lib/pointers",
|
||||
"@com_github_google_uuid//:uuid",
|
||||
"@com_github_hexops_autogold_v2//:autogold",
|
||||
"@com_github_hexops_valast//:valast",
|
||||
"@com_github_jackc_pgx_v5//:pgx",
|
||||
"@com_github_stretchr_testify//assert",
|
||||
"@com_github_stretchr_testify//require",
|
||||
|
||||
@ -1,11 +1,17 @@
|
||||
package subscriptions
|
||||
|
||||
import "time"
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/utctime"
|
||||
subscriptionsv1 "github.com/sourcegraph/sourcegraph/lib/enterpriseportal/subscriptions/v1"
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
"github.com/sourcegraph/sourcegraph/lib/pointers"
|
||||
)
|
||||
|
||||
type SubscriptionLicenseCondition struct {
|
||||
// ⚠️ DO NOT USE: This field is only used for creating foreign key constraint.
|
||||
License *SubscriptionLicense `gorm:"foreignKey:LicenseID"`
|
||||
|
||||
// SubscriptionID is the internal unprefixed UUID of the related license.
|
||||
LicenseID string `gorm:"type:uuid;not null"`
|
||||
// Status is the type of status corresponding to this condition, corresponding
|
||||
@ -15,9 +21,73 @@ type SubscriptionLicenseCondition struct {
|
||||
Message *string `gorm:"size:256"`
|
||||
// TransitionTime is the time at which the condition was created, i.e. when
|
||||
// the license transitioned into this status.
|
||||
TransitionTime time.Time `gorm:"not null;default:current_timestamp"`
|
||||
TransitionTime utctime.Time `gorm:"not null;default:current_timestamp"`
|
||||
}
|
||||
|
||||
func (s *SubscriptionLicenseCondition) TableName() string {
|
||||
func (*SubscriptionLicenseCondition) TableName() string {
|
||||
return "enterprise_portal_subscription_license_conditions"
|
||||
}
|
||||
|
||||
// subscriptionLicenseConditionJSONBAgg must be used with:
|
||||
//
|
||||
// LEFT JOIN
|
||||
// enterprise_portal_subscription_license_conditions license_condition
|
||||
// ON license_condition.license_id = id
|
||||
// GROUP BY
|
||||
// id
|
||||
//
|
||||
// The conditions are aggregated in JSON to 'conditions', which can be directly
|
||||
// unmarshaled into the 'SubscriptionLicenseCondition' type using 'pgx'.
|
||||
func subscriptionLicenseConditionJSONBAgg() string {
|
||||
return `
|
||||
jsonb_agg(
|
||||
jsonb_build_object(
|
||||
'Status', license_condition.status,
|
||||
'Message', license_condition.message,
|
||||
'TransitionTime', license_condition.transition_time
|
||||
)
|
||||
ORDER BY license_condition.transition_time DESC
|
||||
) AS conditions`
|
||||
}
|
||||
|
||||
type licenseConditionsStore struct{ tx pgx.Tx }
|
||||
|
||||
// newLicenseConditionsStore is meant to be used exclusively in the context of
|
||||
// a transaction, where the parent license is being updated at the same time.
|
||||
//
|
||||
// The caller owns the transaction lifecycle.
|
||||
func newLicenseConditionsStore(tx pgx.Tx) *licenseConditionsStore {
|
||||
return &licenseConditionsStore{tx: tx}
|
||||
}
|
||||
|
||||
type createLicenseConditionOpts struct {
|
||||
Status subscriptionsv1.EnterpriseSubscriptionLicenseCondition_Status
|
||||
Message string
|
||||
TransitionTime utctime.Time
|
||||
}
|
||||
|
||||
func (s *licenseConditionsStore) createLicenseCondition(ctx context.Context, licenseID string, opts createLicenseConditionOpts) error {
|
||||
if opts.TransitionTime.GetTime().IsZero() {
|
||||
return errors.New("transition time is required")
|
||||
}
|
||||
_, err := s.tx.Exec(ctx, `
|
||||
INSERT INTO enterprise_portal_subscription_license_conditions (
|
||||
license_id,
|
||||
status,
|
||||
message,
|
||||
transition_time
|
||||
)
|
||||
VALUES (
|
||||
@licenseID,
|
||||
@status,
|
||||
@message,
|
||||
@transitionTime
|
||||
)`, pgx.NamedArgs{
|
||||
"licenseID": licenseID,
|
||||
// Convert to string representation of EnterpriseSubscriptionLicenseCondition
|
||||
"status": subscriptionsv1.EnterpriseSubscriptionLicenseCondition_Status_name[int32(opts.Status)],
|
||||
"message": pointers.NilIfZero(opts.Message),
|
||||
"transitionTime": opts.TransitionTime,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
@ -1,15 +1,49 @@
|
||||
package subscriptions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgtype"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"gorm.io/gorm"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/utctime"
|
||||
internallicense "github.com/sourcegraph/sourcegraph/internal/license"
|
||||
subscriptionsv1 "github.com/sourcegraph/sourcegraph/lib/enterpriseportal/subscriptions/v1"
|
||||
"github.com/sourcegraph/sourcegraph/lib/pointers"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
)
|
||||
|
||||
type SubscriptionLicense struct {
|
||||
// ⚠️ DO NOT USE: This type is only used for creating foreign key constraints
|
||||
// and initializing tables with gorm.
|
||||
type TableSubscriptionLicense struct {
|
||||
// ⚠️ DO NOT USE: This field is only used for creating foreign key constraint.
|
||||
Subscription *Subscription `gorm:"foreignKey:SubscriptionID"`
|
||||
Conditions *[]SubscriptionLicenseCondition `gorm:"foreignKey:LicenseID"`
|
||||
|
||||
SubscriptionLicense
|
||||
}
|
||||
|
||||
func (*TableSubscriptionLicense) TableName() string {
|
||||
return "enterprise_portal_subscription_licenses"
|
||||
}
|
||||
|
||||
// Implement tables.CustomMigrator
|
||||
func (s *TableSubscriptionLicense) RunCustomMigrations(migrator gorm.Migrator) error {
|
||||
if migrator.HasColumn(s, "license_kind") {
|
||||
if err := migrator.DropColumn(s, "license_kind"); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type SubscriptionLicense struct {
|
||||
// SubscriptionID is the internal unprefixed UUID of the related subscription.
|
||||
SubscriptionID string `gorm:"type:uuid;not null"`
|
||||
// ID is the internal unprefixed UUID of this license.
|
||||
@ -19,21 +53,331 @@ type SubscriptionLicense struct {
|
||||
// to this subscription.
|
||||
//
|
||||
// Condition transition details are tracked in 'enterprise_portal_subscription_license_conditions'.
|
||||
CreatedAt time.Time `gorm:"not null;default:current_timestamp"`
|
||||
RevokedAt *time.Time // Null indicates the licnese is not revoked.
|
||||
CreatedAt utctime.Time `gorm:"not null;default:current_timestamp"`
|
||||
RevokedAt *utctime.Time // Null indicates the license is not revoked.
|
||||
|
||||
// LicenseKind is the kind of license stored in LicenseData, corresponding
|
||||
// ExpireAt is the time at which the license should expire. Expiration does
|
||||
// NOT get a corresponding condition entry in 'enterprise_portal_subscription_license_conditions'.
|
||||
ExpireAt utctime.Time `gorm:"not null"`
|
||||
|
||||
// LicenseType is the kind of license stored in LicenseData, corresponding
|
||||
// to the API 'EnterpriseSubscriptionLicenseType'.
|
||||
LicenseKind string `gorm:"not null"`
|
||||
LicenseType string `gorm:"not null"`
|
||||
// LicenseData is the license data stored in JSON format. It is read-only
|
||||
// and generally never queried in conditions - properties that are should
|
||||
// be stored at the subscription or license level.
|
||||
//
|
||||
// Value shapes correspond to API types appropriate for each
|
||||
// 'EnterpriseSubscriptionLicenseType'.
|
||||
LicenseData pgtype.JSONB `gorm:"type:jsonb"`
|
||||
LicenseData json.RawMessage `gorm:"type:jsonb"`
|
||||
}
|
||||
|
||||
func (s *SubscriptionLicense) TableName() string {
|
||||
return "enterprise_portal_subscription_licenses"
|
||||
// subscriptionLicenseWithConditionsColumns must match scanSubscriptionLicense()
|
||||
// values.
|
||||
func subscriptionLicenseWithConditionsColumns() []string {
|
||||
return []string{
|
||||
"subscription_id",
|
||||
"id",
|
||||
|
||||
"created_at",
|
||||
"revoked_at",
|
||||
"expire_at",
|
||||
|
||||
"license_type",
|
||||
"license_data",
|
||||
|
||||
subscriptionLicenseConditionJSONBAgg(),
|
||||
}
|
||||
}
|
||||
|
||||
type LicenseWithConditions struct {
|
||||
SubscriptionLicense
|
||||
Conditions []SubscriptionLicenseCondition
|
||||
}
|
||||
|
||||
// scanSubscription matches subscriptionTableColumns() values.
|
||||
func scanSubscriptionLicenseWithConditions(row pgx.Row) (*LicenseWithConditions, error) {
|
||||
var l LicenseWithConditions
|
||||
err := row.Scan(
|
||||
&l.SubscriptionID,
|
||||
&l.ID,
|
||||
&l.CreatedAt,
|
||||
&l.RevokedAt,
|
||||
&l.ExpireAt,
|
||||
&l.LicenseType,
|
||||
&l.LicenseData,
|
||||
&l.Conditions, // see subscriptionLicenseConditionJSONBAgg docstring
|
||||
)
|
||||
return &l, err
|
||||
}
|
||||
|
||||
// LicensesStore manages licenses belonging to Enterprise subscriptions.
|
||||
//
|
||||
// Licenses can only be created and revoked - they can never be updated.
|
||||
type LicensesStore struct {
|
||||
db *pgxpool.Pool
|
||||
}
|
||||
|
||||
func NewLicensesStore(db *pgxpool.Pool) *LicensesStore {
|
||||
return &LicensesStore{
|
||||
db: db,
|
||||
}
|
||||
}
|
||||
|
||||
type ListLicensesOpts struct {
|
||||
SubscriptionID string
|
||||
// PageSize is the maximum number of licenses to return.
|
||||
PageSize int
|
||||
}
|
||||
|
||||
func (opts ListLicensesOpts) toQueryConditions() (where, limitClause string, _ pgx.NamedArgs) {
|
||||
whereConds := []string{"TRUE"}
|
||||
namedArgs := pgx.NamedArgs{}
|
||||
if opts.SubscriptionID != "" {
|
||||
whereConds = append(whereConds, "subscription_id = @subscriptionID")
|
||||
namedArgs["subscriptionID"] = opts.SubscriptionID
|
||||
}
|
||||
where = strings.Join(whereConds, " AND ")
|
||||
|
||||
if opts.PageSize > 0 {
|
||||
limitClause = "LIMIT @pageSize"
|
||||
namedArgs["pageSize"] = opts.PageSize
|
||||
}
|
||||
return where, limitClause, namedArgs
|
||||
}
|
||||
|
||||
func (s *LicensesStore) List(ctx context.Context, opts ListLicensesOpts) ([]*LicenseWithConditions, error) {
|
||||
where, limitClause, namedArgs := opts.toQueryConditions()
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
%s
|
||||
FROM
|
||||
enterprise_portal_subscription_licenses
|
||||
LEFT JOIN
|
||||
enterprise_portal_subscription_license_conditions license_condition
|
||||
ON license_condition.license_id = id
|
||||
WHERE
|
||||
%s
|
||||
GROUP BY
|
||||
id
|
||||
ORDER BY
|
||||
created_at DESC
|
||||
%s`,
|
||||
strings.Join(subscriptionLicenseWithConditionsColumns(), ", "),
|
||||
where, limitClause)
|
||||
|
||||
rows, err := s.db.Query(ctx, query, namedArgs)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "query rows")
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var licenses []*LicenseWithConditions
|
||||
for rows.Next() {
|
||||
license, err := scanSubscriptionLicenseWithConditions(rows)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "scan row")
|
||||
}
|
||||
licenses = append(licenses, license)
|
||||
}
|
||||
return licenses, rows.Err()
|
||||
}
|
||||
|
||||
func (s *LicensesStore) Get(ctx context.Context, licenseID string) (*LicenseWithConditions, error) {
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
%s
|
||||
FROM
|
||||
enterprise_portal_subscription_licenses
|
||||
LEFT JOIN
|
||||
enterprise_portal_subscription_license_conditions license_condition
|
||||
ON license_condition.license_id = id
|
||||
WHERE
|
||||
id = @licenseID
|
||||
GROUP BY
|
||||
id`,
|
||||
strings.Join(subscriptionLicenseWithConditionsColumns(), ", "))
|
||||
|
||||
license, err := scanSubscriptionLicenseWithConditions(
|
||||
s.db.QueryRow(ctx, query, pgx.NamedArgs{
|
||||
"licenseID": licenseID,
|
||||
}),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "query rows")
|
||||
}
|
||||
return license, nil
|
||||
}
|
||||
|
||||
type CreateLicenseOpts struct {
|
||||
Message string
|
||||
// If nil, the creation time will be set to the current time.
|
||||
Time *utctime.Time
|
||||
// Expiration time of the license.
|
||||
ExpireTime utctime.Time
|
||||
}
|
||||
|
||||
// LicenseKey corresponds to *subscriptionsv1.EnterpriseSubscriptionLicenseKey
|
||||
// and the 'ENTERPRISE_SUBSCRIPTION_LICENSE_TYPE_KEY' license type.
|
||||
type LicenseKey struct {
|
||||
Info internallicense.Info
|
||||
// Signed license key with the license information in Info.
|
||||
SignedKey string
|
||||
}
|
||||
|
||||
// CreateLicense creates a new classic offline license for the given subscription.
|
||||
func (s *LicensesStore) CreateLicenseKey(
|
||||
ctx context.Context,
|
||||
subscriptionID string,
|
||||
license *LicenseKey,
|
||||
opts CreateLicenseOpts,
|
||||
) (_ *LicenseWithConditions, err error) {
|
||||
// Special behaviour: the license key embeds the creation time, and it must
|
||||
// match the time provided in the options.
|
||||
if opts.Time == nil {
|
||||
return nil, errors.New("creation time must be specified for licensekeys")
|
||||
} else if !opts.Time.GetTime().Equal(utctime.FromTime(license.Info.CreatedAt).AsTime()) {
|
||||
return nil, errors.New("creation time must match the license key information")
|
||||
}
|
||||
if !opts.ExpireTime.GetTime().Equal(utctime.FromTime(license.Info.ExpiresAt).AsTime()) {
|
||||
return nil, errors.New("expiration time must match the license key information")
|
||||
}
|
||||
|
||||
return s.create(
|
||||
ctx,
|
||||
subscriptionID,
|
||||
subscriptionsv1.EnterpriseSubscriptionLicenseType_ENTERPRISE_SUBSCRIPTION_LICENSE_TYPE_KEY,
|
||||
license,
|
||||
opts,
|
||||
)
|
||||
}
|
||||
|
||||
func (s *LicensesStore) create(
|
||||
ctx context.Context,
|
||||
subscriptionID string,
|
||||
licenseType subscriptionsv1.EnterpriseSubscriptionLicenseType,
|
||||
license any,
|
||||
opts CreateLicenseOpts,
|
||||
) (_ *LicenseWithConditions, err error) {
|
||||
if subscriptionID == "" {
|
||||
return nil, errors.New("subscription ID must be specified")
|
||||
}
|
||||
if opts.Time == nil {
|
||||
opts.Time = pointers.Ptr(utctime.Now())
|
||||
} else if opts.Time.GetTime().After(time.Now()) {
|
||||
return nil, errors.New("creation time cannot be in the future")
|
||||
}
|
||||
if licenseType == subscriptionsv1.EnterpriseSubscriptionLicenseType_ENTERPRISE_SUBSCRIPTION_LICENSE_TYPE_UNSPECIFIED {
|
||||
return nil, errors.New("license type must be specified")
|
||||
}
|
||||
|
||||
licenseID, err := uuid.NewV7()
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "generate uuid")
|
||||
}
|
||||
licenseData, err := json.Marshal(license)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "marshal license data")
|
||||
}
|
||||
tx, err := s.db.Begin(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "begin transaction")
|
||||
}
|
||||
defer func() {
|
||||
if rollbackErr := tx.Rollback(context.Background()); rollbackErr != nil {
|
||||
err = errors.Append(err, errors.Wrap(err, "rollback"))
|
||||
}
|
||||
}()
|
||||
|
||||
if _, err := tx.Exec(ctx, `
|
||||
INSERT INTO enterprise_portal_subscription_licenses (
|
||||
id,
|
||||
subscription_id,
|
||||
license_type,
|
||||
license_data,
|
||||
created_at,
|
||||
expire_at
|
||||
)
|
||||
VALUES (
|
||||
@licenseID,
|
||||
@subscriptionID,
|
||||
@licenseType,
|
||||
@licenseData,
|
||||
@createdAt,
|
||||
@expireAt
|
||||
)
|
||||
`, pgx.NamedArgs{
|
||||
"licenseID": licenseID.String(),
|
||||
"subscriptionID": subscriptionID,
|
||||
"licenseType": subscriptionsv1.EnterpriseSubscriptionLicenseType_name[int32(licenseType)],
|
||||
"licenseData": licenseData,
|
||||
"createdAt": opts.Time,
|
||||
"expireAt": opts.ExpireTime,
|
||||
}); err != nil {
|
||||
return nil, errors.Wrap(err, "create license")
|
||||
}
|
||||
|
||||
if err := newLicenseConditionsStore(tx).createLicenseCondition(ctx, licenseID.String(), createLicenseConditionOpts{
|
||||
Status: subscriptionsv1.EnterpriseSubscriptionLicenseCondition_STATUS_CREATED,
|
||||
Message: opts.Message,
|
||||
TransitionTime: *opts.Time,
|
||||
}); err != nil {
|
||||
return nil, errors.Wrap(err, "create license condition")
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return nil, errors.Wrap(err, "commit transaction")
|
||||
}
|
||||
|
||||
return s.Get(ctx, licenseID.String())
|
||||
}
|
||||
|
||||
type RevokeLicenseOpts struct {
|
||||
Message string
|
||||
// If nil, the revocation time will be set to the current time.
|
||||
Time *utctime.Time
|
||||
}
|
||||
|
||||
// Revoke marks the given license as revoked.
|
||||
func (s *LicensesStore) Revoke(ctx context.Context, licenseID string, opts RevokeLicenseOpts) (*LicenseWithConditions, error) {
|
||||
if opts.Time == nil {
|
||||
opts.Time = pointers.Ptr(utctime.Now())
|
||||
} else if opts.Time.GetTime().After(time.Now()) {
|
||||
return nil, errors.New("revocation time cannot be in the future")
|
||||
}
|
||||
|
||||
tx, err := s.db.Begin(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "begin transaction")
|
||||
}
|
||||
defer func() {
|
||||
if rollbackErr := tx.Rollback(context.Background()); rollbackErr != nil {
|
||||
err = errors.Append(err, rollbackErr)
|
||||
}
|
||||
}()
|
||||
|
||||
if _, err := tx.Exec(ctx, `
|
||||
UPDATE enterprise_portal_subscription_licenses
|
||||
SET revoked_at = COALESCE(revoked_at, @revokedAt) -- use existing revoke time if already revoked
|
||||
WHERE id = @licenseID
|
||||
`, pgx.NamedArgs{
|
||||
"revokedAt": opts.Time,
|
||||
"licenseID": licenseID,
|
||||
}); err != nil {
|
||||
return nil, errors.Wrap(err, "revoke license")
|
||||
}
|
||||
|
||||
if err := newLicenseConditionsStore(tx).createLicenseCondition(ctx, licenseID, createLicenseConditionOpts{
|
||||
Status: subscriptionsv1.EnterpriseSubscriptionLicenseCondition_STATUS_REVOKED,
|
||||
Message: opts.Message,
|
||||
TransitionTime: *opts.Time,
|
||||
}); err != nil {
|
||||
return nil, errors.Wrap(err, "create license condition")
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return nil, errors.Wrap(err, "commit transaction")
|
||||
}
|
||||
|
||||
return s.Get(ctx, licenseID)
|
||||
}
|
||||
|
||||
@ -0,0 +1,238 @@
|
||||
package subscriptions_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/hexops/autogold/v2"
|
||||
"github.com/hexops/valast"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/databasetest"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/utctime"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/subscriptions"
|
||||
"github.com/sourcegraph/sourcegraph/internal/license"
|
||||
"github.com/sourcegraph/sourcegraph/lib/pointers"
|
||||
)
|
||||
|
||||
func TestLicensesStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
db := databasetest.NewTestDB(t, "enterprise-portal", t.Name(), tables.All()...)
|
||||
|
||||
subscriptionID1 := uuid.NewString()
|
||||
subscriptionID2 := uuid.NewString()
|
||||
|
||||
subs := subscriptions.NewStore(db)
|
||||
_, err := subs.Upsert(ctx, subscriptionID1, subscriptions.UpsertSubscriptionOptions{
|
||||
DisplayName: pointers.Ptr(database.NewNullString("Acme, Inc. 1")),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = subs.Upsert(ctx, subscriptionID2, subscriptions.UpsertSubscriptionOptions{
|
||||
DisplayName: pointers.Ptr(database.NewNullString("Acme, Inc. 2")),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
licenses := subscriptions.NewLicensesStore(db)
|
||||
|
||||
var createdLicenses []*subscriptions.LicenseWithConditions
|
||||
getCreatedByLicenseID := func(t *testing.T, licenseID string) *subscriptions.LicenseWithConditions {
|
||||
for _, l := range createdLicenses {
|
||||
if l.ID == licenseID {
|
||||
return l
|
||||
}
|
||||
}
|
||||
t.Errorf("license %q not found", licenseID)
|
||||
t.FailNow()
|
||||
return nil
|
||||
}
|
||||
t.Run("CreateLicenseKey", func(t *testing.T) {
|
||||
testLicense := func(
|
||||
got *subscriptions.LicenseWithConditions,
|
||||
wantMessage autogold.Value,
|
||||
wantLicenseData autogold.Value,
|
||||
) {
|
||||
assert.NotEmpty(t, got.ID)
|
||||
assert.NotZero(t, got.CreatedAt)
|
||||
assert.NotZero(t, got.ExpireAt)
|
||||
assert.Equal(t, "ENTERPRISE_SUBSCRIPTION_LICENSE_TYPE_KEY", got.LicenseType)
|
||||
wantLicenseData.Equal(t, string(got.LicenseData))
|
||||
|
||||
assert.Len(t, got.Conditions, 1)
|
||||
wantMessage.Equal(t, got.Conditions[0].Message)
|
||||
assert.Equal(t, "STATUS_CREATED", got.Conditions[0].Status)
|
||||
assert.Equal(t, got.CreatedAt, got.Conditions[0].TransitionTime)
|
||||
}
|
||||
|
||||
got, err := licenses.CreateLicenseKey(ctx, subscriptionID1,
|
||||
&subscriptions.LicenseKey{
|
||||
Info: license.Info{
|
||||
Tags: []string{"foo"},
|
||||
CreatedAt: time.Time{}.Add(1 * time.Hour),
|
||||
ExpiresAt: time.Time{}.Add(48 * time.Hour),
|
||||
},
|
||||
SignedKey: "asdfasdf",
|
||||
},
|
||||
subscriptions.CreateLicenseOpts{
|
||||
Message: t.Name() + " 1 old",
|
||||
Time: pointers.Ptr(utctime.FromTime(time.Time{}.Add(1 * time.Hour))),
|
||||
ExpireTime: utctime.FromTime(time.Time{}.Add(48 * time.Hour)),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
testLicense(
|
||||
got,
|
||||
autogold.Expect(valast.Ptr("TestLicensesStore/CreateLicenseKey 1 old")),
|
||||
autogold.Expect(`{"Info": {"c": "0001-01-01T01:00:00Z", "e": "0001-01-03T00:00:00Z", "t": ["foo"], "u": 0}, "SignedKey": "asdfasdf"}`),
|
||||
)
|
||||
createdLicenses = append(createdLicenses, got)
|
||||
|
||||
got, err = licenses.CreateLicenseKey(ctx, subscriptionID1,
|
||||
&subscriptions.LicenseKey{
|
||||
Info: license.Info{
|
||||
Tags: []string{"baz"},
|
||||
CreatedAt: time.Time{}.Add(24 * time.Hour),
|
||||
ExpiresAt: time.Time{}.Add(48 * time.Hour),
|
||||
},
|
||||
SignedKey: "barasdf",
|
||||
},
|
||||
subscriptions.CreateLicenseOpts{
|
||||
Message: t.Name() + " 1",
|
||||
Time: pointers.Ptr(utctime.FromTime(time.Time{}.Add(24 * time.Hour))),
|
||||
ExpireTime: utctime.FromTime(time.Time{}.Add(48 * time.Hour)),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
testLicense(
|
||||
got,
|
||||
autogold.Expect(valast.Ptr("TestLicensesStore/CreateLicenseKey 1")),
|
||||
autogold.Expect(`{"Info": {"c": "0001-01-02T00:00:00Z", "e": "0001-01-03T00:00:00Z", "t": ["baz"], "u": 0}, "SignedKey": "barasdf"}`),
|
||||
)
|
||||
createdLicenses = append(createdLicenses, got)
|
||||
|
||||
got, err = licenses.CreateLicenseKey(ctx, subscriptionID2,
|
||||
&subscriptions.LicenseKey{
|
||||
Info: license.Info{
|
||||
Tags: []string{"tag"},
|
||||
CreatedAt: time.Time{}.Add(24 * time.Hour),
|
||||
ExpiresAt: time.Time{}.Add(48 * time.Hour),
|
||||
},
|
||||
SignedKey: "asdffdsadf",
|
||||
},
|
||||
subscriptions.CreateLicenseOpts{
|
||||
Message: t.Name() + " 2",
|
||||
Time: pointers.Ptr(utctime.FromTime(time.Time{}.Add(24 * time.Hour))),
|
||||
ExpireTime: utctime.FromTime(time.Time{}.Add(48 * time.Hour)),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
testLicense(
|
||||
got,
|
||||
autogold.Expect(valast.Ptr("TestLicensesStore/CreateLicenseKey 2")),
|
||||
autogold.Expect(`{"Info": {"c": "0001-01-02T00:00:00Z", "e": "0001-01-03T00:00:00Z", "t": ["tag"], "u": 0}, "SignedKey": "asdffdsadf"}`),
|
||||
)
|
||||
createdLicenses = append(createdLicenses, got)
|
||||
|
||||
t.Run("createdAt does not match", func(t *testing.T) {
|
||||
_, err = licenses.CreateLicenseKey(ctx, subscriptionID2,
|
||||
&subscriptions.LicenseKey{
|
||||
Info: license.Info{
|
||||
Tags: []string{"tag"},
|
||||
CreatedAt: time.Time{}.Add(24 * time.Hour),
|
||||
},
|
||||
SignedKey: "asdffdsadf",
|
||||
},
|
||||
subscriptions.CreateLicenseOpts{
|
||||
Message: t.Name(),
|
||||
Time: pointers.Ptr(utctime.Now()),
|
||||
})
|
||||
require.Error(t, err)
|
||||
autogold.Expect("creation time must match the license key information").Equal(t, err.Error())
|
||||
})
|
||||
t.Run("expiresAt does not match", func(t *testing.T) {
|
||||
_, err = licenses.CreateLicenseKey(ctx, subscriptionID2,
|
||||
&subscriptions.LicenseKey{
|
||||
Info: license.Info{
|
||||
Tags: []string{"tag"},
|
||||
CreatedAt: time.Time{},
|
||||
ExpiresAt: time.Time{}.Add(48 * time.Hour),
|
||||
},
|
||||
SignedKey: "asdffdsadf",
|
||||
},
|
||||
subscriptions.CreateLicenseOpts{
|
||||
Message: t.Name(),
|
||||
Time: pointers.Ptr(utctime.FromTime(time.Time{})),
|
||||
ExpireTime: utctime.Now(),
|
||||
})
|
||||
require.Error(t, err)
|
||||
autogold.Expect("expiration time must match the license key information").Equal(t, err.Error())
|
||||
})
|
||||
})
|
||||
|
||||
// No point continuing if test licenses did not create, all tests after this
|
||||
// will fail
|
||||
if t.Failed() {
|
||||
t.FailNow()
|
||||
}
|
||||
|
||||
t.Run("List", func(t *testing.T) {
|
||||
listedLicenses, err := licenses.List(ctx, subscriptions.ListLicensesOpts{})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, listedLicenses, len(createdLicenses))
|
||||
for _, l := range listedLicenses {
|
||||
created := getCreatedByLicenseID(t, l.ID)
|
||||
assert.Equal(t, *created, *l)
|
||||
}
|
||||
|
||||
t.Run("List by subscription", func(t *testing.T) {
|
||||
listedLicenses, err := licenses.List(ctx, subscriptions.ListLicensesOpts{
|
||||
SubscriptionID: subscriptionID1,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, listedLicenses, 2)
|
||||
for _, l := range listedLicenses {
|
||||
assert.Equal(t, subscriptionID1, l.SubscriptionID)
|
||||
assert.Equal(t, *getCreatedByLicenseID(t, l.ID), *l)
|
||||
}
|
||||
|
||||
listedLicenses, err = licenses.List(ctx, subscriptions.ListLicensesOpts{
|
||||
SubscriptionID: subscriptionID2,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, listedLicenses, 1)
|
||||
for _, l := range listedLicenses {
|
||||
assert.Equal(t, subscriptionID2, l.SubscriptionID)
|
||||
assert.Equal(t, *getCreatedByLicenseID(t, l.ID), *l)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("Get", func(t *testing.T) {
|
||||
for _, license := range createdLicenses {
|
||||
got, err := licenses.Get(ctx, license.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, *license, *got)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Revoke", func(t *testing.T) {
|
||||
for idx, license := range createdLicenses {
|
||||
revokeTime := utctime.FromTime(time.Now().Add(-time.Second))
|
||||
got, err := licenses.Revoke(ctx, license.ID, subscriptions.RevokeLicenseOpts{
|
||||
Message: fmt.Sprintf("%s %d", t.Name(), idx),
|
||||
Time: pointers.Ptr(revokeTime),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, revokeTime.AsTime(), got.RevokedAt.AsTime())
|
||||
require.Len(t, got.Conditions, 2)
|
||||
// Most recent condition is sorted first, and should be the revocation
|
||||
assert.Equal(t, "STATUS_REVOKED", got.Conditions[0].Status)
|
||||
assert.Equal(t, revokeTime.AsTime(), got.Conditions[0].TransitionTime.AsTime())
|
||||
assert.Equal(t, "STATUS_CREATED", got.Conditions[1].Status)
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -2,6 +2,7 @@ package subscriptions
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
@ -10,10 +11,26 @@ import (
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/upsert"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/utctime"
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
"github.com/sourcegraph/sourcegraph/lib/pointers"
|
||||
)
|
||||
|
||||
// ⚠️ DO NOT USE: This type is only used for creating foreign key constraints
|
||||
// and initializing tables with gorm.
|
||||
type TableSubscription struct {
|
||||
// Each Subscription has many Licenses.
|
||||
Licenses []*TableSubscriptionLicense `gorm:"foreignKey:SubscriptionID"`
|
||||
|
||||
// Each Subscription has many Conditions.
|
||||
Conditions *[]SubscriptionCondition `gorm:"foreignKey:SubscriptionID"`
|
||||
|
||||
Subscription
|
||||
}
|
||||
|
||||
func (*TableSubscription) TableName() string {
|
||||
return "enterprise_portal_subscriptions"
|
||||
}
|
||||
|
||||
// Subscription is an Enterprise subscription record.
|
||||
type Subscription struct {
|
||||
// ID is the internal (unprefixed) UUID-format identifier for the subscription.
|
||||
@ -22,7 +39,7 @@ type Subscription struct {
|
||||
// "acme.sourcegraphcloud.com". This is set explicitly.
|
||||
//
|
||||
// It must be unique across all currently un-archived subscriptions.
|
||||
InstanceDomain string `gorm:"uniqueIndex:,where:archived_at IS NULL"`
|
||||
InstanceDomain *string `gorm:"uniqueIndex:,where:archived_at IS NULL"`
|
||||
|
||||
// WARNING: The below fields are not yet used in production.
|
||||
|
||||
@ -33,15 +50,15 @@ type Subscription struct {
|
||||
//
|
||||
// TODO: Clean up the database post-deploy and remove the 'Unnamed subscription'
|
||||
// part of the constraint.
|
||||
DisplayName string `gorm:"size:256;not null;uniqueIndex:,where:archived_at IS NULL AND display_name != 'Unnamed subscription' AND display_name != ''"`
|
||||
DisplayName *string `gorm:"size:256;uniqueIndex:,where:archived_at IS NULL AND display_name != 'Unnamed subscription' AND display_name != ''"`
|
||||
|
||||
// Timestamps representing the latest timestamps of key conditions related
|
||||
// to this subscription.
|
||||
//
|
||||
// Condition transition details are tracked in 'enterprise_portal_subscription_conditions'.
|
||||
CreatedAt time.Time `gorm:"not null;default:current_timestamp"`
|
||||
UpdatedAt time.Time `gorm:"not null;default:current_timestamp"`
|
||||
ArchivedAt *time.Time // Null indicates the subscription is not archived.
|
||||
CreatedAt utctime.Time `gorm:"not null;default:current_timestamp"`
|
||||
UpdatedAt utctime.Time `gorm:"not null;default:current_timestamp"`
|
||||
ArchivedAt *utctime.Time // Null indicates the subscription is not archived.
|
||||
|
||||
// SalesforceSubscriptionID associated with this Enterprise subscription.
|
||||
SalesforceSubscriptionID *string
|
||||
@ -49,11 +66,7 @@ type Subscription struct {
|
||||
SalesforceOpportunityID *string
|
||||
}
|
||||
|
||||
func (s Subscription) TableName() string {
|
||||
return "enterprise_portal_subscriptions"
|
||||
}
|
||||
|
||||
// subscriptionTableColumns must match s.scan() values.
|
||||
// subscriptionTableColumns must match scanSubscription() values.
|
||||
func subscriptionTableColumns() []string {
|
||||
return []string{
|
||||
"id",
|
||||
@ -67,7 +80,7 @@ func subscriptionTableColumns() []string {
|
||||
}
|
||||
}
|
||||
|
||||
// scanSubscription matches s.columns() values.
|
||||
// scanSubscription matches subscriptionTableColumns() values.
|
||||
func scanSubscription(row pgx.Row) (*Subscription, error) {
|
||||
var s Subscription
|
||||
err := row.Scan(
|
||||
@ -83,13 +96,6 @@ func scanSubscription(row pgx.Row) (*Subscription, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.CreatedAt = s.CreatedAt.UTC()
|
||||
s.UpdatedAt = s.UpdatedAt.UTC()
|
||||
if s.ArchivedAt != nil {
|
||||
s.ArchivedAt = pointers.Ptr(s.ArchivedAt.UTC())
|
||||
}
|
||||
|
||||
return &s, nil
|
||||
}
|
||||
|
||||
@ -172,8 +178,8 @@ WHERE %s
|
||||
}
|
||||
|
||||
type UpsertSubscriptionOptions struct {
|
||||
InstanceDomain string
|
||||
DisplayName string
|
||||
InstanceDomain *sql.NullString
|
||||
DisplayName *sql.NullString
|
||||
|
||||
CreatedAt time.Time
|
||||
ArchivedAt *time.Time
|
||||
|
||||
@ -1,14 +1,9 @@
|
||||
package subscriptions
|
||||
|
||||
import (
|
||||
"time"
|
||||
)
|
||||
import "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/utctime"
|
||||
|
||||
// Subscription is an Enterprise subscription condition record.
|
||||
type SubscriptionCondition struct {
|
||||
// ⚠️ DO NOT USE: This field is only used for creating foreign key constraint.
|
||||
Subscription *Subscription `gorm:"foreignKey:SubscriptionID"`
|
||||
|
||||
// SubscriptionID is the internal unprefixed UUID of the related subscription.
|
||||
SubscriptionID string `gorm:"type:uuid;not null"`
|
||||
// Status is the type of status corresponding to this condition, corresponding
|
||||
@ -18,7 +13,7 @@ type SubscriptionCondition struct {
|
||||
Message *string `gorm:"size:256"`
|
||||
// TransitionTime is the time at which the condition was created, i.e. when
|
||||
// the subscription transitioned into this status.
|
||||
TransitionTime time.Time `gorm:"not null;default:current_timestamp"`
|
||||
TransitionTime utctime.Time `gorm:"not null;default:current_timestamp"`
|
||||
}
|
||||
|
||||
func (s *SubscriptionCondition) TableName() string {
|
||||
|
||||
@ -10,6 +10,7 @@ import (
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/databasetest"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/subscriptions"
|
||||
@ -20,7 +21,7 @@ func TestSubscriptionsStore(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
db := databasetest.NewTestDB(t, "enterprise-portal", "SubscriptionsStore", tables.All()...)
|
||||
db := databasetest.NewTestDB(t, "enterprise-portal", t.Name(), tables.All()...)
|
||||
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
@ -45,19 +46,25 @@ func SubscriptionsStoreList(t *testing.T, ctx context.Context, s *subscriptions.
|
||||
s1, err := s.Upsert(
|
||||
ctx,
|
||||
uuid.New().String(),
|
||||
subscriptions.UpsertSubscriptionOptions{InstanceDomain: "s1.sourcegraph.com"},
|
||||
subscriptions.UpsertSubscriptionOptions{
|
||||
InstanceDomain: pointers.Ptr(database.NewNullString("s1.sourcegraph.com")),
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
s2, err := s.Upsert(
|
||||
ctx,
|
||||
uuid.New().String(),
|
||||
subscriptions.UpsertSubscriptionOptions{InstanceDomain: "s2.sourcegraph.com"},
|
||||
subscriptions.UpsertSubscriptionOptions{
|
||||
InstanceDomain: pointers.Ptr(database.NewNullString("s2.sourcegraph.com")),
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
_, err = s.Upsert(
|
||||
ctx,
|
||||
uuid.New().String(),
|
||||
subscriptions.UpsertSubscriptionOptions{InstanceDomain: "s3.sourcegraph.com"},
|
||||
subscriptions.UpsertSubscriptionOptions{
|
||||
InstanceDomain: pointers.Ptr(database.NewNullString("s3.sourcegraph.com")),
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
@ -79,7 +86,7 @@ func SubscriptionsStoreList(t *testing.T, ctx context.Context, s *subscriptions.
|
||||
|
||||
t.Run("list by instance domains", func(t *testing.T) {
|
||||
ss, err := s.List(ctx, subscriptions.ListEnterpriseSubscriptionsOptions{
|
||||
InstanceDomains: []string{s1.InstanceDomain, s2.InstanceDomain}},
|
||||
InstanceDomains: []string{*s1.InstanceDomain, *s2.InstanceDomain}},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, ss, 2)
|
||||
@ -115,14 +122,16 @@ func SubscriptionsStoreUpsert(t *testing.T, ctx context.Context, s *subscription
|
||||
currentSubscription, err := s.Upsert(
|
||||
ctx,
|
||||
uuid.New().String(),
|
||||
subscriptions.UpsertSubscriptionOptions{InstanceDomain: "s1.sourcegraph.com"},
|
||||
subscriptions.UpsertSubscriptionOptions{
|
||||
InstanceDomain: pointers.Ptr(database.NewNullString("s1.sourcegraph.com")),
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := s.Get(ctx, currentSubscription.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, currentSubscription.ID, got.ID)
|
||||
assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain)
|
||||
assert.Equal(t, *currentSubscription.InstanceDomain, *got.InstanceDomain)
|
||||
assert.Empty(t, got.DisplayName)
|
||||
assert.NotZero(t, got.CreatedAt)
|
||||
assert.NotZero(t, got.UpdatedAt)
|
||||
@ -133,17 +142,19 @@ func SubscriptionsStoreUpsert(t *testing.T, ctx context.Context, s *subscription
|
||||
|
||||
got, err = s.Upsert(ctx, currentSubscription.ID, subscriptions.UpsertSubscriptionOptions{})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain)
|
||||
assert.Equal(t,
|
||||
pointers.DerefZero(currentSubscription.InstanceDomain),
|
||||
pointers.DerefZero(got.InstanceDomain))
|
||||
})
|
||||
|
||||
t.Run("update only domain", func(t *testing.T) {
|
||||
t.Cleanup(func() { currentSubscription = got })
|
||||
|
||||
got, err = s.Upsert(ctx, currentSubscription.ID, subscriptions.UpsertSubscriptionOptions{
|
||||
InstanceDomain: "s1-new.sourcegraph.com",
|
||||
InstanceDomain: pointers.Ptr(database.NewNullString("s1-new.sourcegraph.com")),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "s1-new.sourcegraph.com", got.InstanceDomain)
|
||||
assert.Equal(t, "s1-new.sourcegraph.com", pointers.DerefZero(got.InstanceDomain))
|
||||
assert.Equal(t, currentSubscription.DisplayName, got.DisplayName)
|
||||
})
|
||||
|
||||
@ -151,11 +162,11 @@ func SubscriptionsStoreUpsert(t *testing.T, ctx context.Context, s *subscription
|
||||
t.Cleanup(func() { currentSubscription = got })
|
||||
|
||||
got, err = s.Upsert(ctx, currentSubscription.ID, subscriptions.UpsertSubscriptionOptions{
|
||||
DisplayName: "My New Display Name",
|
||||
DisplayName: pointers.Ptr(database.NewNullString("My New Display Name")),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain)
|
||||
assert.Equal(t, "My New Display Name", got.DisplayName)
|
||||
assert.Equal(t, *currentSubscription.InstanceDomain, *got.InstanceDomain)
|
||||
assert.Equal(t, "My New Display Name", pointers.DerefZero(got.DisplayName))
|
||||
})
|
||||
|
||||
t.Run("update only created at", func(t *testing.T) {
|
||||
@ -166,10 +177,12 @@ func SubscriptionsStoreUpsert(t *testing.T, ctx context.Context, s *subscription
|
||||
CreatedAt: yesterday,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain)
|
||||
assert.Equal(t,
|
||||
pointers.DerefZero(currentSubscription.InstanceDomain),
|
||||
pointers.DerefZero(got.InstanceDomain))
|
||||
assert.Equal(t, currentSubscription.DisplayName, got.DisplayName)
|
||||
// Round times to allow for some precision drift in CI
|
||||
assert.Equal(t, yesterday.Round(time.Second).UTC(), got.CreatedAt.Round(time.Second))
|
||||
assert.Equal(t, yesterday.Round(time.Second).UTC(), got.CreatedAt.GetTime().Round(time.Second))
|
||||
})
|
||||
|
||||
t.Run("update only archived at", func(t *testing.T) {
|
||||
@ -180,11 +193,11 @@ func SubscriptionsStoreUpsert(t *testing.T, ctx context.Context, s *subscription
|
||||
ArchivedAt: pointers.Ptr(yesterday),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain)
|
||||
assert.Equal(t, currentSubscription.DisplayName, got.DisplayName)
|
||||
assert.Equal(t, *currentSubscription.InstanceDomain, *got.InstanceDomain)
|
||||
assert.Equal(t, *currentSubscription.DisplayName, *got.DisplayName)
|
||||
assert.Equal(t, currentSubscription.CreatedAt, got.CreatedAt)
|
||||
// Round times to allow for some precision drift in CI
|
||||
assert.Equal(t, yesterday.Round(time.Second).UTC(), got.ArchivedAt.Round(time.Second))
|
||||
assert.Equal(t, yesterday.Round(time.Second).UTC(), got.ArchivedAt.GetTime().Round(time.Second))
|
||||
})
|
||||
|
||||
t.Run("force update to zero values", func(t *testing.T) {
|
||||
@ -209,7 +222,9 @@ func SubscriptionsStoreGet(t *testing.T, ctx context.Context, s *subscriptions.S
|
||||
s1, err := s.Upsert(
|
||||
ctx,
|
||||
uuid.New().String(),
|
||||
subscriptions.UpsertSubscriptionOptions{InstanceDomain: "s1.sourcegraph.com"},
|
||||
subscriptions.UpsertSubscriptionOptions{
|
||||
InstanceDomain: pointers.Ptr(database.NewNullString("s1.sourcegraph.com")),
|
||||
},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
10
cmd/enterprise-portal/internal/database/types.go
Normal file
10
cmd/enterprise-portal/internal/database/types.go
Normal file
@ -0,0 +1,10 @@
|
||||
package database
|
||||
|
||||
import "database/sql"
|
||||
|
||||
func NewNullString(v string) sql.NullString {
|
||||
return sql.NullString{
|
||||
String: v,
|
||||
Valid: v != "",
|
||||
}
|
||||
}
|
||||
@ -68,8 +68,8 @@ func convertSubscriptionToProto(subscription *subscriptions.Subscription, attrs
|
||||
return &subscriptionsv1.EnterpriseSubscription{
|
||||
Id: subscriptionsv1.EnterpriseSubscriptionIDPrefix + attrs.ID,
|
||||
Conditions: conds,
|
||||
InstanceDomain: subscription.InstanceDomain,
|
||||
DisplayName: subscription.DisplayName,
|
||||
InstanceDomain: pointers.DerefZero(subscription.InstanceDomain),
|
||||
DisplayName: pointers.DerefZero(subscription.DisplayName),
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -16,8 +16,10 @@ import (
|
||||
subscriptionsv1connect "github.com/sourcegraph/sourcegraph/lib/enterpriseportal/subscriptions/v1/v1connect"
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
"github.com/sourcegraph/sourcegraph/lib/managedservicesplatform/iam"
|
||||
"github.com/sourcegraph/sourcegraph/lib/pointers"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/connectutil"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/subscriptions"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/dotcomdb"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/samsm2m"
|
||||
@ -328,31 +330,41 @@ func (s *handlerV1) UpdateEnterpriseSubscription(ctx context.Context, req *conne
|
||||
// Empty field paths means update all non-empty fields.
|
||||
if len(fieldPaths) == 0 {
|
||||
if v := req.Msg.GetSubscription().GetInstanceDomain(); v != "" {
|
||||
opts.InstanceDomain = v
|
||||
opts.InstanceDomain = pointers.Ptr(database.NewNullString(v))
|
||||
}
|
||||
if v := req.Msg.GetSubscription().GetDisplayName(); v != "" {
|
||||
opts.DisplayName = v
|
||||
opts.DisplayName = pointers.Ptr(database.NewNullString(v))
|
||||
}
|
||||
} else {
|
||||
for _, p := range fieldPaths {
|
||||
switch p {
|
||||
case "instance_domain":
|
||||
opts.InstanceDomain = req.Msg.GetSubscription().GetInstanceDomain()
|
||||
opts.InstanceDomain = pointers.Ptr(
|
||||
database.NewNullString(req.Msg.GetSubscription().GetInstanceDomain()),
|
||||
)
|
||||
case "display_name":
|
||||
opts.DisplayName = req.Msg.GetSubscription().GetDisplayName()
|
||||
opts.DisplayName = pointers.Ptr(
|
||||
database.NewNullString(req.Msg.GetSubscription().GetDisplayName()),
|
||||
)
|
||||
case "*":
|
||||
opts.ForceUpdate = true
|
||||
opts.InstanceDomain = req.Msg.GetSubscription().GetInstanceDomain()
|
||||
opts.InstanceDomain = pointers.Ptr(
|
||||
database.NewNullString(req.Msg.GetSubscription().GetInstanceDomain()),
|
||||
)
|
||||
opts.DisplayName = pointers.Ptr(
|
||||
database.NewNullString(req.Msg.GetSubscription().GetDisplayName()),
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Validate and normalize the domain
|
||||
if opts.InstanceDomain != "" {
|
||||
opts.InstanceDomain, err = subscriptionsv1.NormalizeInstanceDomain(opts.InstanceDomain)
|
||||
if opts.InstanceDomain != nil && opts.InstanceDomain.Valid {
|
||||
normalizedDomain, err := subscriptionsv1.NormalizeInstanceDomain(opts.InstanceDomain.String)
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, errors.Wrap(err, "invalid instance domain"))
|
||||
}
|
||||
opts.InstanceDomain.String = normalizedDomain
|
||||
}
|
||||
|
||||
subscription, err := s.store.UpsertEnterpriseSubscription(ctx, subscriptionID, opts)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user