dotcom: add llm-proxy access state (#51242)

Implements tracking for LLM-proxy access state for #50726:

1. New database fields for tracking access enabled, rate limit, and rate
limit interval
2. GraphQL queries and mutations for LLM-proxy access state
3. Ability to hardcode defaults for rate limits based on active license
plan
4. Super-simple UI for showing LLM-proxy access state, behind feature
flag `llm-proxy-management-ui` (the access token component from #51074
are now behind this flag too)

Supsersedes #51075 

## Test plan

Some lightweight tests, and some manual testing:

<img width="1197" alt="image"
src="https://user-images.githubusercontent.com/23356519/235025565-17655052-d336-4d87-a833-724c45beaee3.png">

![image](https://user-images.githubusercontent.com/23356519/235026164-01358b66-5b7f-43a4-9da9-26a808cd2cb1.png)
This commit is contained in:
Robert Lin 2023-05-01 13:08:01 -07:00 committed by GitHub
parent a946476a6c
commit e84d7c5605
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
24 changed files with 876 additions and 599 deletions

View File

@ -53,6 +53,11 @@ describe('SiteAdminProductSubscriptionPage', () => {
pageInfo: { hasNextPage: false },
},
activeLicense: null,
llmProxyAccess: {
__typename: 'LLMProxyAccess',
enabled: false,
rateLimit: null,
},
})
}
_queryProductLicenses={() =>

View File

@ -22,12 +22,15 @@ import {
H2,
ErrorAlert,
Text,
Checkbox,
H3,
} from '@sourcegraph/wildcard'
import { queryGraphQL, requestGraphQL } from '../../../../backend/graphql'
import { CopyableText } from '../../../../components/CopyableText'
import { FilteredConnection } from '../../../../components/FilteredConnection'
import { PageTitle } from '../../../../components/PageTitle'
import { useFeatureFlag } from '../../../../featureFlags/useFeatureFlag'
import {
ArchiveProductSubscriptionResult,
ArchiveProductSubscriptionVariables,
@ -136,6 +139,9 @@ export const SiteAdminProductSubscriptionPage: React.FunctionComponent<React.Pro
GENERATE_ACCESS_TOKEN_GQL
)
// Feature flag only used as this is under development - will be enabled by default
const [llmProxyManagementUI] = useFeatureFlag('llm-proxy-management-ui')
const nodeProps: Pick<SiteAdminProductLicenseNodeProps, 'showSubscription'> = {
showSubscription: false,
}
@ -199,7 +205,7 @@ export const SiteAdminProductSubscriptionPage: React.FunctionComponent<React.Pro
</tbody>
</table>
</Card>
<Card className="mt-3">
<Card className="mt-3" hidden={!llmProxyManagementUI}>
<CardHeader className="d-flex align-items-center justify-content-between">
Access token
<Button
@ -230,6 +236,20 @@ export const SiteAdminProductSubscriptionPage: React.FunctionComponent<React.Pro
)}
</CardBody>
</Card>
<Card className="mt-3" hidden={!llmProxyManagementUI}>
<CardHeader>Cody services</CardHeader>
<CardBody hidden={!productSubscription.llmProxyAccess.enabled}>
<H3>Completions</H3>
<Checkbox
id="llm-proxy-enabled"
checked={productSubscription.llmProxyAccess.enabled}
disabled={true}
label="Enable access to hosted completions (LLM-proxy)"
className="mb-2"
/>
<Text>Rate limits: {JSON.stringify(productSubscription.llmProxyAccess.rateLimit)}</Text>
</CardBody>
</Card>
<LicenseGenerationKeyWarning className="mt-3" />
<Card className="mt-1">
<CardHeader className="d-flex align-items-center justify-content-between">
@ -315,6 +335,13 @@ function queryProductSubscription(
licenseKey
createdAt
}
llmProxyAccess {
enabled
rateLimit {
limit
intervalSeconds
}
}
createdAt
isArchived
url

View File

@ -140,6 +140,7 @@ exports[`SiteAdminProductSubscriptionPage renders 1`] = `
</div>
<div
class="card mt-3"
hidden=""
>
<div
class="cardHeader d-flex align-items-center justify-content-between"
@ -176,6 +177,48 @@ exports[`SiteAdminProductSubscriptionPage renders 1`] = `
</p>
</div>
</div>
<div
class="card mt-3"
hidden=""
>
<div
class="cardHeader"
>
Cody services
</div>
<div
class="cardBody"
hidden=""
>
<h3
class="h3"
>
Completions
</h3>
<div
class="form-check"
>
<input
class="form-check-input mb-2"
disabled=""
id="llm-proxy-enabled"
label="Enable access to hosted completions (LLM-proxy)"
type="checkbox"
/>
<label
class="label form-check-label label"
for="llm-proxy-enabled"
>
Enable access to hosted completions (LLM-proxy)
</label>
</div>
<p
class=""
>
Rate limits: null
</p>
</div>
</div>
<div
class="card mt-1"
>

View File

