[license-checks] add new fields to product_licenses table (#52761)

This is the first PR to enable [automated
billing](https://docs.google.com/document/d/1aJniLvsyAuomsnBrdT0_kpVa9P2tbSYZEwS8bmfGJRg/edit#heading=h.7h9icmy9s9qp)
and automated license checks.

Adding a few fields with this migration
```
ALTER TABLE IF EXISTS product_licenses
    ADD COLUMN IF NOT EXISTS site_id UUID NULL,
    ADD COLUMN IF NOT EXISTS license_check_token bytea NULL,
    ADD COLUMN IF NOT EXISTS revoked_at timestamptz NULL,
    ADD COLUMN IF NOT EXISTS salesforce_sub_id text NULL,
    ADD COLUMN IF NOT EXISTS salesforce_opp_id text NULL;
```

Also changed the store to be able to use the salesforce fields to create
licenses of version 2.
Finally changed the resolver to also accept new salesforce fields,
resulting in license version 2 creation.

Related issue: 
- #52508

## Test plan

Some unit tests modified/added. To manually test, first create a product
subscription locally and then fire a graphql request like below:
```
mutation GenerateProductLicenseForSubscription($productSubscriptionID: ID!, $license: ProductLicenseInput!) {
  dotcom {
    generateProductLicenseForSubscription(
      productSubscriptionID: $productSubscriptionID
      license: $license
    ) {
      id
      __typename
    }
    __typename
  }
}
```
With variables:
```
{
  "productSubscriptionID": "UHJvZHVjdFN1YnNjcmlwdGlvbjoiYTMyMDU2ZGEtYjY4Zi00YTI4LWFkY2YtZGJhMTVhMmRjODg2Ig==", 
  "license": {
      "tags": [
        "customer:milan",
        "plan:enterprise-1"
      ],
      "userCount": 1,
      "expiresAt": 1717279199,
    	"salesforceSubscriptionID":"1234",
      "salesforceOpportunityID": "5678"
    }
}
```
This commit is contained in:
Milan Freml 2023-06-02 09:16:55 +02:00 committed by GitHub
parent 25d754a04b
commit 65ba7fc65f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 253 additions and 50 deletions

View File

@ -98,9 +98,11 @@ type ProductLicense interface {
// ProductLicenseInput implements the GraphQL type ProductLicenseInput.
type ProductLicenseInput struct {
Tags []string
UserCount int32
ExpiresAt int32
Tags []string
UserCount int32
ExpiresAt int32
SalesforceSubscriptionID *string
SalesforceOpportunityID *string
}
type ProductLicensesArgs struct {

View File

@ -178,6 +178,14 @@ input ProductLicenseInput {
The expiration date of this product license, expressed as the number of seconds since the epoch.
"""
expiresAt: Int!
"""
The Salesforce subscription ID associated with this product license.
"""
salesforceSubscriptionID: String
"""
The Salesforce opportunity ID associated with this product license.
"""
salesforceOpportunityID: String
}
"""

View File

@ -2,6 +2,7 @@ package productsubscription
import (
"context"
"crypto/sha256"
"time"
"github.com/google/uuid"
@ -11,6 +12,7 @@ import (
"github.com/sourcegraph/sourcegraph/enterprise/internal/license"
"github.com/sourcegraph/sourcegraph/internal/database"
"github.com/sourcegraph/sourcegraph/internal/database/dbutil"
"github.com/sourcegraph/sourcegraph/internal/hashutil"
"github.com/sourcegraph/sourcegraph/lib/errors"
)
@ -36,6 +38,11 @@ type dbLicenses struct {
db database.DB
}
const createLicenseQuery = `
INSERT INTO product_licenses(id, product_subscription_id, license_key, license_version, license_tags, license_user_count, license_expires_at, license_check_token, salesforce_sub_id, salesforce_opp_id)
VALUES($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) RETURNING id
`
// Create creates a new product license entry for the given subscription.
func (s dbLicenses) Create(ctx context.Context, subscriptionID, licenseKey string, version int, info license.Info) (id string, err error) {
if mocks.licenses.Create != nil {
@ -47,14 +54,13 @@ func (s dbLicenses) Create(ctx context.Context, subscriptionID, licenseKey strin
return "", errors.Wrap(err, "new UUID")
}
keyHash := sha256.Sum256([]byte(licenseKey))
var expiresAt *time.Time
if !info.ExpiresAt.IsZero() {
expiresAt = &info.ExpiresAt
}
if err = s.db.QueryRowContext(ctx, `
INSERT INTO product_licenses(id, product_subscription_id, license_key, license_version, license_tags, license_user_count, license_expires_at)
VALUES($1, $2, $3, $4, $5, $6, $7) RETURNING id
`,
if err = s.db.QueryRowContext(ctx, createLicenseQuery,
newUUID,
subscriptionID,
licenseKey,
@ -62,6 +68,9 @@ VALUES($1, $2, $3, $4, $5, $6, $7) RETURNING id
pq.Array(info.Tags),
dbutil.NewNullInt64(int64(info.UserCount)),
dbutil.NullTime{Time: expiresAt},
hashutil.ToSHA256Bytes(keyHash[:]),
info.SalesforceSubscriptionID,
info.SalesforceOpportunityID,
).Scan(&id); err != nil {
return "", errors.Wrap(err, "insert")
}

View File

@ -2,6 +2,7 @@ package productsubscription
import (
"context"
"fmt"
"testing"
"github.com/sourcegraph/log/logtest"
@ -42,36 +43,56 @@ func TestProductLicenses_Create(t *testing.T) {
require.NoError(t, err)
now := timeutil.Now()
info := license.Info{
licenseV1 := license.Info{
Tags: []string{"true-up"},
UserCount: 10,
ExpiresAt: now,
}
pl, err := dbLicenses{db: db}.Create(ctx, ps, "k2", 1, info)
require.NoError(t, err)
got, err := dbLicenses{db: db}.GetByID(ctx, pl)
require.NoError(t, err)
assert.Equal(t, pl, got.ID)
assert.Equal(t, ps, got.ProductSubscriptionID)
assert.Equal(t, "k2", got.LicenseKey)
sfSubID := "AE9108431908421"
sfOpID := "0A8908908A800F"
require.NotNil(t, got.LicenseVersion)
assert.Equal(t, 1, *got.LicenseVersion)
require.NotNil(t, got.LicenseTags)
assert.Equal(t, info.Tags, got.LicenseTags)
require.NotNil(t, got.LicenseUserCount)
assert.Equal(t, int(info.UserCount), *got.LicenseUserCount)
require.NotNil(t, got.LicenseExpiresAt)
assert.Equal(t, info.ExpiresAt, *got.LicenseExpiresAt)
licenseV2 := license.Info{
Tags: []string{"true-up"},
UserCount: 10,
ExpiresAt: now,
SalesforceSubscriptionID: &sfSubID,
SalesforceOpportunityID: &sfOpID,
}
ts, err := dbLicenses{db: db}.List(ctx, dbLicensesListOptions{ProductSubscriptionID: ps})
require.NoError(t, err)
assert.Len(t, ts, 1)
for v, info := range []license.Info{licenseV1, licenseV2} {
t.Run(fmt.Sprintf("Test v%d", v+1), func(t *testing.T) {
version := v + 1
key := fmt.Sprintf("key%d", version)
pl, err := dbLicenses{db: db}.Create(ctx, ps, key, version, info)
require.NoError(t, err)
got, err := dbLicenses{db: db}.GetByID(ctx, pl)
require.NoError(t, err)
assert.Equal(t, pl, got.ID)
assert.Equal(t, ps, got.ProductSubscriptionID)
assert.Equal(t, key, got.LicenseKey)
require.NotNil(t, got.LicenseVersion)
assert.Equal(t, version, *got.LicenseVersion)
require.NotNil(t, got.LicenseTags)
assert.Equal(t, info.Tags, got.LicenseTags)
require.NotNil(t, got.LicenseUserCount)
assert.Equal(t, int(info.UserCount), *got.LicenseUserCount)
require.NotNil(t, got.LicenseExpiresAt)
assert.Equal(t, info.ExpiresAt, *got.LicenseExpiresAt)
ts, err := dbLicenses{db: db}.List(ctx, dbLicensesListOptions{ProductSubscriptionID: ps})
require.NoError(t, err)
assert.Len(t, ts, version)
// Invalid subscription ID.
ts, err = dbLicenses{db: db}.List(ctx, dbLicensesListOptions{ProductSubscriptionID: "69da12d5-323c-4e42-9d44-cc7951639bca"})
require.NoError(t, err)
assert.Len(t, ts, 0)
})
}
ts, err = dbLicenses{db: db}.List(ctx, dbLicensesListOptions{ProductSubscriptionID: "69da12d5-323c-4e42-9d44-cc7951639bca" /* invalid */})
require.NoError(t, err)
assert.Len(t, ts, 0)
}
func TestProductLicenses_List(t *testing.T) {

View File

@ -101,9 +101,11 @@ func (r *productLicense) CreatedAt() gqlutil.DateTime {
func generateProductLicenseForSubscription(ctx context.Context, db database.DB, subscriptionID string, input *graphqlbackend.ProductLicenseInput) (id string, err error) {
info := license.Info{
Tags: license.SanitizeTagsList(input.Tags),
UserCount: uint(input.UserCount),
ExpiresAt: time.Unix(int64(input.ExpiresAt), 0),
Tags: license.SanitizeTagsList(input.Tags),
UserCount: uint(input.UserCount),
ExpiresAt: time.Unix(int64(input.ExpiresAt), 0),
SalesforceSubscriptionID: input.SalesforceSubscriptionID,
SalesforceOpportunityID: input.SalesforceOpportunityID,
}
licenseKey, version, err := licensing.GenerateProductLicenseKey(info)
if err != nil {

View File

@ -27,12 +27,21 @@ import (
//
// NOTE: If you change these fields, you MUST handle backward compatibility. Existing licenses that
// were generated with the old fields must still work until all customers have added the new
// license. Increment (encodedInfo).Version and formatVersion when you make backward-incompatbile
// changes.
// license. Increment (encodedInfo).Version and modify version() implementation when you make
// backward-incompatbile changes.
type Info struct {
Tags []string `json:"t"` // tags that denote features/restrictions (e.g., "starter" or "dev")
UserCount uint `json:"u"` // the number of users that this license is valid for
ExpiresAt time.Time `json:"e"` // the date when this license expires
// Tags denote features/restrictions (e.g., "starter" or "dev")
Tags []string `json:"t"`
// UserCount is the number of users that this license is valid for
UserCount uint `json:"u"`
// ExpiresAt is the date when this license expires
ExpiresAt time.Time `json:"e"`
// SalesforceSubscriptionID is the optional Salesforce subscription ID to link licenses
// to Salesforce subscriptions
SalesforceSubscriptionID *string `json:"sf_sub_id,omitempty"`
// SalesforceOpportunityID is the optional Salesforce opportunity ID to link licenses
// to Salesforce opportunities
SalesforceOpportunityID *string `json:"sf_opp_id,omitempty"`
}
// IsExpired reports whether the license has expired.
@ -92,10 +101,15 @@ type encodedInfo struct {
Info
}
const formatVersion = 1 // (encodedInfo).Version value
func (l Info) version() int {
if l.SalesforceSubscriptionID == nil {
return 1
}
return 2
}
func (l Info) encode() ([]byte, error) {
e := encodedInfo{Version: formatVersion, Info: l}
e := encodedInfo{Version: l.version(), Info: l}
if _, err := rand.Read(e.Nonce[:8]); err != nil {
return nil, err
}
@ -107,8 +121,8 @@ func (l *Info) decode(data []byte) error {
if err := json.Unmarshal(data, &e); err != nil {
return err
}
if e.Version != formatVersion {
return errors.Errorf("license key format is version %d, expected version %d", e.Version, formatVersion)
if e.Version != e.Info.version() {
return errors.Errorf("license key format is version %d, expected version %d", e.Version, e.Info.version())
}
*l = e.Info
return nil
@ -134,7 +148,7 @@ func GenerateSignedKey(info Info, privateKey ssh.Signer) (licenseKey string, ver
if err != nil {
return "", 0, errors.Wrap(err, "marshal")
}
return base64.RawURLEncoding.EncodeToString(signedKeyData), formatVersion, nil
return base64.RawURLEncoding.EncodeToString(signedKeyData), info.version(), nil
}
// ParseSignedKey parses and verifies the signed license key. If parsing or verification fails, a

View File

@ -71,13 +71,34 @@ mSXt7lUbEmiQep700eM7YlgrOxUVqHsjf1QMrNfq05Ajr8uDfHim
}
var (
timeFixture = time.Date(2018, time.September, 22, 21, 33, 44, 0, time.UTC)
infoFixture = Info{Tags: []string{"a"}, UserCount: 123, ExpiresAt: timeFixture}
timeFixture = time.Date(2018, time.September, 22, 21, 33, 44, 0, time.UTC)
infoV1Fixture = Info{Tags: []string{"a"}, UserCount: 123, ExpiresAt: timeFixture}
sfSubID = "AE0002412312"
sfOpID = "EA890000813"
infoV2Fixture = Info{Tags: []string{"a"}, UserCount: 123, ExpiresAt: timeFixture, SalesforceSubscriptionID: &sfSubID, SalesforceOpportunityID: &sfOpID}
)
func TestInfo_EncodeDecode(t *testing.T) {
t.Run("ok", func(t *testing.T) {
want := infoFixture
t.Run("v1 ok", func(t *testing.T) {
want := infoV1Fixture
data, err := want.encode()
if err != nil {
t.Fatal(err)
}
var got Info
if err := got.decode(data); err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(got, want) {
t.Errorf("got %+v, want %+v", got, want)
}
})
t.Run("v2 ok", func(t *testing.T) {
want := infoV2Fixture
data, err := want.encode()
if err != nil {
t.Fatal(err)
@ -102,8 +123,25 @@ func TestInfo_EncodeDecode(t *testing.T) {
}
func TestGenerateParseSignedKey(t *testing.T) {
t.Run("ok", func(t *testing.T) {
want := infoFixture
t.Run("v1 ok", func(t *testing.T) {
want := infoV1Fixture
text, _, err := GenerateSignedKey(want, privateKey)
if err != nil {
t.Fatal(err)
}
got, _, err := ParseSignedKey(text, publicKey)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(got, &want) {
t.Errorf("got %+v, want %+v", got, &want)
}
})
t.Run("v2 ok", func(t *testing.T) {
want := infoV2Fixture
text, _, err := GenerateSignedKey(want, privateKey)
if err != nil {
t.Fatal(err)
@ -120,7 +158,7 @@ func TestGenerateParseSignedKey(t *testing.T) {
})
t.Run("ignores whitespace", func(t *testing.T) {
want := infoFixture
want := infoV1Fixture
text, _, err := GenerateSignedKey(want, privateKey)
if err != nil {
t.Fatal(err)

View File

@ -20728,6 +20728,19 @@
"GenerationExpression": "",
"Comment": ""
},
{
"Name": "license_check_token",
"Index": 11,
"TypeName": "bytea",
"IsNullable": true,
"Default": "",
"CharacterMaximumLength": 0,
"IsIdentity": false,
"IdentityGeneration": "",
"IsGenerated": "NEVER",
"GenerationExpression": "",
"Comment": ""
},
{
"Name": "license_expires_at",
"Index": 8,
@ -20805,9 +20818,71 @@
"IsGenerated": "NEVER",
"GenerationExpression": "",
"Comment": ""
},
{
"Name": "revoked_at",
"Index": 12,
"TypeName": "timestamp with time zone",
"IsNullable": true,
"Default": "",
"CharacterMaximumLength": 0,
"IsIdentity": false,
"IdentityGeneration": "",
"IsGenerated": "NEVER",
"GenerationExpression": "",
"Comment": ""
},
{
"Name": "salesforce_opp_id",
"Index": 14,
"TypeName": "text",
"IsNullable": true,
"Default": "",
"CharacterMaximumLength": 0,
"IsIdentity": false,
"IdentityGeneration": "",
"IsGenerated": "NEVER",
"GenerationExpression": "",
"Comment": ""
},
{
"Name": "salesforce_sub_id",
"Index": 13,
"TypeName": "text",
"IsNullable": true,
"Default": "",
"CharacterMaximumLength": 0,
"IsIdentity": false,
"IdentityGeneration": "",
"IsGenerated": "NEVER",
"GenerationExpression": "",
"Comment": ""
},
{
"Name": "site_id",
"Index": 10,
"TypeName": "uuid",
"IsNullable": true,
"Default": "",
"CharacterMaximumLength": 0,
"IsIdentity": false,
"IdentityGeneration": "",
"IsGenerated": "NEVER",
"GenerationExpression": "",
"Comment": ""
}
],
"Indexes": [
{
"Name": "product_licenses_license_check_token_idx",
"IsPrimaryKey": false,
"IsUnique": true,
"IsExclusion": false,
"IsDeferrable": false,
"IndexDefinition": "CREATE UNIQUE INDEX product_licenses_license_check_token_idx ON product_licenses USING btree (license_check_token)",
"ConstraintType": "",
"ConstraintDefinition": ""
},
{
"Name": "product_licenses_pkey",
"IsPrimaryKey": true,

View File

@ -3124,8 +3124,14 @@ Indexes:
license_user_count | integer | | |
license_expires_at | timestamp with time zone | | |
access_token_enabled | boolean | | not null | true
site_id | uuid | | |
license_check_token | bytea | | |
revoked_at | timestamp with time zone | | |
salesforce_sub_id | text | | |
salesforce_opp_id | text | | |
Indexes:
"product_licenses_pkey" PRIMARY KEY, btree (id)
"product_licenses_license_check_token_idx" UNIQUE, btree (license_check_token)
Foreign-key constraints:
"product_licenses_product_subscription_id_fkey" FOREIGN KEY (product_subscription_id) REFERENCES product_subscriptions(id)

View File

@ -1006,6 +1006,9 @@ go_library(
"frontend/1685562535_add_missing_ranking_index/down.sql",
"frontend/1685562535_add_missing_ranking_index/metadata.yaml",
"frontend/1685562535_add_missing_ranking_index/up.sql",
"frontend/1685525992_add_license_fields_to_support_auto_billing/down.sql",
"frontend/1685525992_add_license_fields_to_support_auto_billing/metadata.yaml",
"frontend/1685525992_add_license_fields_to_support_auto_billing/up.sql",
"frontend/1685570436_add_ranking_graph_key_table/down.sql",
"frontend/1685570436_add_ranking_graph_key_table/metadata.yaml",
"frontend/1685570436_add_ranking_graph_key_table/up.sql",

View File

@ -0,0 +1,8 @@
DROP INDEX IF EXISTS product_licenses_license_check_token_idx;
ALTER TABLE IF EXISTS product_licenses
DROP COLUMN IF EXISTS site_id,
DROP COLUMN IF EXISTS license_check_token,
DROP COLUMN IF EXISTS revoked_at,
DROP COLUMN IF EXISTS salesforce_sub_id,
DROP COLUMN IF EXISTS salesforce_opp_id;

View File

@ -0,0 +1,2 @@
name: add license fields to support auto billing
parents: [1684854389, 1685105270]

View File

@ -0,0 +1,8 @@
ALTER TABLE IF EXISTS product_licenses
ADD COLUMN IF NOT EXISTS site_id UUID,
ADD COLUMN IF NOT EXISTS license_check_token bytea,
ADD COLUMN IF NOT EXISTS revoked_at timestamptz,
ADD COLUMN IF NOT EXISTS salesforce_sub_id text,
ADD COLUMN IF NOT EXISTS salesforce_opp_id text;
CREATE UNIQUE INDEX IF NOT EXISTS product_licenses_license_check_token_idx ON product_licenses(license_check_token);

View File

@ -3898,7 +3898,12 @@ CREATE TABLE product_licenses (
license_tags text[],
license_user_count integer,
license_expires_at timestamp with time zone,
access_token_enabled boolean DEFAULT true NOT NULL
access_token_enabled boolean DEFAULT true NOT NULL,
site_id uuid,
license_check_token bytea,
revoked_at timestamp with time zone,
salesforce_sub_id text,
salesforce_opp_id text
);
COMMENT ON COLUMN product_licenses.access_token_enabled IS 'Whether this license key can be used as an access token to authenticate API requests';
@ -5851,6 +5856,8 @@ CREATE UNIQUE INDEX permissions_unique_namespace_action ON permissions USING btr
CREATE INDEX process_after_insights_query_runner_jobs_idx ON insights_query_runner_jobs USING btree (process_after);
CREATE UNIQUE INDEX product_licenses_license_check_token_idx ON product_licenses USING btree (license_check_token);
CREATE INDEX registry_extension_releases_registry_extension_id ON registry_extension_releases USING btree (registry_extension_id, release_tag, created_at DESC) WHERE (deleted_at IS NULL);
CREATE INDEX registry_extension_releases_registry_extension_id_created_at ON registry_extension_releases USING btree (registry_extension_id, created_at) WHERE (deleted_at IS NULL);