diff --git a/cmd/enterprise-portal/internal/database/internal/upsert/BUILD.bazel b/cmd/enterprise-portal/internal/database/internal/upsert/BUILD.bazel new file mode 100644 index 00000000000..458916402c0 --- /dev/null +++ b/cmd/enterprise-portal/internal/database/internal/upsert/BUILD.bazel @@ -0,0 +1,26 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") +load("//dev:go_defs.bzl", "go_test") + +go_library( + name = "upsert", + srcs = ["upsert.go"], + importpath = "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/upsert", + visibility = ["//cmd/enterprise-portal:__subpackages__"], + deps = [ + "@com_github_jackc_pgx_v5//:pgx", + "@com_github_jackc_pgx_v5//pgxpool", + ], +) + +go_test( + name = "upsert_test", + srcs = ["upsert_test.go"], + embed = [":upsert"], + deps = [ + "//lib/pointers", + "@com_github_hexops_autogold_v2//:autogold", + "@com_github_hexops_valast//:valast", + "@com_github_jackc_pgx_v5//:pgx", + "@com_github_stretchr_testify//assert", + ], +) diff --git a/cmd/enterprise-portal/internal/database/internal/upsert/upsert.go b/cmd/enterprise-portal/internal/database/internal/upsert/upsert.go new file mode 100644 index 00000000000..91875248a70 --- /dev/null +++ b/cmd/enterprise-portal/internal/database/internal/upsert/upsert.go @@ -0,0 +1,138 @@ +package upsert + +import ( + "context" + "fmt" + "strings" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" +) + +// Trick to avoid user-provided string values - outside this package, this type +// can only be fulfilled by a constant string value. +type constString string + +type Builder struct { + table constString + primaryKey constString + + insertColumns []string + args pgx.NamedArgs + updateColumns []string + + forceUpdate bool +} + +// New instantiates an upsert.Builder that can be used with `upsert.Field(b, ...)` +// to implement the database layer for the upsert pattern common in gRPC 'update', +// methods, per the AIP: https://google.aip.dev/134 +func New(table, primaryKey constString, forceUpdate bool) *Builder { + return &Builder{ + table: table, + primaryKey: primaryKey, + forceUpdate: forceUpdate, + args: pgx.NamedArgs{}, + } +} + +type fieldOptions struct { + useColumnDefault bool + ignoreOnForceUpdate bool +} + +type fieldOptionFn func(*fieldOptions) + +func (fn fieldOptionFn) apply(opt *fieldOptions) { fn(opt) } + +type FieldOption interface { + apply(*fieldOptions) +} + +// WithColumnDefault indicates that the field should not be included in an upsert +// if the field has a zero value, which allows the column default to be used. +// +// It does NOT apply in a force update. +func WithColumnDefault() FieldOption { + return fieldOptionFn(func(opt *fieldOptions) { opt.useColumnDefault = true }) +} + +// WithIgnoreOnForceUpdate indicates that the field should not be updated when +// performing a force update. +func WithIgnoreOnForceUpdate() FieldOption { + return fieldOptionFn(func(opt *fieldOptions) { opt.ignoreOnForceUpdate = true }) +} + +// Field registers a field that can be set in the upsert to value T. If T is +// a zero value, the field is not set on an update, UNLESS the `forceUpdate` +// parameter was provided as `true` to upsert.New(...). +func Field[T comparable](b *Builder, column constString, value T, opts ...FieldOption) { + opt := fieldOptions{} + for _, o := range opts { + o.apply(&opt) + } + var zero T + + // If upsert has a zero value, and we would prefer to use the column default, + // do nothing, unless we are performing a force-update across all fields. + if !b.forceUpdate && (zero == value && opt.useColumnDefault) { + return + } + + // If we are force-updating, and the field is marked to be ignored, do nothing. + if b.forceUpdate && opt.ignoreOnForceUpdate { + return + } + + b.insertColumns = append(b.insertColumns, string(column)) + b.args[string(column)] = value + + // If we are force-updating, or value is not zero, update the column in + // existing rows (on conflict). + if b.forceUpdate || value != zero { + b.updateColumns = append(b.updateColumns, string(column)) + } +} + +func (b *Builder) buildQuery() (string, bool) { + if len(b.updateColumns) == 0 { + return "", false + } + + onConflictSets := make([]string, len(b.updateColumns)) + for i, c := range b.updateColumns { + onConflictSets[i] = fmt.Sprintf("%[1]s = EXCLUDED.%[1]s", c) + } + + insertArgNames := make([]string, len(b.insertColumns)) + for i, c := range b.insertColumns { + insertArgNames[i] = fmt.Sprintf("@%s", c) + } + + return fmt.Sprintf(` +INSERT INTO %[1]s + (%[2]s) +VALUES + (%[3]s) +ON CONFLICT + (%[4]s) +DO UPDATE SET + %[5]s`, + b.table, // %[1]s + strings.Join(b.insertColumns, ", "), // %[2]s + strings.Join(insertArgNames, ", "), // %[3]s + b.primaryKey, // %[4]s + strings.Join(onConflictSets, ",\n"), // %[5]s + ), true +} + +func (b *Builder) Exec(ctx context.Context, db *pgxpool.Pool) error { + q, ok := b.buildQuery() + if !ok { + return nil + } + if _, err := db.Exec(ctx, q, b.args); err != nil { + return err + } + return nil +} diff --git a/cmd/enterprise-portal/internal/database/internal/upsert/upsert_test.go b/cmd/enterprise-portal/internal/database/internal/upsert/upsert_test.go new file mode 100644 index 00000000000..f1bc8150281 --- /dev/null +++ b/cmd/enterprise-portal/internal/database/internal/upsert/upsert_test.go @@ -0,0 +1,113 @@ +package upsert + +import ( + "testing" + "time" + + "github.com/hexops/autogold/v2" + "github.com/hexops/valast" + "github.com/jackc/pgx/v5" + "github.com/stretchr/testify/assert" + + "github.com/sourcegraph/sourcegraph/lib/pointers" +) + +func TestBuilder(t *testing.T) { + mockTime := time.Date(2024, 7, 8, 16, 39, 16, 4277000, time.Local) + for _, tc := range []struct { + name string + + forceUpdate bool + upsertFields func(b *Builder) + + wantQuery autogold.Value + wantArgs autogold.Value + }{ + { + name: "not force-update", + forceUpdate: false, + upsertFields: func(b *Builder) { + Field[*string](b, "col1", nil) + + // WithIgnoreOnForceUpdate() does nothing because the we are not + // in a force-update. + Field(b, "col2", pointers.Ptr("value2"), WithIgnoreOnForceUpdate()) + + // WithColumnDefault() does nothing because the time is not zero + Field(b, "time", mockTime, WithColumnDefault()) + + // Do not set, it should use the default value. + Field(b, "should_be_ignored", "", WithColumnDefault()) + }, + wantQuery: autogold.Expect(` +INSERT INTO table +(col1, col2, time) +VALUES +(@col1, @col2, @time) +ON CONFLICT +(id) +DO UPDATE SET +col2 = EXCLUDED.col2, +time = EXCLUDED.time`), + wantArgs: autogold.Expect(pgx.NamedArgs{ + "col1": nil, "col2": valast.Ptr("value2"), + "time": time.Date(2024, + 7, + 8, + 16, + 39, + 16, + 4277000, + time.Local), + }), + }, + { + name: "force-update", + forceUpdate: true, + upsertFields: func(b *Builder) { + Field[*string](b, "col1", nil) + Field(b, "col2", pointers.Ptr("value2")) + Field(b, "time", mockTime) + + // Do not set, it cannot be updated in a force-update. + Field(b, "should_be_ignored", "", WithIgnoreOnForceUpdate()) + }, + wantQuery: autogold.Expect(` +INSERT INTO table +(col1, col2, time) +VALUES +(@col1, @col2, @time) +ON CONFLICT +(id) +DO UPDATE SET +col1 = EXCLUDED.col1, +col2 = EXCLUDED.col2, +time = EXCLUDED.time`), + wantArgs: autogold.Expect(pgx.NamedArgs{ + "col1": nil, "col2": valast.Ptr("value2"), + "time": time.Date(2024, + 7, + 8, + 16, + 39, + 16, + 4277000, + time.Local), + }), + }, + } { + t.Run(tc.name, func(t *testing.T) { + b := New("table", "id", tc.forceUpdate) + tc.upsertFields(b) + + q, ok := b.buildQuery() + if tc.wantQuery == nil && tc.wantArgs == nil { + assert.False(t, ok) + } else { + assert.True(t, ok) + tc.wantQuery.Equal(t, q) + tc.wantArgs.Equal(t, b.args) + } + }) + } +} diff --git a/cmd/enterprise-portal/internal/database/subscriptions/BUILD.bazel b/cmd/enterprise-portal/internal/database/subscriptions/BUILD.bazel index 7a26a644202..d73ea1939da 100644 --- a/cmd/enterprise-portal/internal/database/subscriptions/BUILD.bazel +++ b/cmd/enterprise-portal/internal/database/subscriptions/BUILD.bazel @@ -13,7 +13,9 @@ go_library( tags = [TAG_INFRA_CORESERVICES], visibility = ["//cmd/enterprise-portal:__subpackages__"], deps = [ + "//cmd/enterprise-portal/internal/database/internal/upsert", "//lib/errors", + "//lib/pointers", "@com_github_jackc_pgtype//:pgtype", "@com_github_jackc_pgx_v5//:pgx", "@com_github_jackc_pgx_v5//pgxpool", @@ -31,6 +33,7 @@ go_test( ":subscriptions", "//cmd/enterprise-portal/internal/database/databasetest", "//cmd/enterprise-portal/internal/database/internal/tables", + "//lib/pointers", "@com_github_google_uuid//:uuid", "@com_github_jackc_pgx_v5//:pgx", "@com_github_stretchr_testify//assert", diff --git a/cmd/enterprise-portal/internal/database/subscriptions/subscriptions.go b/cmd/enterprise-portal/internal/database/subscriptions/subscriptions.go index 60666e2c694..94e8d8e0297 100644 --- a/cmd/enterprise-portal/internal/database/subscriptions/subscriptions.go +++ b/cmd/enterprise-portal/internal/database/subscriptions/subscriptions.go @@ -9,7 +9,9 @@ import ( "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" + "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/upsert" "github.com/sourcegraph/sourcegraph/lib/errors" + "github.com/sourcegraph/sourcegraph/lib/pointers" ) // Subscription is an Enterprise subscription record. @@ -26,8 +28,12 @@ type Subscription struct { // DisplayName is the human-friendly name of this subscription, e.g. "Acme, Inc." // - // It must be unique across all currently un-archived subscriptions. - DisplayName string `gorm:"size:256;not null;uniqueIndex:,where:archived_at IS NULL AND display_name != 'Unnamed subscription';default:'Unnamed subscription'"` + // It must be unique across all currently un-archived subscriptions, unless + // it is not set. + // + // TODO: Clean up the database post-deploy and remove the 'Unnamed subscription' + // part of the constraint. + DisplayName string `gorm:"size:256;not null;uniqueIndex:,where:archived_at IS NULL AND display_name != 'Unnamed subscription' AND display_name != ''"` // Timestamps representing the latest timestamps of key conditions related // to this subscription. @@ -43,10 +49,50 @@ type Subscription struct { SalesforceOpportunityID *string } -func (s *Subscription) TableName() string { +func (s Subscription) TableName() string { return "enterprise_portal_subscriptions" } +// subscriptionTableColumns must match s.scan() values. +func subscriptionTableColumns() []string { + return []string{ + "id", + "instance_domain", + "display_name", + "created_at", + "updated_at", + "archived_at", + "salesforce_subscription_id", + "salesforce_opportunity_id", + } +} + +// scanSubscription matches s.columns() values. +func scanSubscription(row pgx.Row) (*Subscription, error) { + var s Subscription + err := row.Scan( + &s.ID, + &s.InstanceDomain, + &s.DisplayName, + &s.CreatedAt, + &s.UpdatedAt, + &s.ArchivedAt, + &s.SalesforceSubscriptionID, + &s.SalesforceOpportunityID, + ) + if err != nil { + return nil, err + } + + s.CreatedAt = s.CreatedAt.UTC() + s.UpdatedAt = s.UpdatedAt.UTC() + if s.ArchivedAt != nil { + s.ArchivedAt = pointers.Ptr(s.ArchivedAt.UTC()) + } + + return &s, nil +} + // Store is the storage layer for product subscriptions. type Store struct { db *pgxpool.Pool @@ -101,11 +147,11 @@ func (s *Store) List(ctx context.Context, opts ListEnterpriseSubscriptionsOption where, limit, namedArgs := opts.toQueryConditions() query := fmt.Sprintf(` SELECT - id, - instance_domain + %s FROM enterprise_portal_subscriptions WHERE %s %s`, + strings.Join(subscriptionTableColumns(), ", "), where, limit, ) rows, err := s.db.Query(ctx, query, namedArgs) @@ -116,17 +162,25 @@ WHERE %s var subscriptions []*Subscription for rows.Next() { - var subscription Subscription - if err = rows.Scan(&subscription.ID, &subscription.InstanceDomain); err != nil { + sub, err := scanSubscription(rows) + if err != nil { return nil, errors.Wrap(err, "scan row") } - subscriptions = append(subscriptions, &subscription) + subscriptions = append(subscriptions, sub) } return subscriptions, rows.Err() } type UpsertSubscriptionOptions struct { InstanceDomain string + DisplayName string + + CreatedAt time.Time + ArchivedAt *time.Time + + SalesforceSubscriptionID *string + SalesforceOpportunityID *string + // ForceUpdate indicates whether to force update all fields of the subscription // record. ForceUpdate bool @@ -134,40 +188,26 @@ type UpsertSubscriptionOptions struct { // toQuery returns the query based on the options. It returns an empty query if // nothing to update. -func (opts UpsertSubscriptionOptions) toQuery(id string) (query string, _ pgx.NamedArgs) { - const queryFmt = ` -INSERT INTO enterprise_portal_subscriptions (id, instance_domain) -VALUES (@id, @instanceDomain) -ON CONFLICT (id) -DO UPDATE SET - %s` - namedArgs := pgx.NamedArgs{ - "id": id, - "instanceDomain": opts.InstanceDomain, - } +func (opts UpsertSubscriptionOptions) Exec(ctx context.Context, db *pgxpool.Pool, id string) error { + b := upsert.New("enterprise_portal_subscriptions", "id", opts.ForceUpdate) + upsert.Field(b, "id", id) + upsert.Field(b, "instance_domain", opts.InstanceDomain) + upsert.Field(b, "display_name", opts.DisplayName) - var sets []string - if opts.ForceUpdate || opts.InstanceDomain != "" { - sets = append(sets, "instance_domain = excluded.instance_domain") - } - if len(sets) == 0 { - return "", nil - } - query = fmt.Sprintf( - queryFmt, - strings.Join(sets, ", "), - ) - return query, namedArgs + upsert.Field(b, "created_at", opts.CreatedAt, + upsert.WithColumnDefault(), + upsert.WithIgnoreOnForceUpdate()) + upsert.Field(b, "updated_at", time.Now()) // always updated now + upsert.Field(b, "archived_at", opts.ArchivedAt) + upsert.Field(b, "salesforce_subscription_id", opts.SalesforceSubscriptionID) + upsert.Field(b, "salesforce_opportunity_id", opts.SalesforceOpportunityID) + return b.Exec(ctx, db) } // Upsert upserts a subscription record based on the given options. func (s *Store) Upsert(ctx context.Context, subscriptionID string, opts UpsertSubscriptionOptions) (*Subscription, error) { - query, namedArgs := opts.toQuery(subscriptionID) - if query != "" { - _, err := s.db.Exec(ctx, query, namedArgs) - if err != nil { - return nil, errors.Wrap(err, "exec") - } + if err := opts.Exec(ctx, s.db, subscriptionID); err != nil { + return nil, errors.Wrap(err, "exec") } return s.Get(ctx, subscriptionID) } @@ -175,12 +215,18 @@ func (s *Store) Upsert(ctx context.Context, subscriptionID string, opts UpsertSu // Get returns a subscription record with the given subscription ID. It returns // pgx.ErrNoRows if no such subscription exists. func (s *Store) Get(ctx context.Context, subscriptionID string) (*Subscription, error) { - var subscription Subscription - query := `SELECT id, instance_domain FROM enterprise_portal_subscriptions WHERE id = @id` + query := fmt.Sprintf(`SELECT + %s + FROM + enterprise_portal_subscriptions + WHERE + id = @id`, + strings.Join(subscriptionTableColumns(), ", ")) namedArgs := pgx.NamedArgs{"id": subscriptionID} - err := s.db.QueryRow(ctx, query, namedArgs).Scan(&subscription.ID, &subscription.InstanceDomain) + + sub, err := scanSubscription(s.db.QueryRow(ctx, query, namedArgs)) if err != nil { return nil, err } - return &subscription, nil + return sub, nil } diff --git a/cmd/enterprise-portal/internal/database/subscriptions/subscriptions_test.go b/cmd/enterprise-portal/internal/database/subscriptions/subscriptions_test.go index d12a817cda1..4aacc18c881 100644 --- a/cmd/enterprise-portal/internal/database/subscriptions/subscriptions_test.go +++ b/cmd/enterprise-portal/internal/database/subscriptions/subscriptions_test.go @@ -3,6 +3,7 @@ package subscriptions_test import ( "context" "testing" + "time" "github.com/google/uuid" "github.com/jackc/pgx/v5" @@ -12,6 +13,7 @@ import ( "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/databasetest" "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables" "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/subscriptions" + "github.com/sourcegraph/sourcegraph/lib/pointers" ) func TestSubscriptionsStore(t *testing.T) { @@ -107,35 +109,98 @@ func SubscriptionsStoreList(t *testing.T, ctx context.Context, s *subscriptions. } func SubscriptionsStoreUpsert(t *testing.T, ctx context.Context, s *subscriptions.Store) { - // Create initial test record. - s1, err := s.Upsert( + // Create initial test record. The currentSubscription should be reassigned + // throughout various test cases to represent the current state of the test + // record, as the subtests are run in sequence. + currentSubscription, err := s.Upsert( ctx, uuid.New().String(), subscriptions.UpsertSubscriptionOptions{InstanceDomain: "s1.sourcegraph.com"}, ) require.NoError(t, err) - got, err := s.Get(ctx, s1.ID) + got, err := s.Get(ctx, currentSubscription.ID) require.NoError(t, err) - assert.Equal(t, s1.ID, got.ID) - assert.Equal(t, s1.InstanceDomain, got.InstanceDomain) + assert.Equal(t, currentSubscription.ID, got.ID) + assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain) + assert.Empty(t, got.DisplayName) + assert.NotZero(t, got.CreatedAt) + assert.NotZero(t, got.UpdatedAt) + assert.Nil(t, got.ArchivedAt) // not archived yet t.Run("noop", func(t *testing.T) { - got, err = s.Upsert(ctx, s1.ID, subscriptions.UpsertSubscriptionOptions{}) + t.Cleanup(func() { currentSubscription = got }) + + got, err = s.Upsert(ctx, currentSubscription.ID, subscriptions.UpsertSubscriptionOptions{}) require.NoError(t, err) - assert.Equal(t, s1.InstanceDomain, got.InstanceDomain) + assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain) }) - t.Run("update", func(t *testing.T) { - got, err = s.Upsert(ctx, s1.ID, subscriptions.UpsertSubscriptionOptions{InstanceDomain: "s1-new.sourcegraph.com"}) + t.Run("update only domain", func(t *testing.T) { + t.Cleanup(func() { currentSubscription = got }) + + got, err = s.Upsert(ctx, currentSubscription.ID, subscriptions.UpsertSubscriptionOptions{ + InstanceDomain: "s1-new.sourcegraph.com", + }) require.NoError(t, err) assert.Equal(t, "s1-new.sourcegraph.com", got.InstanceDomain) + assert.Equal(t, currentSubscription.DisplayName, got.DisplayName) }) - t.Run("force update", func(t *testing.T) { - got, err = s.Upsert(ctx, s1.ID, subscriptions.UpsertSubscriptionOptions{ForceUpdate: true}) + t.Run("update only display name", func(t *testing.T) { + t.Cleanup(func() { currentSubscription = got }) + + got, err = s.Upsert(ctx, currentSubscription.ID, subscriptions.UpsertSubscriptionOptions{ + DisplayName: "My New Display Name", + }) + require.NoError(t, err) + assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain) + assert.Equal(t, "My New Display Name", got.DisplayName) + }) + + t.Run("update only created at", func(t *testing.T) { + t.Cleanup(func() { currentSubscription = got }) + + yesterday := time.Now().Add(-24 * time.Hour) + got, err = s.Upsert(ctx, currentSubscription.ID, subscriptions.UpsertSubscriptionOptions{ + CreatedAt: yesterday, + }) + require.NoError(t, err) + assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain) + assert.Equal(t, currentSubscription.DisplayName, got.DisplayName) + // Round times to allow for some precision drift in CI + assert.Equal(t, yesterday.Round(time.Second).UTC(), got.CreatedAt.Round(time.Second)) + }) + + t.Run("update only archived at", func(t *testing.T) { + t.Cleanup(func() { currentSubscription = got }) + + yesterday := time.Now().Add(-24 * time.Hour) + got, err = s.Upsert(ctx, currentSubscription.ID, subscriptions.UpsertSubscriptionOptions{ + ArchivedAt: pointers.Ptr(yesterday), + }) + require.NoError(t, err) + assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain) + assert.Equal(t, currentSubscription.DisplayName, got.DisplayName) + assert.Equal(t, currentSubscription.CreatedAt, got.CreatedAt) + // Round times to allow for some precision drift in CI + assert.Equal(t, yesterday.Round(time.Second).UTC(), got.ArchivedAt.Round(time.Second)) + }) + + t.Run("force update to zero values", func(t *testing.T) { + t.Cleanup(func() { currentSubscription = got }) + + got, err = s.Upsert(ctx, currentSubscription.ID, subscriptions.UpsertSubscriptionOptions{ + ForceUpdate: true, + }) require.NoError(t, err) assert.Empty(t, got.InstanceDomain) + assert.Empty(t, got.DisplayName) + assert.Nil(t, got.ArchivedAt) + + // Some fields cannot be updated in a force-update. + assert.Equal(t, currentSubscription.ID, got.ID) + assert.Equal(t, currentSubscription.CreatedAt, got.CreatedAt) }) } diff --git a/cmd/enterprise-portal/internal/subscriptionsservice/adapters.go b/cmd/enterprise-portal/internal/subscriptionsservice/adapters.go index 1bf84794513..52ed657b067 100644 --- a/cmd/enterprise-portal/internal/subscriptionsservice/adapters.go +++ b/cmd/enterprise-portal/internal/subscriptionsservice/adapters.go @@ -65,11 +65,11 @@ func convertSubscriptionToProto(subscription *subscriptions.Subscription, attrs LastTransitionTime: timestamppb.New(*attrs.ArchivedAt), }) } - return &subscriptionsv1.EnterpriseSubscription{ Id: subscriptionsv1.EnterpriseSubscriptionIDPrefix + attrs.ID, Conditions: conds, InstanceDomain: subscription.InstanceDomain, + DisplayName: subscription.DisplayName, } } diff --git a/cmd/enterprise-portal/internal/subscriptionsservice/v1.go b/cmd/enterprise-portal/internal/subscriptionsservice/v1.go index c6e19d75c5d..7e4a16cbc95 100644 --- a/cmd/enterprise-portal/internal/subscriptionsservice/v1.go +++ b/cmd/enterprise-portal/internal/subscriptionsservice/v1.go @@ -311,7 +311,8 @@ func (s *handlerV1) UpdateEnterpriseSubscription(ctx context.Context, req *conne return nil, connect.NewError(connect.CodeInvalidArgument, errors.New("subscription.id is required")) } - // Double check with the dotcom DB that the subscription ID is valid. + // TEMPORARY: Double check with the dotcom DB that the subscription ID is valid. + // This currently ensures we never actually create new subscriptions. subscriptionAttrs, err := s.store.ListDotcomEnterpriseSubscriptions(ctx, dotcomdb.ListEnterpriseSubscriptionsOptions{ SubscriptionIDs: []string{subscriptionID}, }) @@ -329,11 +330,16 @@ func (s *handlerV1) UpdateEnterpriseSubscription(ctx context.Context, req *conne if v := req.Msg.GetSubscription().GetInstanceDomain(); v != "" { opts.InstanceDomain = v } + if v := req.Msg.GetSubscription().GetDisplayName(); v != "" { + opts.DisplayName = v + } } else { for _, p := range fieldPaths { switch p { case "instance_domain": opts.InstanceDomain = req.Msg.GetSubscription().GetInstanceDomain() + case "display_name": + opts.DisplayName = req.Msg.GetSubscription().GetDisplayName() case "*": opts.ForceUpdate = true opts.InstanceDomain = req.Msg.GetSubscription().GetInstanceDomain() diff --git a/cmd/enterprise-portal/internal/subscriptionsservice/v1_test.go b/cmd/enterprise-portal/internal/subscriptionsservice/v1_test.go index 04c44ed0243..99f0a20c67c 100644 --- a/cmd/enterprise-portal/internal/subscriptionsservice/v1_test.go +++ b/cmd/enterprise-portal/internal/subscriptionsservice/v1_test.go @@ -243,6 +243,7 @@ func TestHandlerV1_UpdateEnterpriseSubscription(t *testing.T) { Subscription: &subscriptionsv1.EnterpriseSubscription{ Id: "80ca12e2-54b4-448c-a61a-390b1a9c1224", InstanceDomain: "s1.sourcegraph.com", + DisplayName: "My Test Subscription", }, UpdateMask: nil, }) @@ -252,6 +253,7 @@ func TestHandlerV1_UpdateEnterpriseSubscription(t *testing.T) { h.mockStore.ListDotcomEnterpriseSubscriptionsFunc.SetDefaultReturn([]*dotcomdb.SubscriptionAttributes{{ID: "80ca12e2-54b4-448c-a61a-390b1a9c1224"}}, nil) h.mockStore.UpsertEnterpriseSubscriptionFunc.SetDefaultHook(func(_ context.Context, _ string, opts subscriptions.UpsertSubscriptionOptions) (*subscriptions.Subscription, error) { assert.NotEmpty(t, opts.InstanceDomain) + assert.NotEmpty(t, opts.DisplayName) assert.False(t, opts.ForceUpdate) return &subscriptions.Subscription{}, nil }) @@ -265,6 +267,7 @@ func TestHandlerV1_UpdateEnterpriseSubscription(t *testing.T) { Subscription: &subscriptionsv1.EnterpriseSubscription{ Id: "80ca12e2-54b4-448c-a61a-390b1a9c1224", InstanceDomain: "s1.sourcegraph.com", + DisplayName: "My Test Subscription", // should not be included }, UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"instance_domain"}}, }) @@ -274,6 +277,7 @@ func TestHandlerV1_UpdateEnterpriseSubscription(t *testing.T) { h.mockStore.ListDotcomEnterpriseSubscriptionsFunc.SetDefaultReturn([]*dotcomdb.SubscriptionAttributes{{ID: "80ca12e2-54b4-448c-a61a-390b1a9c1224"}}, nil) h.mockStore.UpsertEnterpriseSubscriptionFunc.SetDefaultHook(func(_ context.Context, _ string, opts subscriptions.UpsertSubscriptionOptions) (*subscriptions.Subscription, error) { assert.NotEmpty(t, opts.InstanceDomain) + assert.Empty(t, opts.DisplayName) assert.False(t, opts.ForceUpdate) return &subscriptions.Subscription{}, nil }) diff --git a/lib/enterpriseportal/subscriptions/v1/subscriptions.pb.go b/lib/enterpriseportal/subscriptions/v1/subscriptions.pb.go index 8ef0935e7f7..bb28cf95a53 100644 --- a/lib/enterpriseportal/subscriptions/v1/subscriptions.pb.go +++ b/lib/enterpriseportal/subscriptions/v1/subscriptions.pb.go @@ -1538,6 +1538,7 @@ type UpdateEnterpriseSubscriptionRequest struct { // The list of fields to update, fields are specified relative to the EnterpriseSubscription. // Updatable fields are: // - instance_domain + // - display_name UpdateMask *fieldmaskpb.FieldMask `protobuf:"bytes,2,opt,name=update_mask,json=updateMask,proto3" json:"update_mask,omitempty"` } diff --git a/lib/enterpriseportal/subscriptions/v1/subscriptions.proto b/lib/enterpriseportal/subscriptions/v1/subscriptions.proto index b43693913af..3c5688d21e2 100644 --- a/lib/enterpriseportal/subscriptions/v1/subscriptions.proto +++ b/lib/enterpriseportal/subscriptions/v1/subscriptions.proto @@ -327,6 +327,7 @@ message UpdateEnterpriseSubscriptionRequest { // The list of fields to update, fields are specified relative to the EnterpriseSubscription. // Updatable fields are: // - instance_domain + // - display_name google.protobuf.FieldMask update_mask = 2; }