@ -26,6 +26,7 @@ export type FeatureFlagName =
| 'clone-progress-logging'
| 'sourcegraph-operator-site-admin-hide-maintenance'
| 'repository-metadata'
| 'llm-proxy-management-ui'
interface OrgFlagOverride {
orgID: string

View File

@ -19,6 +19,7 @@ type DotcomRootResolver interface {
type DotcomResolver interface {
// DotcomMutation
CreateProductSubscription(context.Context, *CreateProductSubscriptionArgs) (ProductSubscription, error)
UpdateProductSubscription(context.Context, *UpdateProductSubscriptionArgs) (*EmptyResponse, error)
GenerateProductLicenseForSubscription(context.Context, *GenerateProductLicenseForSubscriptionArgs) (ProductLicense, error)
GenerateAccessTokenForSubscription(context.Context, *GenerateAccessTokenForSubscriptionArgs) (ProductSubscriptionAccessToken, error)
ArchiveProductSubscription(context.Context, *ArchiveProductSubscriptionArgs) (*EmptyResponse, error)
@ -40,6 +41,7 @@ type ProductSubscription interface {
Account(context.Context) (*UserResolver, error)
ActiveLicense(context.Context) (ProductLicense, error)
ProductLicenses(context.Context, *graphqlutil.ConnectionArgs) (ProductLicenseConnection, error)
LLMProxyAccess() LLMProxyAccess
CreatedAt() gqlutil.DateTime
IsArchived() bool
URL(context.Context) (string, error)
@ -116,3 +118,28 @@ type ProductLicenseConnection interface {
type ProductSubscriptionByAccessTokenArgs struct {
AccessToken string
}
type UpdateProductSubscriptionArgs struct {
ID graphql.ID
Update UpdateProductSubscriptionInput
}
type UpdateProductSubscriptionInput struct {
LLMProxyAccess *UpdateLLMProxyAccessInput
}
type UpdateLLMProxyAccessInput struct {
Enabled *bool
RateLimit *int32
RateLimitIntervalSeconds *int32
}
type LLMProxyAccess interface {
Enabled() bool
RateLimit(context.Context) (LLMProxyRateLimit, error)
}
type LLMProxyRateLimit interface {
Limit() int32
IntervalSeconds() int32
}

View File

@ -47,12 +47,33 @@ type DotcomMutation {
productSubscriptionID: ID!
): ProductSubscriptionAccessToken!
"""
Applies a partial update to a product subscription.
Only Sourcegraph.com site admins may perform this mutation.
FOR INTERNAL USE ONLY.
"""
updateProductSubscription(
"""
The product subscription to update.
"""
id: ID!
"""
Partial update to apply.
"""
update: UpdateProductSubscriptionInput!
): EmptyResponse!
"""
Archives an existing product subscription.
Only Sourcegraph.com site admins may perform this mutation.
FOR INTERNAL USE ONLY.
"""
archiveProductSubscription(id: ID!): EmptyResponse!
archiveProductSubscription(
"""
The product subscription to archive.
"""
id: ID!
): EmptyResponse!
}
"""
@ -260,6 +281,10 @@ type ProductSubscription implements Node {
first: Int
): ProductLicenseConnection!
"""
LLM-proxy access granted to this subscription. Properties may be inferred from the active license, or be defined in overrides.
"""
llmProxyAccess: LLMProxyAccess!
"""
The date when this product subscription was created.
"""
createdAt: DateTime!
@ -288,3 +313,68 @@ type ProductSubscriptionAccessToken {
"""
accessToken: String!
}
"""
LLM-proxy access granted to a subscription.
FOR INTERNAL USE ONLY.
"""
type LLMProxyAccess {
"""
Whether or not a subscription has LLM-proxy access.
It may be true, even if a subscription is archived, as a historical record. However,
archived subscriptions should not be treated as having access to LLM-proxy.
"""
enabled: Boolean!
"""
Rate limits for LLM-proxy access, if access is enabled.
"""
rateLimit: LLMProxyRateLimit
}
"""
LLM-proxy access rate limits for a subscription.
FOR INTERNAL USE ONLY.
"""
type LLMProxyRateLimit {
"""
Requests per time interval.
"""
limit: Int!
"""
Interval for rate limiting.
"""
intervalSeconds: Int!
}
"""
Partial update to apply to a subscription. Omitted fields are not applied.
"""
input UpdateProductSubscriptionInput {
"""
Partial update to LLM-proxy access granted to this subscription.
"""
llmProxyAccess: UpdateLLMProxyAccessInput
}
"""
Partial update to apply to a subscription's LLM-proxy access. Omitted fields are not applied.
"""
input UpdateLLMProxyAccessInput {
"""
Enable or disable LLM-proxy access.
"""
enabled: Boolean
"""
Override default requests per time interval.
Set to 0 to remove the override.
"""
rateLimit: Int
"""
Override default interval for rate limiting.
Set to 0 to remove the override.
"""
rateLimitIntervalSeconds: Int
}

View File

@ -8,6 +8,7 @@ go_library(
"license_expiration.go",
"licenses_db.go",
"licenses_graphql.go",
"llmproxy_graphql.go",
"mock_db.go",
"service_account.go",
"subscriptions_db.go",
@ -50,6 +51,7 @@ go_test(
srcs = [
"license_expiration_test.go",
"licenses_db_test.go",
"llmproxy_graphql_test.go",
"service_account_test.go",
"subscriptions_db_test.go",
"subscriptions_graphql_test.go",
@ -61,6 +63,7 @@ go_test(
"requires-network",
],
deps = [
"//cmd/frontend/graphqlbackend",
"//enterprise/internal/license",
"//enterprise/internal/licensing",
"//internal/actor",
@ -76,6 +79,7 @@ go_test(
"@com_github_derision_test_glock//:glock",
"@com_github_google_go_cmp//cmp",
"@com_github_hexops_autogold_v2//:autogold",
"@com_github_hexops_valast//:valast",
"@com_github_sourcegraph_log//logtest",
"@com_github_stretchr_testify//assert",
"@com_github_stretchr_testify//require",

View File

@ -0,0 +1,48 @@
package productsubscription
import (
"context"
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
"github.com/sourcegraph/sourcegraph/enterprise/internal/licensing"
"github.com/sourcegraph/sourcegraph/lib/errors"
)
type llmProxyAccessResolver struct{ sub *productSubscription }
func (l llmProxyAccessResolver) Enabled() bool { return l.sub.v.LLMProxyAccess.Enabled }
func (l llmProxyAccessResolver) RateLimit(ctx context.Context) (graphqlbackend.LLMProxyRateLimit, error) {
if !l.sub.v.LLMProxyAccess.Enabled {
return nil, nil
}
var rateLimit licensing.LLMProxyRateLimit
// Get default access from active license. Call hydrate and access field directly to
// avoid parsing license key which is done in (*productLicense).Info(), instead just
// relying on what we know in DB.
l.sub.hydrateActiveLicense(ctx)
if l.sub.activeLicenseErr != nil {
return nil, errors.Wrap(l.sub.activeLicenseErr, "could not get active license")
}
if l.sub.activeLicense != nil {
rateLimit = licensing.NewLLMProxyRateLimit(licensing.PlanFromTags(l.sub.activeLicense.LicenseTags))
}
// Apply overrides
rateLimitOverrides := l.sub.v.LLMProxyAccess
if rateLimitOverrides.RateLimit != nil {
rateLimit.Limit = *rateLimitOverrides.RateLimit
}
if rateLimitOverrides.RateIntervalSeconds != nil {
rateLimit.IntervalSeconds = *rateLimitOverrides.RateIntervalSeconds
}
return &llmProxyRateLimitResolver{v: rateLimit}, nil
}
type llmProxyRateLimitResolver struct{ v licensing.LLMProxyRateLimit }
func (l *llmProxyRateLimitResolver) Limit() int32 { return l.v.Limit }
func (l *llmProxyRateLimitResolver) IntervalSeconds() int32 { return l.v.IntervalSeconds }

View File

@ -0,0 +1,70 @@
package productsubscription
import (
"context"
"fmt"
"testing"
"time"
"github.com/sourcegraph/log/logtest"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
"github.com/sourcegraph/sourcegraph/enterprise/internal/license"
"github.com/sourcegraph/sourcegraph/enterprise/internal/licensing"
"github.com/sourcegraph/sourcegraph/internal/database"
"github.com/sourcegraph/sourcegraph/internal/database/dbtest"
"github.com/sourcegraph/sourcegraph/internal/timeutil"
)
func TestLLMProxyAccessResolverRateLimit(t *testing.T) {
logger := logtest.Scoped(t)
db := database.NewDB(logger, dbtest.NewDB(logger, t))
ctx := context.Background()
u, err := db.Users().Create(ctx, database.NewUser{Username: "u"})
require.NoError(t, err)
subID, err := dbSubscriptions{db: db}.Create(ctx, u.ID, "")
require.NoError(t, err)
info := license.Info{
Tags: []string{fmt.Sprintf("plan:%s", licensing.PlanEnterprise1)},
UserCount: 10,
ExpiresAt: timeutil.Now().Add(time.Minute),
}
_, err = dbLicenses{db: db}.Create(ctx, subID, "k2", 1, info)
require.NoError(t, err)
t.Run("default rate limit for a plan", func(t *testing.T) {
sub, err := dbSubscriptions{db: db}.GetByID(ctx, subID)
require.NoError(t, err)
r := llmProxyAccessResolver{sub: &productSubscription{v: sub, db: db}}
wantRateLimit := licensing.NewLLMProxyRateLimit(licensing.PlanEnterprise1)
rateLimit, err := r.RateLimit(ctx)
require.NoError(t, err)
assert.Equal(t, wantRateLimit.Limit, rateLimit.Limit())
assert.Equal(t, wantRateLimit.IntervalSeconds, rateLimit.IntervalSeconds())
})
t.Run("override default rate limit for a plan", func(t *testing.T) {
dbSubscriptions{db: db}.Update(ctx, subID, dbSubscriptionUpdate{
llmProxyAccess: &graphqlbackend.UpdateLLMProxyAccessInput{
RateLimit: pointify(int32(123456)),
},
})
sub, err := dbSubscriptions{db: db}.GetByID(ctx, subID)
require.NoError(t, err)
r := llmProxyAccessResolver{sub: &productSubscription{v: sub, db: db}}
defaultRateLimit := licensing.NewLLMProxyRateLimit(licensing.PlanEnterprise1)
rateLimit, err := r.RateLimit(ctx)
require.NoError(t, err)
assert.Equal(t, int32(123456), rateLimit.Limit())
assert.Equal(t, defaultRateLimit.IntervalSeconds, rateLimit.IntervalSeconds())
})
}

View File

@ -6,3 +6,5 @@ type dbMocks struct {
}
var mocks dbMocks
func pointify[T any](v T) *T { return &v }

View File

@ -9,10 +9,17 @@ import (
"github.com/google/uuid"
"github.com/keegancsmith/sqlf"
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
"github.com/sourcegraph/sourcegraph/internal/database"
"github.com/sourcegraph/sourcegraph/lib/errors"
)
type dbLLMProxyAccess struct {
Enabled bool
RateLimit *int32
RateIntervalSeconds *int32
}
// dbSubscription describes an product subscription row in the product_subscriptions DB
// table.
type dbSubscription struct {
@ -22,6 +29,7 @@ type dbSubscription struct {
CreatedAt time.Time
ArchivedAt *time.Time
AccountNumber *string
LLMProxyAccess dbLLMProxyAccess
}
var emailQueries = sqlf.Sprintf(`all_primary_emails AS (
@ -124,7 +132,10 @@ SELECT
billing_subscription_id,
product_subscriptions.created_at,
product_subscriptions.archived_at,
product_subscriptions.account_number
product_subscriptions.account_number,
product_subscriptions.llm_proxy_enabled,
product_subscriptions.llm_proxy_rate_limit,
product_subscriptions.llm_proxy_rate_interval_seconds
FROM product_subscriptions
LEFT OUTER JOIN users ON product_subscriptions.user_id = users.id
LEFT OUTER JOIN primary_emails ON users.id = primary_emails.user_id
@ -145,7 +156,17 @@ ORDER BY archived_at DESC NULLS FIRST, created_at DESC
var results []*dbSubscription
for rows.Next() {
var v dbSubscription
if err := rows.Scan(&v.ID, &v.UserID, &v.BillingSubscriptionID, &v.CreatedAt, &v.ArchivedAt, &v.AccountNumber); err != nil {
if err := rows.Scan(
&v.ID,
&v.UserID,
&v.BillingSubscriptionID,
&v.CreatedAt,
&v.ArchivedAt,
&v.AccountNumber,
&v.LLMProxyAccess.Enabled,
&v.LLMProxyAccess.RateLimit,
&v.LLMProxyAccess.RateIntervalSeconds,
); err != nil {
return nil, err
}
results = append(results, &v)
@ -174,6 +195,7 @@ WHERE (%s)`, emailQueries, sqlf.Join(opt.sqlConditions(), ") AND ("))
// value is nil, the field remains unchanged in the database.
type dbSubscriptionUpdate struct {
billingSubscriptionID *sql.NullString
llmProxyAccess *graphqlbackend.UpdateLLMProxyAccessInput
}
// Update updates a product subscription.
@ -184,8 +206,26 @@ func (s dbSubscriptions) Update(ctx context.Context, id string, update dbSubscri
if v := update.billingSubscriptionID; v != nil {
fieldUpdates = append(fieldUpdates, sqlf.Sprintf("billing_subscription_id=%s", *v))
}
if access := update.llmProxyAccess; access != nil {
if v := access.Enabled; v != nil {
fieldUpdates = append(fieldUpdates, sqlf.Sprintf("llm_proxy_enabled=%s", *v))
}
if v := access.RateLimit; v != nil {
fieldUpdates = append(fieldUpdates, sqlf.Sprintf("llm_proxy_rate_limit=%s", sql.NullInt64{
Int64: int64(*v),
Valid: *v != 0,
}))
}
if v := access.RateLimitIntervalSeconds; v != nil {
fieldUpdates = append(fieldUpdates, sqlf.Sprintf("llm_proxy_rate_interval_seconds=%s", sql.NullInt64{
Int64: int64(*v),
Valid: *v != 0,
}))
}
}
query := sqlf.Sprintf("UPDATE product_subscriptions SET %s WHERE id=%s", sqlf.Join(fieldUpdates, ", "), id)
query := sqlf.Sprintf("UPDATE product_subscriptions SET %s WHERE id=%s",
sqlf.Join(fieldUpdates, ", "), id)
res, err := s.db.ExecContext(ctx, query.Query(sqlf.PostgresBindVar), query.Args()...)
if err != nil {
return err

View File

@ -4,11 +4,15 @@ import (
"context"
"database/sql"
"testing"
"time"
"github.com/hexops/autogold/v2"
"github.com/hexops/valast"
"github.com/sourcegraph/log/logtest"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
"github.com/sourcegraph/sourcegraph/internal/database"
"github.com/sourcegraph/sourcegraph/internal/database/dbtest"
)
@ -18,14 +22,16 @@ func TestProductSubscriptions_Create(t *testing.T) {
db := database.NewDB(logger, dbtest.NewDB(logger, t))
ctx := context.Background()
subscriptions := dbSubscriptions{db: db}
t.Run("no account number", func(t *testing.T) {
u, err := db.Users().Create(ctx, database.NewUser{Username: "u"})
require.NoError(t, err)
sub, err := dbSubscriptions{db: db}.Create(ctx, u.ID, u.Username)
sub, err := subscriptions.Create(ctx, u.ID, u.Username)
require.NoError(t, err)
got, err := dbSubscriptions{db: db}.GetByID(ctx, sub)
got, err := subscriptions.GetByID(ctx, sub)
require.NoError(t, err)
assert.Equal(t, sub, got.ID)
assert.Equal(t, u.ID, got.UserID)
@ -37,10 +43,10 @@ func TestProductSubscriptions_Create(t *testing.T) {
u, err := db.Users().Create(ctx, database.NewUser{Username: "u-11223344"})
require.NoError(t, err)
sub, err := dbSubscriptions{db: db}.Create(ctx, u.ID, u.Username)
sub, err := subscriptions.Create(ctx, u.ID, u.Username)
require.NoError(t, err)
got, err := dbSubscriptions{db: db}.GetByID(ctx, sub)
got, err := subscriptions.GetByID(ctx, sub)
require.NoError(t, err)
assert.Equal(t, sub, got.ID)
assert.Equal(t, u.ID, got.UserID)
@ -49,11 +55,11 @@ func TestProductSubscriptions_Create(t *testing.T) {
require.NotNil(t, got.AccountNumber)
assert.Equal(t, "11223344", *got.AccountNumber)
ts, err := dbSubscriptions{db: db}.List(ctx, dbSubscriptionsListOptions{UserID: u.ID})
ts, err := subscriptions.List(ctx, dbSubscriptionsListOptions{UserID: u.ID})
require.NoError(t, err)
assert.Len(t, ts, 1)
ts, err = dbSubscriptions{db: db}.List(ctx, dbSubscriptionsListOptions{UserID: 123 /* invalid */})
ts, err = subscriptions.List(ctx, dbSubscriptionsListOptions{UserID: 123 /* invalid */})
require.NoError(t, err)
assert.Len(t, ts, 0)
}
@ -64,62 +70,39 @@ func TestProductSubscriptions_List(t *testing.T) {
ctx := context.Background()
u1, err := db.Users().Create(ctx, database.NewUser{Username: "u1"})
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
u2, err := db.Users().Create(ctx, database.NewUser{Username: "u2"})
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
_, err = dbSubscriptions{db: db}.Create(ctx, u1.ID, "")
if err != nil {
t.Fatal(err)
}
_, err = dbSubscriptions{db: db}.Create(ctx, u1.ID, "")
if err != nil {
t.Fatal(err)
}
subscriptions := dbSubscriptions{db: db}
{
// List all product subscriptions.
ts, err := dbSubscriptions{db: db}.List(ctx, dbSubscriptionsListOptions{})
if err != nil {
t.Fatal(err)
}
if want := 2; len(ts) != want {
t.Errorf("got %d product subscriptions, want %d", len(ts), want)
}
count, err := dbSubscriptions{db: db}.Count(ctx, dbSubscriptionsListOptions{})
if err != nil {
t.Fatal(err)
}
if want := 2; count != want {
t.Errorf("got %d, want %d", count, want)
}
}
_, err = subscriptions.Create(ctx, u1.ID, "")
require.NoError(t, err)
_, err = subscriptions.Create(ctx, u1.ID, "")
require.NoError(t, err)
{
t.Run("List all product subscriptions", func(t *testing.T) {
ts, err := subscriptions.List(ctx, dbSubscriptionsListOptions{})
require.NoError(t, err)
assert.Equal(t, 2, len(ts))
count, err := subscriptions.Count(ctx, dbSubscriptionsListOptions{})
require.NoError(t, err)
assert.Equal(t, 2, count)
})
t.Run("List u1's product subscriptions", func(t *testing.T) {
// List u1's product subscriptions.
ts, err := dbSubscriptions{db: db}.List(ctx, dbSubscriptionsListOptions{UserID: u1.ID})
if err != nil {
t.Fatal(err)
}
if want := 2; len(ts) != want {
t.Errorf("got %d product subscriptions, want %d", len(ts), want)
}
}
ts, err := subscriptions.List(ctx, dbSubscriptionsListOptions{UserID: u1.ID})
require.NoError(t, err)
assert.Equal(t, 2, len(ts))
})
{
// List u2's product subscriptions.
ts, err := dbSubscriptions{db: db}.List(ctx, dbSubscriptionsListOptions{UserID: u2.ID})
if err != nil {
t.Fatal(err)
}
if want := 0; len(ts) != want {
t.Errorf("got %d product subscriptions, want %d", len(ts), want)
}
}
t.Run("List u2's product subscriptions", func(t *testing.T) {
ts, err := subscriptions.List(ctx, dbSubscriptionsListOptions{UserID: u2.ID})
require.NoError(t, err)
assert.Equal(t, 0, len(ts))
})
}
func TestProductSubscriptions_Update(t *testing.T) {
@ -128,54 +111,80 @@ func TestProductSubscriptions_Update(t *testing.T) {
ctx := context.Background()
u, err := db.Users().Create(ctx, database.NewUser{Username: "u"})
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
sub0, err := dbSubscriptions{db: db}.Create(ctx, u.ID, "")
if err != nil {
t.Fatal(err)
}
if got, err := (dbSubscriptions{db: db}).GetByID(ctx, sub0); err != nil {
t.Fatal(err)
} else if got.BillingSubscriptionID != nil {
t.Errorf("got %q, want nil", *got.BillingSubscriptionID)
}
subscriptions := dbSubscriptions{db: db}
// Set non-null value.
if err := (dbSubscriptions{db: db}).Update(ctx, sub0, dbSubscriptionUpdate{
billingSubscriptionID: &sql.NullString{
String: "x",
Valid: true,
},
}); err != nil {
t.Fatal(err)
}
if got, err := (dbSubscriptions{db: db}).GetByID(ctx, sub0); err != nil {
t.Fatal(err)
} else if want := "x"; got.BillingSubscriptionID == nil || *got.BillingSubscriptionID != want {
t.Errorf("got %v, want %q", got.BillingSubscriptionID, want)
}
sub0, err := subscriptions.Create(ctx, u.ID, "")
require.NoError(t, err)
got, err := subscriptions.GetByID(ctx, sub0)
require.NoError(t, err)
require.Nil(t, got.BillingSubscriptionID)
// Update no fields.
if err := (dbSubscriptions{db: db}).Update(ctx, sub0, dbSubscriptionUpdate{billingSubscriptionID: nil}); err != nil {
t.Fatal(err)
}
if got, err := (dbSubscriptions{db: db}).GetByID(ctx, sub0); err != nil {
t.Fatal(err)
} else if want := "x"; got.BillingSubscriptionID == nil || *got.BillingSubscriptionID != want {
t.Errorf("got %v, want %q", got.BillingSubscriptionID, want)
}
t.Run("billingSubscriptionID", func(t *testing.T) {
t.Run("set non-null value", func(t *testing.T) {
err := subscriptions.Update(ctx, sub0, dbSubscriptionUpdate{
billingSubscriptionID: &sql.NullString{
String: "x",
Valid: true,
},
})
require.NoError(t, err)
got, err := subscriptions.GetByID(ctx, sub0)
require.NoError(t, err)
autogold.Expect(valast.Addr("x").(*string)).Equal(t, got.BillingSubscriptionID)
})
// Set null value.
if err := (dbSubscriptions{db: db}).Update(ctx, sub0, dbSubscriptionUpdate{
billingSubscriptionID: &sql.NullString{Valid: false},
}); err != nil {
t.Fatal(err)
}
if got, err := (dbSubscriptions{db: db}).GetByID(ctx, sub0); err != nil {
t.Fatal(err)
} else if got.BillingSubscriptionID != nil {
t.Errorf("got %q, want nil", *got.BillingSubscriptionID)
}
t.Run("update no fields", func(t *testing.T) {
err := subscriptions.Update(ctx, sub0, dbSubscriptionUpdate{})
require.NoError(t, err)
got, err := subscriptions.GetByID(ctx, sub0)
require.NoError(t, err)
autogold.Expect(valast.Addr("x").(*string)).Equal(t, got.BillingSubscriptionID)
})
// Set null value.
t.Run("set null value", func(t *testing.T) {
err := subscriptions.Update(ctx, sub0, dbSubscriptionUpdate{
billingSubscriptionID: &sql.NullString{Valid: false},
})
require.NoError(t, err)
got, err := subscriptions.GetByID(ctx, sub0)
require.NoError(t, err)
autogold.Expect((*string)(nil)).Equal(t, got.BillingSubscriptionID)
})
})
t.Run("llmProxyAccess", func(t *testing.T) {
t.Run("set non-null values", func(t *testing.T) {
err := subscriptions.Update(ctx, sub0, dbSubscriptionUpdate{
llmProxyAccess: &graphqlbackend.UpdateLLMProxyAccessInput{
Enabled: pointify(true),
RateLimit: pointify(int32(12)),
RateLimitIntervalSeconds: pointify(int32(time.Hour.Seconds())),
},
})
require.NoError(t, err)
got, err := subscriptions.GetByID(ctx, sub0)
require.NoError(t, err)
autogold.Expect(dbLLMProxyAccess{
Enabled: true, RateLimit: valast.Addr(int32(12)).(*int32),
RateIntervalSeconds: valast.Addr(int32(3600)).(*int32),
}).Equal(t, got.LLMProxyAccess)
})
t.Run("set to zero/null values", func(t *testing.T) {
err := subscriptions.Update(ctx, sub0, dbSubscriptionUpdate{
llmProxyAccess: &graphqlbackend.UpdateLLMProxyAccessInput{
Enabled: pointify(false),
RateLimit: pointify(int32(0)),
RateLimitIntervalSeconds: pointify(int32(0)),
},
})
require.NoError(t, err)
got, err := subscriptions.GetByID(ctx, sub0)
require.NoError(t, err)
autogold.Expect(dbLLMProxyAccess{}).Equal(t, got.LLMProxyAccess)
})
})
}

View File

@ -20,9 +20,14 @@ import (
)
// productSubscription implements the GraphQL type ProductSubscription.
// It must not be copied.
type productSubscription struct {
db database.DB
v *dbSubscription
activeLicense *dbLicense
activeLicenseErr error
activeLicenseOnce sync.Once
}
// ProductSubscriptionByID looks up and returns the ProductSubscription with the given GraphQL
@ -95,15 +100,23 @@ func (r *productSubscription) Account(ctx context.Context) (*graphqlbackend.User
}
func (r *productSubscription) ActiveLicense(ctx context.Context) (graphqlbackend.ProductLicense, error) {
// Return newest license.
active, err := dbLicenses{db: r.db}.Active(ctx, r.v.ID)
if err != nil {
return nil, err
r.hydrateActiveLicense(ctx)
if r.activeLicenseErr != nil {
return nil, r.activeLicenseErr
}
if active == nil {
if r.activeLicense == nil {
return nil, nil
}
return &productLicense{db: r.db, v: active}, nil
return &productLicense{db: r.db, v: r.activeLicense}, nil
}
// hydrateActiveLicense populates r.activeLicense and r.activeLicenseErr once,
// make sure this is called before attempting to use either.
func (r *productSubscription) hydrateActiveLicense(ctx context.Context) {
// Get newest license.
r.activeLicenseOnce.Do(func() {
r.activeLicense, r.activeLicenseErr = dbLicenses{db: r.db}.Active(ctx, r.v.ID)
})
}
func (r *productSubscription) ProductLicenses(ctx context.Context, args *graphqlutil.ConnectionArgs) (graphqlbackend.ProductLicenseConnection, error) {
@ -118,6 +131,10 @@ func (r *productSubscription) ProductLicenses(ctx context.Context, args *graphql
return &productLicenseConnection{db: r.db, opt: opt}, nil
}
func (r *productSubscription) LLMProxyAccess() graphqlbackend.LLMProxyAccess {
return llmProxyAccessResolver{sub: r}
}
func (r *productSubscription) CreatedAt() gqlutil.DateTime {
return gqlutil.DateTime{Time: r.v.CreatedAt}
}
@ -159,6 +176,25 @@ func (r ProductSubscriptionLicensingResolver) CreateProductSubscription(ctx cont
return productSubscriptionByDBID(ctx, r.DB, id)
}
func (r ProductSubscriptionLicensingResolver) UpdateProductSubscription(ctx context.Context, args *graphqlbackend.UpdateProductSubscriptionArgs) (*graphqlbackend.EmptyResponse, error) {
// 🚨 SECURITY: Only site admins may update product subscriptions.
if err := auth.CheckCurrentUserIsSiteAdmin(ctx, r.DB); err != nil {
return nil, err
}
sub, err := productSubscriptionByID(ctx, r.DB, args.ID)
if err != nil {
return nil, err
}
if err := (dbSubscriptions{db: r.DB}).Update(ctx, sub.v.ID, dbSubscriptionUpdate{
llmProxyAccess: args.Update.LLMProxyAccess,
}); err != nil {
return nil, err
}
return &graphqlbackend.EmptyResponse{}, nil
}
func (r ProductSubscriptionLicensingResolver) ArchiveProductSubscription(ctx context.Context, args *graphqlbackend.ArchiveProductSubscriptionArgs) (*graphqlbackend.EmptyResponse, error) {
// 🚨 SECURITY: Only site admins may archive product subscriptions.
if err := auth.CheckCurrentUserIsSiteAdmin(ctx, r.DB); err != nil {

View File

@ -8,6 +8,7 @@ go_library(
"doc.go",
"features.go",
"licensing.go",
"llmproxy.go",
"plans.go",
"tags.go",
"user_count.go",

View File

@ -0,0 +1,26 @@
package licensing
// LLMProxyRateLimit indicates rate limits for Sourcegraph's managed LLM-proxy service.
//
// Zero values in either field indicates no access.
type LLMProxyRateLimit struct {
Limit int32
IntervalSeconds int32
}
// NewLLMProxyRateLimit applies default LLM-proxy access based on the plan.
func NewLLMProxyRateLimit(plan Plan) LLMProxyRateLimit {
switch plan {
// TODO: This is just an example for now.
case PlanEnterprise1:
return LLMProxyRateLimit{
Limit: 50,
IntervalSeconds: 60 * 60 * 24, // day
}
// TODO: Defaults for other plans
default:
return LLMProxyRateLimit{}
}
}

View File

@ -66,23 +66,7 @@ func (p Plan) IsFree() bool {
// Plan is the pricing plan of the license.
func (info *Info) Plan() Plan {
for _, tag := range info.Tags {
// A tag that begins with "plan:" indicates the license's plan.
if strings.HasPrefix(tag, planTagPrefix) {
plan := Plan(tag[len(planTagPrefix):])
if plan.isKnown() {
return plan
}
}
// Backcompat: support the old "starter" tag (which mapped to "Enterprise Starter").
if tag == "starter" {
return PlanOldEnterpriseStarter
}
}
// Backcompat: no tags means it is the old "Enterprise" plan.
return PlanOldEnterprise
return PlanFromTags(info.Tags)
}
// hasUnknownPlan returns an error if the plan is presented in the license tags
@ -101,3 +85,24 @@ func (info *Info) hasUnknownPlan() error {
}
return nil
}
// PlanFromTags returns the pricing plan of the license, based on the given tags.
func PlanFromTags(tags []string) Plan {
for _, tag := range tags {
// A tag that begins with "plan:" indicates the license's plan.
if strings.HasPrefix(tag, planTagPrefix) {
plan := Plan(tag[len(planTagPrefix):])
if plan.isKnown() {
return plan
}
}
// Backcompat: support the old "starter" tag (which mapped to "Enterprise Starter").
if tag == "starter" {
return PlanOldEnterpriseStarter
}
}
// Backcompat: no tags means it is the old "Enterprise" plan.
return PlanOldEnterprise
}

View File

@ -20056,6 +20056,45 @@
"GenerationExpression": "",
"Comment": ""
},
{
"Name": "llm_proxy_enabled",
"Index": 8,
"TypeName": "boolean",
"IsNullable": false,
"Default": "true",
"CharacterMaximumLength": 0,
"IsIdentity": false,
"IdentityGeneration": "",
"IsGenerated": "NEVER",
"GenerationExpression": "",
"Comment": "Whether or not this subscription has access to LLM-proxy"
},
{
"Name": "llm_proxy_rate_interval_seconds",
"Index": 10,
"TypeName": "integer",
"IsNullable": true,
"Default": "",
"CharacterMaximumLength": 0,
"IsIdentity": false,
"IdentityGeneration": "",
"IsGenerated": "NEVER",
"GenerationExpression": "",
"Comment": "Custom time interval over which the for LLM-proxy rate limit is applied"
},
{
"Name": "llm_proxy_rate_limit",
"Index": 9,
"TypeName": "integer",
"IsNullable": true,
"Default": "",
"CharacterMaximumLength": 0,
"IsIdentity": false,
"IdentityGeneration": "",
"IsGenerated": "NEVER",
"GenerationExpression": "",
"Comment": "Custom requests per time interval allowed for LLM-proxy"
},
{
"Name": "updated_at",
"Index": 5,

View File

@ -3001,15 +3001,18 @@ Foreign-key constraints:
# Table "public.product_subscriptions"
```
Column | Type | Collation | Nullable | Default
-------------------------+--------------------------+-----------+----------+---------
id | uuid | | not null |
user_id | integer | | not null |
billing_subscription_id | text | | |
created_at | timestamp with time zone | | not null | now()
updated_at | timestamp with time zone | | not null | now()
archived_at | timestamp with time zone | | |
account_number | text | | |
Column | Type | Collation | Nullable | Default
---------------------------------+--------------------------+-----------+----------+---------
id | uuid | | not null |
user_id | integer | | not null |
billing_subscription_id | text | | |
created_at | timestamp with time zone | | not null | now()
updated_at | timestamp with time zone | | not null | now()
archived_at | timestamp with time zone | | |
account_number | text | | |
llm_proxy_enabled | boolean | | not null | true
llm_proxy_rate_limit | integer | | |
llm_proxy_rate_interval_seconds | integer | | |
Indexes:
"product_subscriptions_pkey" PRIMARY KEY, btree (id)
Foreign-key constraints:
@ -3019,6 +3022,12 @@ Referenced by:
```
**llm_proxy_enabled**: Whether or not this subscription has access to LLM-proxy
**llm_proxy_rate_interval_seconds**: Custom time interval over which the for LLM-proxy rate limit is applied
**llm_proxy_rate_limit**: Custom requests per time interval allowed for LLM-proxy
# Table "public.query_runner_state"
```
Column | Type | Collation | Nullable | Default

View File

@ -901,6 +901,9 @@ go_library(
"frontend/1682114198_product_license_access_tokens/metadata.yaml",
"frontend/1682114198_product_license_access_tokens/up.sql",
"frontend/squashed.sql",
"frontend/1682626931_subscription_llm_proxy_state/down.sql",
"frontend/1682626931_subscription_llm_proxy_state/metadata.yaml",
"frontend/1682626931_subscription_llm_proxy_state/up.sql",
],
importpath = "github.com/sourcegraph/sourcegraph/migrations",
visibility = ["//visibility:public"],

View File

@ -0,0 +1,4 @@
ALTER TABLE product_subscriptions
DROP COLUMN IF EXISTS llm_proxy_enabled,
DROP COLUMN IF EXISTS llm_proxy_rate_limit,
DROP COLUMN IF EXISTS llm_proxy_rate_interval_seconds;

View File

@ -0,0 +1,2 @@
name: subscription_llm_proxy_state
parents: [1682012624]

View File

@ -0,0 +1,20 @@
ALTER TABLE product_subscriptions
ADD COLUMN IF NOT EXISTS llm_proxy_enabled BOOLEAN NOT NULL DEFAULT TRUE,
ADD COLUMN IF NOT EXISTS llm_proxy_rate_limit INTEGER,
ADD COLUMN IF NOT EXISTS llm_proxy_rate_interval_seconds INTEGER;
COMMENT ON COLUMN product_subscriptions.llm_proxy_enabled IS 'Whether or not this subscription has access to LLM-proxy';
COMMENT ON COLUMN product_subscriptions.llm_proxy_rate_limit IS 'Custom requests per time interval allowed for LLM-proxy';
COMMENT ON COLUMN product_subscriptions.llm_proxy_rate_interval_seconds IS 'Custom time interval over which the for LLM-proxy rate limit is applied';
-- Initially, mark any subscription that has no active license as without LLM-proxy access,
-- since there are a lot of old subscriptions out there.
UPDATE product_subscriptions
SET llm_proxy_enabled = false
WHERE id IN (
SELECT product_subscription_id
FROM product_licenses
WHERE license_expires_at > NOW()
GROUP BY product_subscription_id
HAVING COUNT(*) = 0
);

View File

@ -3732,9 +3732,18 @@ CREATE TABLE product_subscriptions (
created_at timestamp with time zone DEFAULT now() NOT NULL,
updated_at timestamp with time zone DEFAULT now() NOT NULL,
archived_at timestamp with time zone,
account_number text
account_number text,
llm_proxy_enabled boolean DEFAULT true NOT NULL,
llm_proxy_rate_limit integer,
llm_proxy_rate_interval_seconds integer
);
COMMENT ON COLUMN product_subscriptions.llm_proxy_enabled IS 'Whether or not this subscription has access to LLM-proxy';
COMMENT ON COLUMN product_subscriptions.llm_proxy_rate_limit IS 'Custom requests per time interval allowed for LLM-proxy';
COMMENT ON COLUMN product_subscriptions.llm_proxy_rate_interval_seconds IS 'Custom time interval over which the for LLM-proxy rate limit is applied';
CREATE TABLE query_runner_state (
query text,
last_executed timestamp with time zone,

File diff suppressed because it is too large Load Diff