mirror of
https://github.com/sourcegraph/sourcegraph.git
synced 2026-02-06 16:51:55 +00:00
feat/enterpriseportal: database layer for subscriptions upsert (#63703)
Implements upsert for all the subscriptions fields in the DB client. As part of this I generalized the logic for building upsert DB interactions into a new `upsert` package, because this pattern is a common one we'll need to implement to maintain various AIP-update-compliant endpoints, which specifies various upsert behaviours: https://google.aip.dev/134 Part of CORE-216 Part of CORE-156 ## Test plan Integration tests against DB
This commit is contained in:
parent
d7ab268385
commit
28f797e866
@ -0,0 +1,26 @@
|
||||
load("@io_bazel_rules_go//go:def.bzl", "go_library")
|
||||
load("//dev:go_defs.bzl", "go_test")
|
||||
|
||||
go_library(
|
||||
name = "upsert",
|
||||
srcs = ["upsert.go"],
|
||||
importpath = "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/upsert",
|
||||
visibility = ["//cmd/enterprise-portal:__subpackages__"],
|
||||
deps = [
|
||||
"@com_github_jackc_pgx_v5//:pgx",
|
||||
"@com_github_jackc_pgx_v5//pgxpool",
|
||||
],
|
||||
)
|
||||
|
||||
go_test(
|
||||
name = "upsert_test",
|
||||
srcs = ["upsert_test.go"],
|
||||
embed = [":upsert"],
|
||||
deps = [
|
||||
"//lib/pointers",
|
||||
"@com_github_hexops_autogold_v2//:autogold",
|
||||
"@com_github_hexops_valast//:valast",
|
||||
"@com_github_jackc_pgx_v5//:pgx",
|
||||
"@com_github_stretchr_testify//assert",
|
||||
],
|
||||
)
|
||||
@ -0,0 +1,138 @@
|
||||
package upsert
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
// Trick to avoid user-provided string values - outside this package, this type
|
||||
// can only be fulfilled by a constant string value.
|
||||
type constString string
|
||||
|
||||
type Builder struct {
|
||||
table constString
|
||||
primaryKey constString
|
||||
|
||||
insertColumns []string
|
||||
args pgx.NamedArgs
|
||||
updateColumns []string
|
||||
|
||||
forceUpdate bool
|
||||
}
|
||||
|
||||
// New instantiates an upsert.Builder that can be used with `upsert.Field(b, ...)`
|
||||
// to implement the database layer for the upsert pattern common in gRPC 'update',
|
||||
// methods, per the AIP: https://google.aip.dev/134
|
||||
func New(table, primaryKey constString, forceUpdate bool) *Builder {
|
||||
return &Builder{
|
||||
table: table,
|
||||
primaryKey: primaryKey,
|
||||
forceUpdate: forceUpdate,
|
||||
args: pgx.NamedArgs{},
|
||||
}
|
||||
}
|
||||
|
||||
type fieldOptions struct {
|
||||
useColumnDefault bool
|
||||
ignoreOnForceUpdate bool
|
||||
}
|
||||
|
||||
type fieldOptionFn func(*fieldOptions)
|
||||
|
||||
func (fn fieldOptionFn) apply(opt *fieldOptions) { fn(opt) }
|
||||
|
||||
type FieldOption interface {
|
||||
apply(*fieldOptions)
|
||||
}
|
||||
|
||||
// WithColumnDefault indicates that the field should not be included in an upsert
|
||||
// if the field has a zero value, which allows the column default to be used.
|
||||
//
|
||||
// It does NOT apply in a force update.
|
||||
func WithColumnDefault() FieldOption {
|
||||
return fieldOptionFn(func(opt *fieldOptions) { opt.useColumnDefault = true })
|
||||
}
|
||||
|
||||
// WithIgnoreOnForceUpdate indicates that the field should not be updated when
|
||||
// performing a force update.
|
||||
func WithIgnoreOnForceUpdate() FieldOption {
|
||||
return fieldOptionFn(func(opt *fieldOptions) { opt.ignoreOnForceUpdate = true })
|
||||
}
|
||||
|
||||
// Field registers a field that can be set in the upsert to value T. If T is
|
||||
// a zero value, the field is not set on an update, UNLESS the `forceUpdate`
|
||||
// parameter was provided as `true` to upsert.New(...).
|
||||
func Field[T comparable](b *Builder, column constString, value T, opts ...FieldOption) {
|
||||
opt := fieldOptions{}
|
||||
for _, o := range opts {
|
||||
o.apply(&opt)
|
||||
}
|
||||
var zero T
|
||||
|
||||
// If upsert has a zero value, and we would prefer to use the column default,
|
||||
// do nothing, unless we are performing a force-update across all fields.
|
||||
if !b.forceUpdate && (zero == value && opt.useColumnDefault) {
|
||||
return
|
||||
}
|
||||
|
||||
// If we are force-updating, and the field is marked to be ignored, do nothing.
|
||||
if b.forceUpdate && opt.ignoreOnForceUpdate {
|
||||
return
|
||||
}
|
||||
|
||||
b.insertColumns = append(b.insertColumns, string(column))
|
||||
b.args[string(column)] = value
|
||||
|
||||
// If we are force-updating, or value is not zero, update the column in
|
||||
// existing rows (on conflict).
|
||||
if b.forceUpdate || value != zero {
|
||||
b.updateColumns = append(b.updateColumns, string(column))
|
||||
}
|
||||
}
|
||||
|
||||
func (b *Builder) buildQuery() (string, bool) {
|
||||
if len(b.updateColumns) == 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
onConflictSets := make([]string, len(b.updateColumns))
|
||||
for i, c := range b.updateColumns {
|
||||
onConflictSets[i] = fmt.Sprintf("%[1]s = EXCLUDED.%[1]s", c)
|
||||
}
|
||||
|
||||
insertArgNames := make([]string, len(b.insertColumns))
|
||||
for i, c := range b.insertColumns {
|
||||
insertArgNames[i] = fmt.Sprintf("@%s", c)
|
||||
}
|
||||
|
||||
return fmt.Sprintf(`
|
||||
INSERT INTO %[1]s
|
||||
(%[2]s)
|
||||
VALUES
|
||||
(%[3]s)
|
||||
ON CONFLICT
|
||||
(%[4]s)
|
||||
DO UPDATE SET
|
||||
%[5]s`,
|
||||
b.table, // %[1]s
|
||||
strings.Join(b.insertColumns, ", "), // %[2]s
|
||||
strings.Join(insertArgNames, ", "), // %[3]s
|
||||
b.primaryKey, // %[4]s
|
||||
strings.Join(onConflictSets, ",\n"), // %[5]s
|
||||
), true
|
||||
}
|
||||
|
||||
func (b *Builder) Exec(ctx context.Context, db *pgxpool.Pool) error {
|
||||
q, ok := b.buildQuery()
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if _, err := db.Exec(ctx, q, b.args); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -0,0 +1,113 @@
|
||||
package upsert
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/hexops/autogold/v2"
|
||||
"github.com/hexops/valast"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/lib/pointers"
|
||||
)
|
||||
|
||||
func TestBuilder(t *testing.T) {
|
||||
mockTime := time.Date(2024, 7, 8, 16, 39, 16, 4277000, time.Local)
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
|
||||
forceUpdate bool
|
||||
upsertFields func(b *Builder)
|
||||
|
||||
wantQuery autogold.Value
|
||||
wantArgs autogold.Value
|
||||
}{
|
||||
{
|
||||
name: "not force-update",
|
||||
forceUpdate: false,
|
||||
upsertFields: func(b *Builder) {
|
||||
Field[*string](b, "col1", nil)
|
||||
|
||||
// WithIgnoreOnForceUpdate() does nothing because the we are not
|
||||
// in a force-update.
|
||||
Field(b, "col2", pointers.Ptr("value2"), WithIgnoreOnForceUpdate())
|
||||
|
||||
// WithColumnDefault() does nothing because the time is not zero
|
||||
Field(b, "time", mockTime, WithColumnDefault())
|
||||
|
||||
// Do not set, it should use the default value.
|
||||
Field(b, "should_be_ignored", "", WithColumnDefault())
|
||||
},
|
||||
wantQuery: autogold.Expect(`
|
||||
INSERT INTO table
|
||||
(col1, col2, time)
|
||||
VALUES
|
||||
(@col1, @col2, @time)
|
||||
ON CONFLICT
|
||||
(id)
|
||||
DO UPDATE SET
|
||||
col2 = EXCLUDED.col2,
|
||||
time = EXCLUDED.time`),
|
||||
wantArgs: autogold.Expect(pgx.NamedArgs{
|
||||
"col1": nil, "col2": valast.Ptr("value2"),
|
||||
"time": time.Date(2024,
|
||||
7,
|
||||
8,
|
||||
16,
|
||||
39,
|
||||
16,
|
||||
4277000,
|
||||
time.Local),
|
||||
}),
|
||||
},
|
||||
{
|
||||
name: "force-update",
|
||||
forceUpdate: true,
|
||||
upsertFields: func(b *Builder) {
|
||||
Field[*string](b, "col1", nil)
|
||||
Field(b, "col2", pointers.Ptr("value2"))
|
||||
Field(b, "time", mockTime)
|
||||
|
||||
// Do not set, it cannot be updated in a force-update.
|
||||
Field(b, "should_be_ignored", "", WithIgnoreOnForceUpdate())
|
||||
},
|
||||
wantQuery: autogold.Expect(`
|
||||
INSERT INTO table
|
||||
(col1, col2, time)
|
||||
VALUES
|
||||
(@col1, @col2, @time)
|
||||
ON CONFLICT
|
||||
(id)
|
||||
DO UPDATE SET
|
||||
col1 = EXCLUDED.col1,
|
||||
col2 = EXCLUDED.col2,
|
||||
time = EXCLUDED.time`),
|
||||
wantArgs: autogold.Expect(pgx.NamedArgs{
|
||||
"col1": nil, "col2": valast.Ptr("value2"),
|
||||
"time": time.Date(2024,
|
||||
7,
|
||||
8,
|
||||
16,
|
||||
39,
|
||||
16,
|
||||
4277000,
|
||||
time.Local),
|
||||
}),
|
||||
},
|
||||
} {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
b := New("table", "id", tc.forceUpdate)
|
||||
tc.upsertFields(b)
|
||||
|
||||
q, ok := b.buildQuery()
|
||||
if tc.wantQuery == nil && tc.wantArgs == nil {
|
||||
assert.False(t, ok)
|
||||
} else {
|
||||
assert.True(t, ok)
|
||||
tc.wantQuery.Equal(t, q)
|
||||
tc.wantArgs.Equal(t, b.args)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -13,7 +13,9 @@ go_library(
|
||||
tags = [TAG_INFRA_CORESERVICES],
|
||||
visibility = ["//cmd/enterprise-portal:__subpackages__"],
|
||||
deps = [
|
||||
"//cmd/enterprise-portal/internal/database/internal/upsert",
|
||||
"//lib/errors",
|
||||
"//lib/pointers",
|
||||
"@com_github_jackc_pgtype//:pgtype",
|
||||
"@com_github_jackc_pgx_v5//:pgx",
|
||||
"@com_github_jackc_pgx_v5//pgxpool",
|
||||
@ -31,6 +33,7 @@ go_test(
|
||||
":subscriptions",
|
||||
"//cmd/enterprise-portal/internal/database/databasetest",
|
||||
"//cmd/enterprise-portal/internal/database/internal/tables",
|
||||
"//lib/pointers",
|
||||
"@com_github_google_uuid//:uuid",
|
||||
"@com_github_jackc_pgx_v5//:pgx",
|
||||
"@com_github_stretchr_testify//assert",
|
||||
|
||||
@ -9,7 +9,9 @@ import (
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/upsert"
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
"github.com/sourcegraph/sourcegraph/lib/pointers"
|
||||
)
|
||||
|
||||
// Subscription is an Enterprise subscription record.
|
||||
@ -26,8 +28,12 @@ type Subscription struct {
|
||||
|
||||
// DisplayName is the human-friendly name of this subscription, e.g. "Acme, Inc."
|
||||
//
|
||||
// It must be unique across all currently un-archived subscriptions.
|
||||
DisplayName string `gorm:"size:256;not null;uniqueIndex:,where:archived_at IS NULL AND display_name != 'Unnamed subscription';default:'Unnamed subscription'"`
|
||||
// It must be unique across all currently un-archived subscriptions, unless
|
||||
// it is not set.
|
||||
//
|
||||
// TODO: Clean up the database post-deploy and remove the 'Unnamed subscription'
|
||||
// part of the constraint.
|
||||
DisplayName string `gorm:"size:256;not null;uniqueIndex:,where:archived_at IS NULL AND display_name != 'Unnamed subscription' AND display_name != ''"`
|
||||
|
||||
// Timestamps representing the latest timestamps of key conditions related
|
||||
// to this subscription.
|
||||
@ -43,10 +49,50 @@ type Subscription struct {
|
||||
SalesforceOpportunityID *string
|
||||
}
|
||||
|
||||
func (s *Subscription) TableName() string {
|
||||
func (s Subscription) TableName() string {
|
||||
return "enterprise_portal_subscriptions"
|
||||
}
|
||||
|
||||
// subscriptionTableColumns must match s.scan() values.
|
||||
func subscriptionTableColumns() []string {
|
||||
return []string{
|
||||
"id",
|
||||
"instance_domain",
|
||||
"display_name",
|
||||
"created_at",
|
||||
"updated_at",
|
||||
"archived_at",
|
||||
"salesforce_subscription_id",
|
||||
"salesforce_opportunity_id",
|
||||
}
|
||||
}
|
||||
|
||||
// scanSubscription matches s.columns() values.
|
||||
func scanSubscription(row pgx.Row) (*Subscription, error) {
|
||||
var s Subscription
|
||||
err := row.Scan(
|
||||
&s.ID,
|
||||
&s.InstanceDomain,
|
||||
&s.DisplayName,
|
||||
&s.CreatedAt,
|
||||
&s.UpdatedAt,
|
||||
&s.ArchivedAt,
|
||||
&s.SalesforceSubscriptionID,
|
||||
&s.SalesforceOpportunityID,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s.CreatedAt = s.CreatedAt.UTC()
|
||||
s.UpdatedAt = s.UpdatedAt.UTC()
|
||||
if s.ArchivedAt != nil {
|
||||
s.ArchivedAt = pointers.Ptr(s.ArchivedAt.UTC())
|
||||
}
|
||||
|
||||
return &s, nil
|
||||
}
|
||||
|
||||
// Store is the storage layer for product subscriptions.
|
||||
type Store struct {
|
||||
db *pgxpool.Pool
|
||||
@ -101,11 +147,11 @@ func (s *Store) List(ctx context.Context, opts ListEnterpriseSubscriptionsOption
|
||||
where, limit, namedArgs := opts.toQueryConditions()
|
||||
query := fmt.Sprintf(`
|
||||
SELECT
|
||||
id,
|
||||
instance_domain
|
||||
%s
|
||||
FROM enterprise_portal_subscriptions
|
||||
WHERE %s
|
||||
%s`,
|
||||
strings.Join(subscriptionTableColumns(), ", "),
|
||||
where, limit,
|
||||
)
|
||||
rows, err := s.db.Query(ctx, query, namedArgs)
|
||||
@ -116,17 +162,25 @@ WHERE %s
|
||||
|
||||
var subscriptions []*Subscription
|
||||
for rows.Next() {
|
||||
var subscription Subscription
|
||||
if err = rows.Scan(&subscription.ID, &subscription.InstanceDomain); err != nil {
|
||||
sub, err := scanSubscription(rows)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "scan row")
|
||||
}
|
||||
subscriptions = append(subscriptions, &subscription)
|
||||
subscriptions = append(subscriptions, sub)
|
||||
}
|
||||
return subscriptions, rows.Err()
|
||||
}
|
||||
|
||||
type UpsertSubscriptionOptions struct {
|
||||
InstanceDomain string
|
||||
DisplayName string
|
||||
|
||||
CreatedAt time.Time
|
||||
ArchivedAt *time.Time
|
||||
|
||||
SalesforceSubscriptionID *string
|
||||
SalesforceOpportunityID *string
|
||||
|
||||
// ForceUpdate indicates whether to force update all fields of the subscription
|
||||
// record.
|
||||
ForceUpdate bool
|
||||
@ -134,40 +188,26 @@ type UpsertSubscriptionOptions struct {
|
||||
|
||||
// toQuery returns the query based on the options. It returns an empty query if
|
||||
// nothing to update.
|
||||
func (opts UpsertSubscriptionOptions) toQuery(id string) (query string, _ pgx.NamedArgs) {
|
||||
const queryFmt = `
|
||||
INSERT INTO enterprise_portal_subscriptions (id, instance_domain)
|
||||
VALUES (@id, @instanceDomain)
|
||||
ON CONFLICT (id)
|
||||
DO UPDATE SET
|
||||
%s`
|
||||
namedArgs := pgx.NamedArgs{
|
||||
"id": id,
|
||||
"instanceDomain": opts.InstanceDomain,
|
||||
}
|
||||
func (opts UpsertSubscriptionOptions) Exec(ctx context.Context, db *pgxpool.Pool, id string) error {
|
||||
b := upsert.New("enterprise_portal_subscriptions", "id", opts.ForceUpdate)
|
||||
upsert.Field(b, "id", id)
|
||||
upsert.Field(b, "instance_domain", opts.InstanceDomain)
|
||||
upsert.Field(b, "display_name", opts.DisplayName)
|
||||
|
||||
var sets []string
|
||||
if opts.ForceUpdate || opts.InstanceDomain != "" {
|
||||
sets = append(sets, "instance_domain = excluded.instance_domain")
|
||||
}
|
||||
if len(sets) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
query = fmt.Sprintf(
|
||||
queryFmt,
|
||||
strings.Join(sets, ", "),
|
||||
)
|
||||
return query, namedArgs
|
||||
upsert.Field(b, "created_at", opts.CreatedAt,
|
||||
upsert.WithColumnDefault(),
|
||||
upsert.WithIgnoreOnForceUpdate())
|
||||
upsert.Field(b, "updated_at", time.Now()) // always updated now
|
||||
upsert.Field(b, "archived_at", opts.ArchivedAt)
|
||||
upsert.Field(b, "salesforce_subscription_id", opts.SalesforceSubscriptionID)
|
||||
upsert.Field(b, "salesforce_opportunity_id", opts.SalesforceOpportunityID)
|
||||
return b.Exec(ctx, db)
|
||||
}
|
||||
|
||||
// Upsert upserts a subscription record based on the given options.
|
||||
func (s *Store) Upsert(ctx context.Context, subscriptionID string, opts UpsertSubscriptionOptions) (*Subscription, error) {
|
||||
query, namedArgs := opts.toQuery(subscriptionID)
|
||||
if query != "" {
|
||||
_, err := s.db.Exec(ctx, query, namedArgs)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "exec")
|
||||
}
|
||||
if err := opts.Exec(ctx, s.db, subscriptionID); err != nil {
|
||||
return nil, errors.Wrap(err, "exec")
|
||||
}
|
||||
return s.Get(ctx, subscriptionID)
|
||||
}
|
||||
@ -175,12 +215,18 @@ func (s *Store) Upsert(ctx context.Context, subscriptionID string, opts UpsertSu
|
||||
// Get returns a subscription record with the given subscription ID. It returns
|
||||
// pgx.ErrNoRows if no such subscription exists.
|
||||
func (s *Store) Get(ctx context.Context, subscriptionID string) (*Subscription, error) {
|
||||
var subscription Subscription
|
||||
query := `SELECT id, instance_domain FROM enterprise_portal_subscriptions WHERE id = @id`
|
||||
query := fmt.Sprintf(`SELECT
|
||||
%s
|
||||
FROM
|
||||
enterprise_portal_subscriptions
|
||||
WHERE
|
||||
id = @id`,
|
||||
strings.Join(subscriptionTableColumns(), ", "))
|
||||
namedArgs := pgx.NamedArgs{"id": subscriptionID}
|
||||
err := s.db.QueryRow(ctx, query, namedArgs).Scan(&subscription.ID, &subscription.InstanceDomain)
|
||||
|
||||
sub, err := scanSubscription(s.db.QueryRow(ctx, query, namedArgs))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &subscription, nil
|
||||
return sub, nil
|
||||
}
|
||||
|
||||
@ -3,6 +3,7 @@ package subscriptions_test
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
@ -12,6 +13,7 @@ import (
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/databasetest"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/subscriptions"
|
||||
"github.com/sourcegraph/sourcegraph/lib/pointers"
|
||||
)
|
||||
|
||||
func TestSubscriptionsStore(t *testing.T) {
|
||||
@ -107,35 +109,98 @@ func SubscriptionsStoreList(t *testing.T, ctx context.Context, s *subscriptions.
|
||||
}
|
||||
|
||||
func SubscriptionsStoreUpsert(t *testing.T, ctx context.Context, s *subscriptions.Store) {
|
||||
// Create initial test record.
|
||||
s1, err := s.Upsert(
|
||||
// Create initial test record. The currentSubscription should be reassigned
|
||||
// throughout various test cases to represent the current state of the test
|
||||
// record, as the subtests are run in sequence.
|
||||
currentSubscription, err := s.Upsert(
|
||||
ctx,
|
||||
uuid.New().String(),
|
||||
subscriptions.UpsertSubscriptionOptions{InstanceDomain: "s1.sourcegraph.com"},
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := s.Get(ctx, s1.ID)
|
||||
got, err := s.Get(ctx, currentSubscription.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, s1.ID, got.ID)
|
||||
assert.Equal(t, s1.InstanceDomain, got.InstanceDomain)
|
||||
assert.Equal(t, currentSubscription.ID, got.ID)
|
||||
assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain)
|
||||
assert.Empty(t, got.DisplayName)
|
||||
assert.NotZero(t, got.CreatedAt)
|
||||
assert.NotZero(t, got.UpdatedAt)
|
||||
assert.Nil(t, got.ArchivedAt) // not archived yet
|
||||
|
||||
t.Run("noop", func(t *testing.T) {
|
||||
got, err = s.Upsert(ctx, s1.ID, subscriptions.UpsertSubscriptionOptions{})
|
||||
t.Cleanup(func() { currentSubscription = got })
|
||||
|
||||
got, err = s.Upsert(ctx, currentSubscription.ID, subscriptions.UpsertSubscriptionOptions{})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, s1.InstanceDomain, got.InstanceDomain)
|
||||
assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain)
|
||||
})
|
||||
|
||||
t.Run("update", func(t *testing.T) {
|
||||
got, err = s.Upsert(ctx, s1.ID, subscriptions.UpsertSubscriptionOptions{InstanceDomain: "s1-new.sourcegraph.com"})
|
||||
t.Run("update only domain", func(t *testing.T) {
|
||||
t.Cleanup(func() { currentSubscription = got })
|
||||
|
||||
got, err = s.Upsert(ctx, currentSubscription.ID, subscriptions.UpsertSubscriptionOptions{
|
||||
InstanceDomain: "s1-new.sourcegraph.com",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "s1-new.sourcegraph.com", got.InstanceDomain)
|
||||
assert.Equal(t, currentSubscription.DisplayName, got.DisplayName)
|
||||
})
|
||||
|
||||
t.Run("force update", func(t *testing.T) {
|
||||
got, err = s.Upsert(ctx, s1.ID, subscriptions.UpsertSubscriptionOptions{ForceUpdate: true})
|
||||
t.Run("update only display name", func(t *testing.T) {
|
||||
t.Cleanup(func() { currentSubscription = got })
|
||||
|
||||
got, err = s.Upsert(ctx, currentSubscription.ID, subscriptions.UpsertSubscriptionOptions{
|
||||
DisplayName: "My New Display Name",
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain)
|
||||
assert.Equal(t, "My New Display Name", got.DisplayName)
|
||||
})
|
||||
|
||||
t.Run("update only created at", func(t *testing.T) {
|
||||
t.Cleanup(func() { currentSubscription = got })
|
||||
|
||||
yesterday := time.Now().Add(-24 * time.Hour)
|
||||
got, err = s.Upsert(ctx, currentSubscription.ID, subscriptions.UpsertSubscriptionOptions{
|
||||
CreatedAt: yesterday,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain)
|
||||
assert.Equal(t, currentSubscription.DisplayName, got.DisplayName)
|
||||
// Round times to allow for some precision drift in CI
|
||||
assert.Equal(t, yesterday.Round(time.Second).UTC(), got.CreatedAt.Round(time.Second))
|
||||
})
|
||||
|
||||
t.Run("update only archived at", func(t *testing.T) {
|
||||
t.Cleanup(func() { currentSubscription = got })
|
||||
|
||||
yesterday := time.Now().Add(-24 * time.Hour)
|
||||
got, err = s.Upsert(ctx, currentSubscription.ID, subscriptions.UpsertSubscriptionOptions{
|
||||
ArchivedAt: pointers.Ptr(yesterday),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, currentSubscription.InstanceDomain, got.InstanceDomain)
|
||||
assert.Equal(t, currentSubscription.DisplayName, got.DisplayName)
|
||||
assert.Equal(t, currentSubscription.CreatedAt, got.CreatedAt)
|
||||
// Round times to allow for some precision drift in CI
|
||||
assert.Equal(t, yesterday.Round(time.Second).UTC(), got.ArchivedAt.Round(time.Second))
|
||||
})
|
||||
|
||||
t.Run("force update to zero values", func(t *testing.T) {
|
||||
t.Cleanup(func() { currentSubscription = got })
|
||||
|
||||
got, err = s.Upsert(ctx, currentSubscription.ID, subscriptions.UpsertSubscriptionOptions{
|
||||
ForceUpdate: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Empty(t, got.InstanceDomain)
|
||||
assert.Empty(t, got.DisplayName)
|
||||
assert.Nil(t, got.ArchivedAt)
|
||||
|
||||
// Some fields cannot be updated in a force-update.
|
||||
assert.Equal(t, currentSubscription.ID, got.ID)
|
||||
assert.Equal(t, currentSubscription.CreatedAt, got.CreatedAt)
|
||||
})
|
||||
}
|
||||
|
||||
|
||||
@ -65,11 +65,11 @@ func convertSubscriptionToProto(subscription *subscriptions.Subscription, attrs
|
||||
LastTransitionTime: timestamppb.New(*attrs.ArchivedAt),
|
||||
})
|
||||
}
|
||||
|
||||
return &subscriptionsv1.EnterpriseSubscription{
|
||||
Id: subscriptionsv1.EnterpriseSubscriptionIDPrefix + attrs.ID,
|
||||
Conditions: conds,
|
||||
InstanceDomain: subscription.InstanceDomain,
|
||||
DisplayName: subscription.DisplayName,
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -311,7 +311,8 @@ func (s *handlerV1) UpdateEnterpriseSubscription(ctx context.Context, req *conne
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, errors.New("subscription.id is required"))
|
||||
}
|
||||
|
||||
// Double check with the dotcom DB that the subscription ID is valid.
|
||||
// TEMPORARY: Double check with the dotcom DB that the subscription ID is valid.
|
||||
// This currently ensures we never actually create new subscriptions.
|
||||
subscriptionAttrs, err := s.store.ListDotcomEnterpriseSubscriptions(ctx, dotcomdb.ListEnterpriseSubscriptionsOptions{
|
||||
SubscriptionIDs: []string{subscriptionID},
|
||||
})
|
||||
@ -329,11 +330,16 @@ func (s *handlerV1) UpdateEnterpriseSubscription(ctx context.Context, req *conne
|
||||
if v := req.Msg.GetSubscription().GetInstanceDomain(); v != "" {
|
||||
opts.InstanceDomain = v
|
||||
}
|
||||
if v := req.Msg.GetSubscription().GetDisplayName(); v != "" {
|
||||
opts.DisplayName = v
|
||||
}
|
||||
} else {
|
||||
for _, p := range fieldPaths {
|
||||
switch p {
|
||||
case "instance_domain":
|
||||
opts.InstanceDomain = req.Msg.GetSubscription().GetInstanceDomain()
|
||||
case "display_name":
|
||||
opts.DisplayName = req.Msg.GetSubscription().GetDisplayName()
|
||||
case "*":
|
||||
opts.ForceUpdate = true
|
||||
opts.InstanceDomain = req.Msg.GetSubscription().GetInstanceDomain()
|
||||
|
||||
@ -243,6 +243,7 @@ func TestHandlerV1_UpdateEnterpriseSubscription(t *testing.T) {
|
||||
Subscription: &subscriptionsv1.EnterpriseSubscription{
|
||||
Id: "80ca12e2-54b4-448c-a61a-390b1a9c1224",
|
||||
InstanceDomain: "s1.sourcegraph.com",
|
||||
DisplayName: "My Test Subscription",
|
||||
},
|
||||
UpdateMask: nil,
|
||||
})
|
||||
@ -252,6 +253,7 @@ func TestHandlerV1_UpdateEnterpriseSubscription(t *testing.T) {
|
||||
h.mockStore.ListDotcomEnterpriseSubscriptionsFunc.SetDefaultReturn([]*dotcomdb.SubscriptionAttributes{{ID: "80ca12e2-54b4-448c-a61a-390b1a9c1224"}}, nil)
|
||||
h.mockStore.UpsertEnterpriseSubscriptionFunc.SetDefaultHook(func(_ context.Context, _ string, opts subscriptions.UpsertSubscriptionOptions) (*subscriptions.Subscription, error) {
|
||||
assert.NotEmpty(t, opts.InstanceDomain)
|
||||
assert.NotEmpty(t, opts.DisplayName)
|
||||
assert.False(t, opts.ForceUpdate)
|
||||
return &subscriptions.Subscription{}, nil
|
||||
})
|
||||
@ -265,6 +267,7 @@ func TestHandlerV1_UpdateEnterpriseSubscription(t *testing.T) {
|
||||
Subscription: &subscriptionsv1.EnterpriseSubscription{
|
||||
Id: "80ca12e2-54b4-448c-a61a-390b1a9c1224",
|
||||
InstanceDomain: "s1.sourcegraph.com",
|
||||
DisplayName: "My Test Subscription", // should not be included
|
||||
},
|
||||
UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"instance_domain"}},
|
||||
})
|
||||
@ -274,6 +277,7 @@ func TestHandlerV1_UpdateEnterpriseSubscription(t *testing.T) {
|
||||
h.mockStore.ListDotcomEnterpriseSubscriptionsFunc.SetDefaultReturn([]*dotcomdb.SubscriptionAttributes{{ID: "80ca12e2-54b4-448c-a61a-390b1a9c1224"}}, nil)
|
||||
h.mockStore.UpsertEnterpriseSubscriptionFunc.SetDefaultHook(func(_ context.Context, _ string, opts subscriptions.UpsertSubscriptionOptions) (*subscriptions.Subscription, error) {
|
||||
assert.NotEmpty(t, opts.InstanceDomain)
|
||||
assert.Empty(t, opts.DisplayName)
|
||||
assert.False(t, opts.ForceUpdate)
|
||||
return &subscriptions.Subscription{}, nil
|
||||
})
|
||||
|
||||
@ -1538,6 +1538,7 @@ type UpdateEnterpriseSubscriptionRequest struct {
|
||||
// The list of fields to update, fields are specified relative to the EnterpriseSubscription.
|
||||
// Updatable fields are:
|
||||
// - instance_domain
|
||||
// - display_name
|
||||
UpdateMask *fieldmaskpb.FieldMask `protobuf:"bytes,2,opt,name=update_mask,json=updateMask,proto3" json:"update_mask,omitempty"`
|
||||
}
|
||||
|
||||
|
||||
@ -327,6 +327,7 @@ message UpdateEnterpriseSubscriptionRequest {
|
||||
// The list of fields to update, fields are specified relative to the EnterpriseSubscription.
|
||||
// Updatable fields are:
|
||||
// - instance_domain
|
||||
// - display_name
|
||||
google.protobuf.FieldMask update_mask = 2;
|
||||
}
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user