feat/enterpriseportal: database layer for subscriptions upsert (#63703)

Implements upsert for all the subscriptions fields in the DB client. As
part of this I generalized the logic for building upsert DB interactions
into a new `upsert` package, because this pattern is a common one we'll
need to implement to maintain various AIP-update-compliant endpoints,
which specifies various upsert behaviours: https://google.aip.dev/134

Part of CORE-216
Part of CORE-156

## Test plan

Integration tests against DB
This commit is contained in:
Robert Lin 2024-07-09 14:35:00 -07:00 committed by GitHub
parent d7ab268385
commit 28f797e866
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 457 additions and 54 deletions

View File

@ -0,0 +1,26 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
load("//dev:go_defs.bzl", "go_test")
go_library(
name = "upsert",
srcs = ["upsert.go"],
importpath = "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/upsert",
visibility = ["//cmd/enterprise-portal:__subpackages__"],
deps = [
"@com_github_jackc_pgx_v5//:pgx",
"@com_github_jackc_pgx_v5//pgxpool",
],
)
go_test(
name = "upsert_test",
srcs = ["upsert_test.go"],
embed = [":upsert"],
deps = [
"//lib/pointers",
"@com_github_hexops_autogold_v2//:autogold",
"@com_github_hexops_valast//:valast",
"@com_github_jackc_pgx_v5//:pgx",
"@com_github_stretchr_testify//assert",
],
)

View File

@ -0,0 +1,138 @@
package upsert
import (
"context"
"fmt"
"strings"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
)
// Trick to avoid user-provided string values - outside this package, this type
// can only be fulfilled by a constant string value.
type constString string
type Builder struct {
table constString
primaryKey constString
insertColumns []string
args pgx.NamedArgs
updateColumns []string
forceUpdate bool
}
// New instantiates an upsert.Builder that can be used with `upsert.Field(b, ...)`
// to implement the database layer for the upsert pattern common in gRPC 'update',
// methods, per the AIP: https://google.aip.dev/134
func New(table, primaryKey constString, forceUpdate bool) *Builder {
return &Builder{
table: table,
primaryKey: primaryKey,
forceUpdate: forceUpdate,
args: pgx.NamedArgs{},
}
}
type fieldOptions struct {
useColumnDefault bool
ignoreOnForceUpdate bool
}
type fieldOptionFn func(*fieldOptions)
func (fn fieldOptionFn) apply(opt *fieldOptions) { fn(opt) }
type FieldOption interface {
apply(*fieldOptions)
}
// WithColumnDefault indicates that the field should not be included in an upsert
// if the field has a zero value, which allows the column default to be used.
//
// It does NOT apply in a force update.
func WithColumnDefault() FieldOption {
return fieldOptionFn(func(opt *fieldOptions) { opt.useColumnDefault = true })
}
// WithIgnoreOnForceUpdate indicates that the field should not be updated when
// performing a force update.
func WithIgnoreOnForceUpdate() FieldOption {
return fieldOptionFn(func(opt *fieldOptions) { opt.ignoreOnForceUpdate = true })
}
// Field registers a field that can be set in the upsert to value T. If T is
// a zero value, the field is not set on an update, UNLESS the `forceUpdate`
// parameter was provided as `true` to upsert.New(...).
func Field[T comparable](b *Builder, column constString, value T, opts ...FieldOption) {
opt := fieldOptions{}
for _, o := range opts {
o.apply(&opt)
}
var zero T
// If upsert has a zero value, and we would prefer to use the column default,
// do nothing, unless we are performing a force-update across all fields.
if !b.forceUpdate && (zero == value && opt.useColumnDefault) {
return
}
// If we are force-updating, and the field is marked to be ignored, do nothing.
if b.forceUpdate && opt.ignoreOnForceUpdate {
return
}
b.insertColumns = append(b.insertColumns, string(column))
b.args[string(column)] = value
// If we are force-updating, or value is not zero, update the column in
// existing rows (on conflict).
if b.forceUpdate || value != zero {
b.updateColumns = append(b.updateColumns, string(column))
}
}
func (b *Builder) buildQuery() (string, bool) {
if len(b.updateColumns) == 0 {
return "", false
}
onConflictSets := make([]string, len(b.updateColumns))
for i, c := range b.updateColumns {
onConflictSets[i] = fmt.Sprintf("%[1]s = EXCLUDED.%[1]s", c)
}
insertArgNames := make([]string, len(b.insertColumns))
for i, c := range b.insertColumns {
insertArgNames[i] = fmt.Sprintf("@%s", c)
}
return fmt.Sprintf(`
INSERT INTO %[1]s
(%[2]s)
VALUES
(%[3]s)
ON CONFLICT
(%[4]s)
DO UPDATE SET
%[5]s`,
b.table, // %[1]s
strings.Join(b.insertColumns, ", "), // %[2]s
strings.Join(insertArgNames, ", "), // %[3]s
b.primaryKey, // %[4]s
strings.Join(onConflictSets, ",\n"), // %[5]s
), true
}
func (b *Builder) Exec(ctx context.Context, db *pgxpool.Pool) error {
q, ok := b.buildQuery()
if !ok {
return nil
}
if _, err := db.Exec(ctx, q, b.args); err != nil {
return err
}
return nil
}

View File

@ -0,0 +1,113 @@
package upsert
import (
"testing"
"time"
"github.com/hexops/autogold/v2"
"github.com/hexops/valast"
"github.com/jackc/pgx/v5"
"github.com/stretchr/testify/assert"
"github.com/sourcegraph/sourcegraph/lib/pointers"
)
func TestBuilder(t *testing.T) {
mockTime := time.Date(2024, 7, 8, 16, 39, 16, 4277000, time.Local)
for _, tc := range []struct {
name string
forceUpdate bool
upsertFields func(b *Builder)
wantQuery autogold.Value
wantArgs autogold.Value
}{
{
name: "not force-update",
forceUpdate: false,
upsertFields: func(b *Builder) {
Field[*string](b, "col1", nil)
// WithIgnoreOnForceUpdate() does nothing because the we are not
// in a force-update.
Field(b, "col2", pointers.Ptr("value2"), WithIgnoreOnForceUpdate())
// WithColumnDefault() does nothing because the time is not zero
Field(b, "time", mockTime, WithColumnDefault())
// Do not set, it should use the default value.
Field(b, "should_be_ignored", "", WithColumnDefault())
},
wantQuery: autogold.Expect(`
INSERT INTO table
(col1, col2, time)
VALUES
(@col1, @col2, @time)
ON CONFLICT
(id)
DO UPDATE SET
col2 = EXCLUDED.col2,
time = EXCLUDED.time`),
wantArgs: autogold.Expect(pgx.NamedArgs{
"col1": nil, "col2": valast.Ptr("value2"),
"time": time.Date(2024,
7,
8,
16,
39,
16,
4277000,
time.Local),
}),
},
{
name: "force-update",
forceUpdate: true,
upsertFields: func(b *Builder) {
Field[*string](b, "col1", nil)
Field(b, "col2", pointers.Ptr("value2"))
Field(b, "time", mockTime)
// Do not set, it cannot be updated in a force-update.
Field(b, "should_be_ignored", "", WithIgnoreOnForceUpdate())
},
wantQuery: autogold.Expect(`
INSERT INTO table
(col1, col2, time)
VALUES
(@col1, @col2, @time)
ON CONFLICT
(id)
DO UPDATE SET
col1 = EXCLUDED.col1,
col2 = EXCLUDED.col2,
time = EXCLUDED.time`),
wantArgs: autogold.Expect(pgx.NamedArgs{
"col1": nil, "col2": valast.Ptr("value2"),
"time": time.Date(2024,
7,
8,
16,
39,
16,
4277000,
time.Local),
}),
},
} {
t.Run(tc.name, func(t *testing.T) {
b := New("table", "id", tc.forceUpdate)
tc.upsertFields(b)
q, ok := b.buildQuery()
if tc.wantQuery == nil && tc.wantArgs == nil {
assert.False(t, ok)
} else {
assert.True(t, ok)
tc.wantQuery.Equal(t, q)
tc.wantArgs.Equal(t, b.args)
}
})
}
}

View File

@ -13,7 +13,9 @@ go_library(
tags = [TAG_INFRA_CORESERVICES],
visibility = ["//cmd/enterprise-portal:__subpackages__"],
deps = [
"//cmd/enterprise-portal/internal/database/internal/upsert",
"//lib/errors",
"//lib/pointers",
"@com_github_jackc_pgtype//:pgtype",
"@com_github_jackc_pgx_v5//:pgx",
"@com_github_jackc_pgx_v5//pgxpool",
@ -31,6 +33,7 @@ go_test(
":subscriptions",
"//cmd/enterprise-portal/internal/database/databasetest",
"//cmd/enterprise-portal/internal/database/internal/tables",
"//lib/pointers",
"@com_github_google_uuid//:uuid",
"@com_github_jackc_pgx_v5//:pgx",
"@com_github_stretchr_testify//assert",

View File

@ -9,7 +9,9 @@ import (
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/upsert"
"github.com/sourcegraph/sourcegraph/lib/errors"
"github.com/sourcegraph/sourcegraph/lib/pointers"
)
// Subscription is an Enterprise subscription record.
@ -26,8 +28,12 @@ type Subscription struct {
// DisplayName is the human-friendly name of this subscription, e.g. "Acme, Inc."
//
// It must be unique across all currently un-archived subscriptions.
DisplayName string `gorm:"size:256;not null;uniqueIndex:,where:archived_at IS NULL AND display_name != 'Unnamed subscription';default:'Unnamed subscription'"`
// It must be unique across all currently un-archived subscriptions, unless
// it is not set.
//
// TODO: Clean up the database post-deploy and remove the 'Unnamed subscription'
// part of the constraint.
DisplayName string `gorm:"size:256;not null;uniqueIndex:,where:archived_at IS NULL AND display_name != 'Unnamed subscription' AND display_name != ''"`
// Timestamps representing the latest timestamps of key conditions related
// to this subscription.
@ -43,10 +49,50 @@ type Subscription struct {
SalesforceOpportunityID *string
}
func (s *Subscription) TableName() string {
func (s Subscription) TableName() string {
return "enterprise_portal_subscriptions"
}
// subscriptionTableColumns must match s.scan() values.
func subscriptionTableColumns() []string {
return []string{
"id",
"instance_domain",
"display_name",
"created_at",
"updated_at",
"archived_at",
"salesforce_subscription_id",
"salesforce_opportunity_id",
}
}
// scanSubscription matches s.columns() values.
func scanSubscription(row pgx.Row) (*Subscription, error) {
var s Subscription
err := row.Scan(
&s.ID,
&s.InstanceDomain,
&s.DisplayName,
&s.CreatedAt,
&s.UpdatedAt,
&s.ArchivedAt,
&s.SalesforceSubscriptionID,
&s.SalesforceOpportunityID,
)
if err != nil {
return nil, err
}
s.CreatedAt = s.CreatedAt.UTC()
s.UpdatedAt = s.UpdatedAt.UTC()
if s.ArchivedAt != nil {
s.ArchivedAt = pointers.Ptr(s.ArchivedAt.UTC())
}
return &s, nil
}
// Store is the storage layer for product subscriptions.
type Store struct {
db *pgxpool.Pool
@ -101,11 +147,11 @@ func (s *Store) List(ctx context.Context, opts ListEnterpriseSubscriptionsOption
where, limit, namedArgs := opts.toQueryConditions()
query := fmt.Sprintf(`
SELECT
id,
instance_domain
%s
FROM enterprise_portal_subscriptions
WHERE %s
%s`,
strings.Join(subscriptionTableColumns(), ", "),
where, limit,
)
rows, err := s.db.Query(ctx, query, namedArgs)
@ -116,17 +162,25 @@ WHERE %s
var subscriptions []*Subscription
for rows.Next() {
var subscription Subscription
if err = rows.Scan(&subscription.ID, &subscription.InstanceDomain); err != nil {
sub, err := scanSubscription(rows)
if err != nil {
return nil, errors.Wrap(err, "scan row")
}
subscriptions = append(subscriptions, &subscription)
subscriptions = append(subscriptions, sub)
}
return subscriptions, rows.Err()
}
type UpsertSubscriptionOptions struct {
InstanceDomain string
DisplayName string
CreatedAt time.Time
ArchivedAt *time.Time
SalesforceSubscriptionID *string
SalesforceOpportunityID *string
// ForceUpdate indicates whether to force update all fields of the subscription
// record.
ForceUpdate bool
@ -134,40 +188,26 @@ type UpsertSubscriptionOptions struct {
// toQuery returns the query based on the options. It returns an empty query if
// nothing to update.
func (opts UpsertSubscriptionOptions) toQuery(id string) (query string, _ pgx.NamedArgs) {
const queryFmt = `
INSERT INTO enterprise_portal_subscriptions (id, instance_domain)
VALUES (@id, @instanceDomain)
ON CONFLICT (id)
DO UPDATE SET
%s`
namedArgs := pgx.NamedArgs{
"id": id,
"instanceDomain": opts.InstanceDomain,
}
func (opts UpsertSubscriptionOptions) Exec(ctx context.Context, db *pgxpool.Pool, id string) error {
b := upsert.New("enterprise_portal_subscriptions", "id", opts.ForceUpdate)
upsert.Field(b, "id", id)
upsert.Field(b, "instance_domain", opts.InstanceDomain)
upsert.Field(b, "display_name", opts.DisplayName)
var sets []string
if opts.ForceUpdate || opts.InstanceDomain != "" {
sets = append(sets, "instance_domain = excluded.instance_domain")
}
if len(sets) == 0 {
return "", nil
}
query = fmt.Sprintf(
queryFmt,
strings.Join(sets, ", "),
)
return query, namedArgs
upsert.Field(b, "created_at", opts.CreatedAt,
upsert.WithColumnDefault(),
upsert.WithIgnoreOnForceUpdate())
upsert.Field(b, "updated_at", time.Now()) // always updated now
upsert.Field(b, "archived_at", opts.ArchivedAt)
upsert.Field(b, "salesforce_subscription_id", opts.SalesforceSubscriptionID)
upsert.Field(b, "salesforce_opportunity_id", opts.SalesforceOpportunityID)
return b.Exec(ctx, db)
}
// Upsert upserts a subscription record based on the given options.
func (s *Store) Upsert(ctx context.Context, subscriptionID string, opts UpsertSubscriptionOptions) (*Subscription, error) {
query, namedArgs := opts.toQuery(subscriptionID)
if query != "" {
_, err := s.db.Exec(ctx, query, namedArgs)
if err != nil {
return nil, errors.Wrap(err, "exec")
}
if err := opts.Exec(ctx, s.db, subscriptionID); err != nil {
return nil, errors.Wrap(err, "exec")
}
return s.Get(ctx, subscriptionID)
}
@ -175,12 +215,18 @@ func (s *Store) Upsert(ctx context.Context, subscriptionID string, opts UpsertSu
// Get returns a subscription record with the given subscription ID. It returns
// pgx.ErrNoRows if no such subscription exists.
func (s *Store) Get(ctx context.Context, subscriptionID string) (*Subscription, error) {
var subscription Subscription
query := `SELECT id, instance_domain FROM enterprise_portal_subscriptions WHERE id = @id`
query := fmt.Sprintf(`SELECT
%s
FROM
enterprise_portal_subscriptions
WHERE
id = @id`,
strings.Join(subscriptionTableColumns(), ", "))
namedArgs := pgx.NamedArgs{"id": subscriptionID}
err := s.db.QueryRow(ctx, query, namedArgs).Scan(&subscription.ID, &subscription.InstanceDomain)
sub, err := scanSubscription(s.db.QueryRow(ctx, query, namedArgs))
if err != nil {
return nil, err
}
return &subscription, nil
return sub, nil
}

View File

@ -3,6 +3,7 @@ package subscriptions_test
import (
"context"
"testing"
"time"
"github.com/google/uuid"
"github.com/jackc/pgx/v5"
@ -12,6 +13,7 @@ import (
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/databasetest"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/subscriptions"
"github.com/sourcegraph/sourcegraph/lib/pointers"
)
func TestSubscriptionsStore(t *testing.T) {
@ -107,35 +109,98 @@ func SubscriptionsStoreList(t *testing.T, ctx context.Context, s *subscriptions.
}
func SubscriptionsStoreUpsert(t *testing.T, ctx context.Context, s *subscriptions.Store) {
// Create initial test record.
s1, err := s.Upsert(
// Create initial test record. The currentSubscription should be reassigned
// throughout various test cases to represent the current state of the test
// record, as the subtests are run in sequence.
currentSubscription, err := s.Upsert(
ctx,
uuid.New().String(),
subscriptions.UpsertSubscriptionOptions{InstanceDomain: "s1.sourcegraph.com"},
)
require.NoError(t, err)
got, err := s.Get(ctx, s1.ID)
got, err := s.Get(ctx, currentSubscription.ID)
require.NoError(t, err)
assert.Equal(t, s1.ID, got.ID)
assert.Equal(t, s1.InstanceDomain, got.InstanceDomain)
assert.Equal(t, currentSubscription.ID, got.ID)
assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain)
assert.Empty(t, got.DisplayName)
assert.NotZero(t, got.CreatedAt)
assert.NotZero(t, got.UpdatedAt)
assert.Nil(t, got.ArchivedAt) // not archived yet
t.Run("noop", func(t *testing.T) {
got, err = s.Upsert(ctx, s1.ID, subscriptions.UpsertSubscriptionOptions{})
t.Cleanup(func() { currentSubscription = got })
got, err = s.Upsert(ctx, currentSubscription.ID, subscriptions.UpsertSubscriptionOptions{})
require.NoError(t, err)
assert.Equal(t, s1.InstanceDomain, got.InstanceDomain)
assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain)
})
t.Run("update", func(t *testing.T) {
got, err = s.Upsert(ctx, s1.ID, subscriptions.UpsertSubscriptionOptions{InstanceDomain: "s1-new.sourcegraph.com"})
t.Run("update only domain", func(t *testing.T) {
t.Cleanup(func() { currentSubscription = got })
got, err = s.Upsert(ctx, currentSubscription.ID, subscriptions.UpsertSubscriptionOptions{
InstanceDomain: "s1-new.sourcegraph.com",
})
require.NoError(t, err)
assert.Equal(t, "s1-new.sourcegraph.com", got.InstanceDomain)
assert.Equal(t, currentSubscription.DisplayName, got.DisplayName)
})
t.Run("force update", func(t *testing.T) {
got, err = s.Upsert(ctx, s1.ID, subscriptions.UpsertSubscriptionOptions{ForceUpdate: true})
t.Run("update only display name", func(t *testing.T) {
t.Cleanup(func() { currentSubscription = got })
got, err = s.Upsert(ctx, currentSubscription.ID, subscriptions.UpsertSubscriptionOptions{
DisplayName: "My New Display Name",
})
require.NoError(t, err)
assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain)
assert.Equal(t, "My New Display Name", got.DisplayName)
})
t.Run("update only created at", func(t *testing.T) {
t.Cleanup(func() { currentSubscription = got })
yesterday := time.Now().Add(-24 * time.Hour)
got, err = s.Upsert(ctx, currentSubscription.ID, subscriptions.UpsertSubscriptionOptions{
CreatedAt: yesterday,
})
require.NoError(t, err)
assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain)
assert.Equal(t, currentSubscription.DisplayName, got.DisplayName)
// Round times to allow for some precision drift in CI
assert.Equal(t, yesterday.Round(time.Second).UTC(), got.CreatedAt.Round(time.Second))
})
t.Run("update only archived at", func(t *testing.T) {
t.Cleanup(func() { currentSubscription = got })
yesterday := time.Now().Add(-24 * time.Hour)
got, err = s.Upsert(ctx, currentSubscription.ID, subscriptions.UpsertSubscriptionOptions{
ArchivedAt: pointers.Ptr(yesterday),
})
require.NoError(t, err)
assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain)
assert.Equal(t, currentSubscription.DisplayName, got.DisplayName)
assert.Equal(t, currentSubscription.CreatedAt, got.CreatedAt)
// Round times to allow for some precision drift in CI
assert.Equal(t, yesterday.Round(time.Second).UTC(), got.ArchivedAt.Round(time.Second))
})
t.Run("force update to zero values", func(t *testing.T) {
t.Cleanup(func() { currentSubscription = got })
got, err = s.Upsert(ctx, currentSubscription.ID, subscriptions.UpsertSubscriptionOptions{
ForceUpdate: true,
})
require.NoError(t, err)
assert.Empty(t, got.InstanceDomain)
assert.Empty(t, got.DisplayName)
assert.Nil(t, got.ArchivedAt)
// Some fields cannot be updated in a force-update.
assert.Equal(t, currentSubscription.ID, got.ID)
assert.Equal(t, currentSubscription.CreatedAt, got.CreatedAt)
})
}

View File

@ -65,11 +65,11 @@ func convertSubscriptionToProto(subscription *subscriptions.Subscription, attrs
LastTransitionTime: timestamppb.New(*attrs.ArchivedAt),
})
}
return &subscriptionsv1.EnterpriseSubscription{
Id: subscriptionsv1.EnterpriseSubscriptionIDPrefix + attrs.ID,
Conditions: conds,
InstanceDomain: subscription.InstanceDomain,
DisplayName: subscription.DisplayName,
}
}

View File

@ -311,7 +311,8 @@ func (s *handlerV1) UpdateEnterpriseSubscription(ctx context.Context, req *conne
return nil, connect.NewError(connect.CodeInvalidArgument, errors.New("subscription.id is required"))
}
// Double check with the dotcom DB that the subscription ID is valid.
// TEMPORARY: Double check with the dotcom DB that the subscription ID is valid.
// This currently ensures we never actually create new subscriptions.
subscriptionAttrs, err := s.store.ListDotcomEnterpriseSubscriptions(ctx, dotcomdb.ListEnterpriseSubscriptionsOptions{
SubscriptionIDs: []string{subscriptionID},
})
@ -329,11 +330,16 @@ func (s *handlerV1) UpdateEnterpriseSubscription(ctx context.Context, req *conne
if v := req.Msg.GetSubscription().GetInstanceDomain(); v != "" {
opts.InstanceDomain = v
}
if v := req.Msg.GetSubscription().GetDisplayName(); v != "" {
opts.DisplayName = v
}
} else {
for _, p := range fieldPaths {
switch p {
case "instance_domain":
opts.InstanceDomain = req.Msg.GetSubscription().GetInstanceDomain()
case "display_name":
opts.DisplayName = req.Msg.GetSubscription().GetDisplayName()
case "*":
opts.ForceUpdate = true
opts.InstanceDomain = req.Msg.GetSubscription().GetInstanceDomain()

View File

@ -243,6 +243,7 @@ func TestHandlerV1_UpdateEnterpriseSubscription(t *testing.T) {
Subscription: &subscriptionsv1.EnterpriseSubscription{
Id: "80ca12e2-54b4-448c-a61a-390b1a9c1224",
InstanceDomain: "s1.sourcegraph.com",
DisplayName: "My Test Subscription",
},
UpdateMask: nil,
})
@ -252,6 +253,7 @@ func TestHandlerV1_UpdateEnterpriseSubscription(t *testing.T) {
h.mockStore.ListDotcomEnterpriseSubscriptionsFunc.SetDefaultReturn([]*dotcomdb.SubscriptionAttributes{{ID: "80ca12e2-54b4-448c-a61a-390b1a9c1224"}}, nil)
h.mockStore.UpsertEnterpriseSubscriptionFunc.SetDefaultHook(func(_ context.Context, _ string, opts subscriptions.UpsertSubscriptionOptions) (*subscriptions.Subscription, error) {
assert.NotEmpty(t, opts.InstanceDomain)
assert.NotEmpty(t, opts.DisplayName)
assert.False(t, opts.ForceUpdate)
return &subscriptions.Subscription{}, nil
})
@ -265,6 +267,7 @@ func TestHandlerV1_UpdateEnterpriseSubscription(t *testing.T) {
Subscription: &subscriptionsv1.EnterpriseSubscription{
Id: "80ca12e2-54b4-448c-a61a-390b1a9c1224",
InstanceDomain: "s1.sourcegraph.com",
DisplayName: "My Test Subscription", // should not be included
},
UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"instance_domain"}},
})
@ -274,6 +277,7 @@ func TestHandlerV1_UpdateEnterpriseSubscription(t *testing.T) {
h.mockStore.ListDotcomEnterpriseSubscriptionsFunc.SetDefaultReturn([]*dotcomdb.SubscriptionAttributes{{ID: "80ca12e2-54b4-448c-a61a-390b1a9c1224"}}, nil)
h.mockStore.UpsertEnterpriseSubscriptionFunc.SetDefaultHook(func(_ context.Context, _ string, opts subscriptions.UpsertSubscriptionOptions) (*subscriptions.Subscription, error) {
assert.NotEmpty(t, opts.InstanceDomain)
assert.Empty(t, opts.DisplayName)
assert.False(t, opts.ForceUpdate)
return &subscriptions.Subscription{}, nil
})

View File

@ -1538,6 +1538,7 @@ type UpdateEnterpriseSubscriptionRequest struct {
// The list of fields to update, fields are specified relative to the EnterpriseSubscription.
// Updatable fields are:
// - instance_domain
// - display_name
UpdateMask *fieldmaskpb.FieldMask `protobuf:"bytes,2,opt,name=update_mask,json=updateMask,proto3" json:"update_mask,omitempty"`
}

View File

@ -327,6 +327,7 @@ message UpdateEnterpriseSubscriptionRequest {
// The list of fields to update, fields are specified relative to the EnterpriseSubscription.
// Updatable fields are:
// - instance_domain
// - display_name
google.protobuf.FieldMask update_mask = 2;
}