feat/enterpriseportal: use database for reading Cody Gateway access

This commit is contained in:
Robert Lin 2024-07-18 12:01:29 -07:00
parent 5916b1f248
commit 4064629c48
17 changed files with 444 additions and 125 deletions

View File

@ -13,16 +13,22 @@ go_library(
visibility = ["//cmd/enterprise-portal:__subpackages__"],
deps = [
"//cmd/enterprise-portal/internal/connectutil",
"//cmd/enterprise-portal/internal/database",
"//cmd/enterprise-portal/internal/database/codyaccess",
"//cmd/enterprise-portal/internal/dotcomdb",
"//cmd/enterprise-portal/internal/samsm2m",
"//internal/codygateway/codygatewayactor",
"//internal/codygateway/codygatewayevents",
"//internal/completions/types",
"//internal/license",
"//internal/licensing",
"//internal/productsubscription",
"//internal/trace",
"//lib/enterpriseportal/codyaccess/v1:codyaccess",
"//lib/enterpriseportal/codyaccess/v1/v1connect",
"//lib/enterpriseportal/subscriptions/v1:subscriptions",
"//lib/errors",
"//lib/pointers",
"@com_connectrpc_connect//:connect",
"@com_github_sourcegraph_conc//pool",
"@com_github_sourcegraph_log//:log",
@ -38,7 +44,7 @@ go_test(
srcs = ["adapters_test.go"],
embed = [":codyaccessservice"],
deps = [
"//cmd/enterprise-portal/internal/dotcomdb",
"//cmd/enterprise-portal/internal/database/codyaccess",
"@com_github_hexops_autogold_v2//:autogold",
"@com_github_stretchr_testify//assert",
],

View File

@ -1,46 +1,51 @@
package codyaccessservice
import (
"encoding/hex"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/dotcomdb"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/codyaccess"
"github.com/sourcegraph/sourcegraph/internal/codygateway/codygatewayevents"
"github.com/sourcegraph/sourcegraph/internal/license"
"github.com/sourcegraph/sourcegraph/internal/licensing"
codyaccessv1 "github.com/sourcegraph/sourcegraph/lib/enterpriseportal/codyaccess/v1"
subscriptionsv1 "github.com/sourcegraph/sourcegraph/lib/enterpriseportal/subscriptions/v1"
"github.com/sourcegraph/sourcegraph/lib/pointers"
)
func convertAccessAttrsToProto(attrs *dotcomdb.CodyGatewayAccessAttributes) *codyaccessv1.CodyGatewayAccess {
func convertAccessAttrsToProto(access *codyaccess.CodyGatewayAccessWithSubscriptionDetails) *codyaccessv1.CodyGatewayAccess {
// Provide ID in prefixed format.
subscriptionID := subscriptionsv1.EnterpriseSubscriptionIDPrefix + attrs.SubscriptionID
subscriptionID := subscriptionsv1.EnterpriseSubscriptionIDPrefix + access.SubscriptionID
// Always try to return the full response, since even when disabled, some
// features may be allowed via Cody Gateway (notably attributions). This
// also allows Cody Gateway to cache the state of actors that are disabled.
limits := attrs.EvaluateRateLimits()
limits := evaluateCodyGatewayAccessRateLimits(access)
return &codyaccessv1.CodyGatewayAccess{
SubscriptionId: subscriptionID,
SubscriptionDisplayName: attrs.GetSubscriptionDisplayName(),
Enabled: attrs.CodyGatewayEnabled,
SubscriptionDisplayName: access.DisplayName,
Enabled: access.Enabled,
// Rate limits return nil if not enabled, per API spec
ChatCompletionsRateLimit: nilIfNotEnabled(attrs.CodyGatewayEnabled, &codyaccessv1.CodyGatewayRateLimit{
ChatCompletionsRateLimit: nilIfNotEnabled(access.Enabled, &codyaccessv1.CodyGatewayRateLimit{
Source: limits.ChatSource,
Limit: uint64(limits.Chat.Limit),
IntervalDuration: durationpb.New(limits.Chat.IntervalDuration()),
}),
CodeCompletionsRateLimit: nilIfNotEnabled(attrs.CodyGatewayEnabled, &codyaccessv1.CodyGatewayRateLimit{
CodeCompletionsRateLimit: nilIfNotEnabled(access.Enabled, &codyaccessv1.CodyGatewayRateLimit{
Source: limits.CodeSource,
Limit: uint64(limits.Code.Limit),
IntervalDuration: durationpb.New(limits.Code.IntervalDuration()),
}),
EmbeddingsRateLimit: nilIfNotEnabled(attrs.CodyGatewayEnabled, &codyaccessv1.CodyGatewayRateLimit{
EmbeddingsRateLimit: nilIfNotEnabled(access.Enabled, &codyaccessv1.CodyGatewayRateLimit{
Source: limits.EmbeddingsSource,
Limit: uint64(limits.Embeddings.Limit),
IntervalDuration: durationpb.New(limits.Embeddings.IntervalDuration()),
}),
// This is always provided, even if access is disabled
AccessTokens: func() []*codyaccessv1.CodyGatewayAccessToken {
accessTokens := attrs.GenerateAccessTokens()
accessTokens := generateCodyGatewayAccessTokens(access)
if len(accessTokens) == 0 {
return []*codyaccessv1.CodyGatewayAccessToken{}
}
@ -56,6 +61,18 @@ func convertAccessAttrsToProto(attrs *dotcomdb.CodyGatewayAccessAttributes) *cod
}
}
func generateCodyGatewayAccessTokens(access *codyaccess.CodyGatewayAccessWithSubscriptionDetails) []string {
accessTokens := make([]string, 0, len(access.LicenseKeyHashes))
for _, t := range access.LicenseKeyHashes {
if len(t) == 0 { // query can return empty hashes, ignore these
continue
}
// See license.GenerateLicenseKeyBasedAccessToken
accessTokens = append(accessTokens, license.LicenseKeyBasedAccessTokenPrefix+hex.EncodeToString(t))
}
return accessTokens
}
func nilIfNotEnabled[T any](enabled bool, value *T) *T {
if !enabled {
return nil
@ -63,6 +80,68 @@ func nilIfNotEnabled[T any](enabled bool, value *T) *T {
return value
}
type CodyGatewayRateLimits struct {
ChatSource codyaccessv1.CodyGatewayRateLimitSource
Chat licensing.CodyGatewayRateLimit
CodeSource codyaccessv1.CodyGatewayRateLimitSource
Code licensing.CodyGatewayRateLimit
EmbeddingsSource codyaccessv1.CodyGatewayRateLimitSource
Embeddings licensing.CodyGatewayRateLimit
}
func maybeApplyOverride[T ~int32 | ~int64](limit *T, overrideValue T, overrideValid bool) codyaccessv1.CodyGatewayRateLimitSource {
if overrideValid {
*limit = overrideValue
return codyaccessv1.CodyGatewayRateLimitSource_CODY_GATEWAY_RATE_LIMIT_SOURCE_OVERRIDE
}
// No override
return codyaccessv1.CodyGatewayRateLimitSource_CODY_GATEWAY_RATE_LIMIT_SOURCE_PLAN
}
// evaluateCodyGatewayAccessRateLimits returns the current CodyGatewayRateLimits based on the
// plan and applying known overrides on top. This closely models the existing
// codyGatewayAccessResolver in 'cmd/frontend/internal/dotcom/productsubscription'.
func evaluateCodyGatewayAccessRateLimits(access *codyaccess.CodyGatewayAccessWithSubscriptionDetails) CodyGatewayRateLimits {
// Set defaults for everything based on active licnese plan and user count.
// If there isn't one, zero values apply.
activeLicense := pointers.DerefZero(access.ActiveLicenseInfo)
p := licensing.PlanFromTags(activeLicense.Tags)
userCount := pointers.Ptr(int(activeLicense.UserCount))
limits := CodyGatewayRateLimits{
ChatSource: codyaccessv1.CodyGatewayRateLimitSource_CODY_GATEWAY_RATE_LIMIT_SOURCE_PLAN,
Chat: licensing.NewCodyGatewayChatRateLimit(p, userCount),
CodeSource: codyaccessv1.CodyGatewayRateLimitSource_CODY_GATEWAY_RATE_LIMIT_SOURCE_PLAN,
Code: licensing.NewCodyGatewayCodeRateLimit(p, userCount),
EmbeddingsSource: codyaccessv1.CodyGatewayRateLimitSource_CODY_GATEWAY_RATE_LIMIT_SOURCE_PLAN,
Embeddings: licensing.NewCodyGatewayEmbeddingsRateLimit(p, userCount),
}
// Chat
limits.ChatSource = maybeApplyOverride(&limits.Chat.Limit,
access.ChatCompletionsRateLimit.Int64, access.ChatCompletionsRateLimit.Valid)
limits.ChatSource = maybeApplyOverride(&limits.Chat.IntervalSeconds,
access.ChatCompletionsRateLimitIntervalSeconds.Int32, access.ChatCompletionsRateLimitIntervalSeconds.Valid)
// Code
limits.CodeSource = maybeApplyOverride(&limits.Code.Limit,
access.CodeCompletionsRateLimit.Int64, access.CodeCompletionsRateLimit.Valid)
limits.CodeSource = maybeApplyOverride(&limits.Code.IntervalSeconds,
access.CodeCompletionsRateLimitIntervalSeconds.Int32, access.CodeCompletionsRateLimitIntervalSeconds.Valid)
// Embeddings
limits.EmbeddingsSource = maybeApplyOverride(&limits.Embeddings.Limit,
access.EmbeddingsRateLimit.Int64, access.EmbeddingsRateLimit.Valid)
limits.EmbeddingsSource = maybeApplyOverride(&limits.Embeddings.IntervalSeconds,
access.EmbeddingsRateLimitIntervalSeconds.Int32, access.EmbeddingsRateLimitIntervalSeconds.Valid)
return limits
}
func convertCodyGatewayUsageDatapoints(usage []codygatewayevents.SubscriptionUsage) []*codyaccessv1.CodyGatewayUsage_UsageDatapoint {
results := make([]*codyaccessv1.CodyGatewayUsage_UsageDatapoint, len(usage))
for i, datapoint := range usage {

View File

@ -7,19 +7,21 @@ import (
"github.com/hexops/autogold/v2"
"github.com/stretchr/testify/assert"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/dotcomdb"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/codyaccess"
)
func TestConvertAccessAttrsToProto(t *testing.T) {
t.Run("zero value", func(t *testing.T) {
proto := convertAccessAttrsToProto(&dotcomdb.CodyGatewayAccessAttributes{})
proto := convertAccessAttrsToProto(&codyaccess.CodyGatewayAccessWithSubscriptionDetails{})
assert.False(t, proto.Enabled)
})
t.Run("disabled returns access tokens", func(t *testing.T) {
proto := convertAccessAttrsToProto(&dotcomdb.CodyGatewayAccessAttributes{
CodyGatewayEnabled: false,
LicenseKeyHashes: [][]byte{[]byte("abc"), []byte("efg")},
proto := convertAccessAttrsToProto(&codyaccess.CodyGatewayAccessWithSubscriptionDetails{
CodyGatewayAccess: codyaccess.CodyGatewayAccess{
Enabled: false,
},
LicenseKeyHashes: [][]byte{[]byte("abc"), []byte("efg")},
})
assert.False(t, proto.Enabled)
// NOTE: These are not real access tokens
@ -31,18 +33,22 @@ func TestConvertAccessAttrsToProto(t *testing.T) {
})
t.Run("enabled with empty access token", func(t *testing.T) {
proto := convertAccessAttrsToProto(&dotcomdb.CodyGatewayAccessAttributes{
CodyGatewayEnabled: true,
LicenseKeyHashes: [][]byte{[]byte(""), nil},
proto := convertAccessAttrsToProto(&codyaccess.CodyGatewayAccessWithSubscriptionDetails{
CodyGatewayAccess: codyaccess.CodyGatewayAccess{
Enabled: true,
},
LicenseKeyHashes: [][]byte{[]byte(""), nil},
})
assert.True(t, proto.Enabled)
assert.Empty(t, proto.GetAccessTokens())
})
t.Run("enabled returns everything", func(t *testing.T) {
proto := convertAccessAttrsToProto(&dotcomdb.CodyGatewayAccessAttributes{
CodyGatewayEnabled: true,
LicenseKeyHashes: [][]byte{[]byte("abc"), []byte("efg")},
proto := convertAccessAttrsToProto(&codyaccess.CodyGatewayAccessWithSubscriptionDetails{
CodyGatewayAccess: codyaccess.CodyGatewayAccess{
Enabled: true,
},
LicenseKeyHashes: [][]byte{[]byte("abc"), []byte("efg")},
})
assert.True(t, proto.Enabled)
// NOTE: These are not real access tokens

View File

@ -12,6 +12,7 @@ import (
"github.com/sourcegraph/sourcegraph-accounts-sdk-go/scopes"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/connectutil"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/codyaccess"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/dotcomdb"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/samsm2m"
"github.com/sourcegraph/sourcegraph/internal/trace"
@ -23,42 +24,40 @@ import (
const Name = codyaccessv1connect.CodyAccessServiceName
type DotComDB interface {
GetCodyGatewayAccessAttributesBySubscription(ctx context.Context, subscriptionID string) (*dotcomdb.CodyGatewayAccessAttributes, error)
GetCodyGatewayAccessAttributesByAccessToken(ctx context.Context, subscriptionID string) (*dotcomdb.CodyGatewayAccessAttributes, error)
GetAllCodyGatewayAccessAttributes(ctx context.Context) ([]*dotcomdb.CodyGatewayAccessAttributes, error)
}
func RegisterV1(
logger log.Logger,
mux *http.ServeMux,
store StoreV1,
dotcom DotComDB,
opts ...connect.HandlerOption,
) {
mux.Handle(
codyaccessv1connect.NewCodyAccessServiceHandler(
&handlerV1{
logger: logger.Scoped("codyaccess.v1"),
store: store,
dotcom: dotcom,
},
NewHandlerV1(logger, store),
opts...,
),
)
}
type handlerV1 struct {
type HandlerV1 struct {
codyaccessv1connect.UnimplementedCodyAccessServiceHandler
logger log.Logger
store StoreV1
dotcom DotComDB
}
var _ codyaccessv1connect.CodyAccessServiceHandler = (*handlerV1)(nil)
func NewHandlerV1(
logger log.Logger,
store StoreV1,
) *HandlerV1 {
return &HandlerV1{
logger: logger.Scoped("codyaccess.v1"),
store: store,
}
}
func (s *handlerV1) GetCodyGatewayAccess(ctx context.Context, req *connect.Request[codyaccessv1.GetCodyGatewayAccessRequest]) (*connect.Response[codyaccessv1.GetCodyGatewayAccessResponse], error) {
var _ codyaccessv1connect.CodyAccessServiceHandler = (*HandlerV1)(nil)
func (s *HandlerV1) GetCodyGatewayAccess(ctx context.Context, req *connect.Request[codyaccessv1.GetCodyGatewayAccessRequest]) (*connect.Response[codyaccessv1.GetCodyGatewayAccessResponse], error) {
logger := trace.Logger(ctx, s.logger).
With(log.String("queryType", fmt.Sprintf("%T", req.Msg.GetQuery())))
@ -70,19 +69,19 @@ func (s *handlerV1) GetCodyGatewayAccess(ctx context.Context, req *connect.Reque
}
logger = logger.With(clientAttrs...)
var attr *dotcomdb.CodyGatewayAccessAttributes
var attr *codyaccess.CodyGatewayAccessWithSubscriptionDetails
switch query := req.Msg.GetQuery().(type) {
case *codyaccessv1.GetCodyGatewayAccessRequest_SubscriptionId:
if len(query.SubscriptionId) == 0 {
return nil, connect.NewError(connect.CodeInvalidArgument, errors.New("invalid query: subscription ID"))
}
attr, err = s.dotcom.GetCodyGatewayAccessAttributesBySubscription(ctx, query.SubscriptionId)
attr, err = s.store.GetCodyGatewayAccessBySubscription(ctx, query.SubscriptionId)
case *codyaccessv1.GetCodyGatewayAccessRequest_AccessToken:
if len(query.AccessToken) == 0 {
return nil, connect.NewError(connect.CodeInvalidArgument, errors.New("invalid query: access token"))
}
attr, err = s.dotcom.GetCodyGatewayAccessAttributesByAccessToken(ctx, query.AccessToken)
attr, err = s.store.GetCodyGatewayAccessByAccessToken(ctx, query.AccessToken)
default:
return nil, connect.NewError(connect.CodeInvalidArgument, errors.New("invalid query"))
@ -103,7 +102,7 @@ func (s *handlerV1) GetCodyGatewayAccess(ctx context.Context, req *connect.Reque
}), nil
}
func (s *handlerV1) ListCodyGatewayAccesses(ctx context.Context, req *connect.Request[codyaccessv1.ListCodyGatewayAccessesRequest]) (*connect.Response[codyaccessv1.ListCodyGatewayAccessesResponse], error) {
func (s *HandlerV1) ListCodyGatewayAccesses(ctx context.Context, req *connect.Request[codyaccessv1.ListCodyGatewayAccessesRequest]) (*connect.Response[codyaccessv1.ListCodyGatewayAccessesResponse], error) {
logger := trace.Logger(ctx, s.logger)
// 🚨 SECURITY: Require approrpiate M2M scope.
@ -122,7 +121,7 @@ func (s *handlerV1) ListCodyGatewayAccesses(ctx context.Context, req *connect.Re
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("pagination not implemented"))
}
attrs, err := s.dotcom.GetAllCodyGatewayAccessAttributes(ctx)
attrs, err := s.store.ListCodyGatewayAccesses(ctx)
if err != nil {
if err == dotcomdb.ErrCodyGatewayAccessNotFound {
return nil, connect.NewError(connect.CodeNotFound, err)
@ -146,7 +145,7 @@ func (s *handlerV1) ListCodyGatewayAccesses(ctx context.Context, req *connect.Re
return connect.NewResponse(&resp), nil
}
func (s *handlerV1) GetCodyGatewayUsage(ctx context.Context, req *connect.Request[codyaccessv1.GetCodyGatewayUsageRequest]) (*connect.Response[codyaccessv1.GetCodyGatewayUsageResponse], error) {
func (s *HandlerV1) GetCodyGatewayUsage(ctx context.Context, req *connect.Request[codyaccessv1.GetCodyGatewayUsageRequest]) (*connect.Response[codyaccessv1.GetCodyGatewayUsageResponse], error) {
logger := trace.Logger(ctx, s.logger)
// 🚨 SECURITY: Require appropriate M2M scope.

View File

@ -2,13 +2,19 @@ package codyaccessservice
import (
"context"
"encoding/hex"
"strings"
"github.com/sourcegraph/conc/pool"
sams "github.com/sourcegraph/sourcegraph-accounts-sdk-go"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/codyaccess"
"github.com/sourcegraph/sourcegraph/internal/codygateway/codygatewayactor"
"github.com/sourcegraph/sourcegraph/internal/codygateway/codygatewayevents"
"github.com/sourcegraph/sourcegraph/internal/completions/types"
"github.com/sourcegraph/sourcegraph/internal/license"
"github.com/sourcegraph/sourcegraph/internal/productsubscription"
codyaccessv1 "github.com/sourcegraph/sourcegraph/lib/enterpriseportal/codyaccess/v1"
subscriptionsv1 "github.com/sourcegraph/sourcegraph/lib/enterpriseportal/subscriptions/v1"
"github.com/sourcegraph/sourcegraph/lib/errors"
@ -24,18 +30,33 @@ type StoreV1 interface {
// is no longer active. It is critical that the caller not honor tokens where
// `.Active == false`.
IntrospectSAMSToken(ctx context.Context, token string) (*sams.IntrospectTokenResponse, error)
// GetCodyGatewayUsage retrieves recent Cody Gateway usage data.
// The subscriptionID should not be prefixed.
GetCodyGatewayUsage(ctx context.Context, subscriptionID string) (*codyaccessv1.CodyGatewayUsage, error)
// GetCodyGatewayAccessBySubscription retrieves Cody Gateway access by
// subscription ID.
GetCodyGatewayAccessBySubscription(ctx context.Context, subscriptionID string) (*codyaccess.CodyGatewayAccessWithSubscriptionDetails, error)
// GetCodyGatewayAccessByAccessToken retrieves Cody Gateway access details
// associated with the given access token.
GetCodyGatewayAccessByAccessToken(ctx context.Context, token string) (*codyaccess.CodyGatewayAccessWithSubscriptionDetails, error)
// ListCodyGatewayAccesses retrieves all Cody Gateway accesses with their
// associated subscription details.
ListCodyGatewayAccesses(ctx context.Context) ([]*codyaccess.CodyGatewayAccessWithSubscriptionDetails, error)
}
type storeV1 struct {
SAMSClient *sams.ClientV1
CodyAccess *codyaccess.Store
CodyGatewayEvents *codygatewayevents.Service
}
type StoreV1Options struct {
SAMSClient *sams.ClientV1
DB *database.DB
// Optional.
CodyGatewayEvents *codygatewayevents.Service
}
@ -46,6 +67,7 @@ var errStoreUnimplemented = errors.New("unimplemented")
func NewStoreV1(opts StoreV1Options) StoreV1 {
return &storeV1{
SAMSClient: opts.SAMSClient,
CodyAccess: opts.DB.CodyAccess(),
CodyGatewayEvents: opts.CodyGatewayEvents,
}
}
@ -126,4 +148,34 @@ func (s *storeV1) GetCodyGatewayUsage(ctx context.Context, subscriptionID string
applyResult(usage)
}
return usage, nil
}
func (s *storeV1) GetCodyGatewayAccessBySubscription(ctx context.Context, subscriptionID string) (*codyaccess.CodyGatewayAccessWithSubscriptionDetails, error) {
return s.CodyAccess.CodyGateway().Get(ctx, codyaccess.GetCodyGatewayAccessOptions{
SubscriptionID: subscriptionID,
})
}
func (s *storeV1) GetCodyGatewayAccessByAccessToken(ctx context.Context, token string) (*codyaccess.CodyGatewayAccessWithSubscriptionDetails, error) {
// Below is copied from 'func (t dbTokens) LookupProductSubscriptionIDByAccessToken'
// in 'cmd/frontend/internal/dotcom/productsubscription'.
if !strings.HasPrefix(token, productsubscription.AccessTokenPrefix) &&
!strings.HasPrefix(token, license.LicenseKeyBasedAccessTokenPrefix) {
return nil, errors.WithSafeDetails(codyaccess.ErrSubscriptionDoesNotExist, "invalid token with unknown prefix")
}
tokenSansPrefix := token[len(license.LicenseKeyBasedAccessTokenPrefix):]
decoded, err := hex.DecodeString(tokenSansPrefix)
if err != nil {
return nil, errors.WithSafeDetails(codyaccess.ErrSubscriptionDoesNotExist, "invalid token with unknown encoding")
}
// End copied code.
return s.CodyAccess.CodyGateway().Get(ctx, codyaccess.GetCodyGatewayAccessOptions{
LicenseKeyHash: decoded,
})
}
func (s *storeV1) ListCodyGatewayAccesses(ctx context.Context) ([]*codyaccess.CodyGatewayAccessWithSubscriptionDetails, error) {
return s.CodyAccess.CodyGateway().List(ctx)
}

View File

@ -95,7 +95,8 @@ func scanCodyGatewayAccess(row pgx.Row) (*CodyGatewayAccessWithSubscriptionDetai
// an active subscription exists, but explicit access is not configured. In
// this case we still need to return a valid CodyGatewayAccessWithSubscriptionDetails,
// just with empty fields.
var maybeEnabled *bool
var maybeEnabled sql.NullBool
var maybeDisplayName sql.NullString
err := row.Scan(
&a.SubscriptionID,
&maybeEnabled,
@ -106,7 +107,7 @@ func scanCodyGatewayAccess(row pgx.Row) (*CodyGatewayAccessWithSubscriptionDetai
&a.EmbeddingsRateLimit,
&a.EmbeddingsRateLimitIntervalSeconds,
// Subscriptions fields
&a.DisplayName,
&maybeDisplayName,
// License fields
&a.ActiveLicenseInfo,
&a.LicenseKeyHashes,
@ -114,9 +115,9 @@ func scanCodyGatewayAccess(row pgx.Row) (*CodyGatewayAccessWithSubscriptionDetai
if err != nil {
return nil, err
}
if maybeEnabled != nil {
a.Enabled = *maybeEnabled
}
a.Enabled = maybeEnabled.Bool
a.DisplayName = maybeDisplayName.String
return &a, nil
}
@ -181,27 +182,55 @@ type CodyGatewayAccessWithSubscriptionDetails struct {
DisplayName string
ActiveLicenseInfo *license.Info
LicenseKeyHashes [][]byte
// Used by GenerateAccessTokens
LicenseKeyHashes [][]byte
}
var ErrSubscriptionDoesNotExist = errors.New("subscription does not exist")
type GetCodyGatewayAccessOptions struct {
SubscriptionID string
LicenseKeyHash []byte
}
func (opts GetCodyGatewayAccessOptions) buildConds() (string, pgx.NamedArgs, error) {
if opts.SubscriptionID == "" && len(opts.LicenseKeyHash) == 0 {
return "", nil, errors.New("must specify either SubscriptionID or LicenseKeyHash")
}
args := pgx.NamedArgs{}
conds := []string{"TRUE"}
if opts.SubscriptionID != "" {
conds = append(conds, "subscription.id = @subscriptionID")
args["subscriptionID"] = opts.SubscriptionID
}
if len(opts.LicenseKeyHash) > 0 {
conds = append(conds, "@licenseKeyHash = ANY(tokens.license_key_hashes)")
args["licenseKeyHash"] = opts.LicenseKeyHash
}
return strings.Join(conds, " AND "), args, nil
}
// Get returns the Cody Gateway access for the given subscription.
func (s *CodyGatewayStore) Get(ctx context.Context, subscriptionID string) (*CodyGatewayAccessWithSubscriptionDetails, error) {
func (s *CodyGatewayStore) Get(ctx context.Context, opts GetCodyGatewayAccessOptions) (*CodyGatewayAccessWithSubscriptionDetails, error) {
conds, args, err := opts.buildConds()
if err != nil {
return nil, err
}
query := fmt.Sprintf(`SELECT
%s
FROM
enterprise_portal_cody_gateway_access AS access
%s
WHERE
subscription.id = @subscriptionID
%s
AND subscription.archived_at IS NULL`,
strings.Join(codyGatewayAccessTableColumns(), ", "),
codyGatewayAccessJoinClauses)
codyGatewayAccessJoinClauses,
conds)
sub, err := scanCodyGatewayAccess(s.db.QueryRow(ctx, query, pgx.NamedArgs{
"subscriptionID": subscriptionID,
}))
sub, err := scanCodyGatewayAccess(s.db.QueryRow(ctx, query, args))
if err != nil {
if errors.Is(err, pgx.ErrNoRows) {
// RIGHT JOIN in query ensures that if we find no result, it's
@ -297,5 +326,5 @@ func (s *CodyGatewayStore) Upsert(ctx context.Context, subscriptionID string, op
}
return nil, err
}
return s.Get(ctx, subscriptionID)
return s.Get(ctx, GetCodyGatewayAccessOptions{SubscriptionID: subscriptionID})
}

View File

@ -149,7 +149,9 @@ func TestCodyGatewayStore(t *testing.T) {
})
require.NoError(t, err)
_, err = codyaccess.NewCodyGatewayStore(db).Get(ctx, subscriptionID)
_, err = codyaccess.NewCodyGatewayStore(db).Get(ctx, codyaccess.GetCodyGatewayAccessOptions{
SubscriptionID: subscriptionID,
})
assert.ErrorIs(t, err, codyaccess.ErrSubscriptionDoesNotExist)
})
})
@ -214,15 +216,36 @@ func CodyGatewayStoreListAndGet(t *testing.T, ctx context.Context, subscriptionI
t.Run("Get", func(t *testing.T) {
for idx, sub := range subscriptionIDs {
t.Run(fmt.Sprintf("idx=%d", idx), func(t *testing.T) {
got, err := s.Get(ctx, sub)
got, err := s.Get(ctx, codyaccess.GetCodyGatewayAccessOptions{
SubscriptionID: sub,
})
require.NoError(t, err)
assertAccess(idx, got)
// Reverse lookup by license key hash
for _, hash := range got.LicenseKeyHashes {
got2, err := s.Get(ctx, codyaccess.GetCodyGatewayAccessOptions{
LicenseKeyHash: hash,
})
require.NoError(t, err)
assert.Len(t, got2.LicenseKeyHashes, 2) // 2 valid licenses
assert.Equal(t, got, got2)
}
})
}
t.Run("ErrSubscriptionDoesNotExist", func(t *testing.T) {
_, err := s.Get(ctx, uuid.NewString())
_, err := s.Get(ctx, codyaccess.GetCodyGatewayAccessOptions{
SubscriptionID: uuid.NewString(),
})
assert.Error(t, err)
assert.ErrorIs(t, err, codyaccess.ErrSubscriptionDoesNotExist)
_, err = s.Get(ctx, codyaccess.GetCodyGatewayAccessOptions{
LicenseKeyHash: []byte(uuid.NewString()),
})
assert.Error(t, err)
assert.ErrorIs(t, err, codyaccess.ErrSubscriptionDoesNotExist)
})
})
@ -241,7 +264,9 @@ func CodyGatewayStoreUpsert(t *testing.T, ctx context.Context, subscriptionIDs [
)
require.NoError(t, err)
got, err := s.Get(ctx, currentAccess.SubscriptionID)
got, err := s.Get(ctx, codyaccess.GetCodyGatewayAccessOptions{
SubscriptionID: currentAccess.SubscriptionID,
})
require.NoError(t, err)
assert.False(t, got.Enabled)
assert.Equal(t, currentAccess.SubscriptionID, got.SubscriptionID)

View File

@ -20,15 +20,15 @@ var databaseTracer = otel.Tracer("enterprise-portal/internal/database")
// DB is the database handle for the storage layer.
type DB struct {
db *pgxpool.Pool
DB *pgxpool.Pool
}
func (db *DB) Subscriptions() *subscriptions.Store {
return subscriptions.NewStore(db.db)
return subscriptions.NewStore(db.DB)
}
func (db *DB) CodyAccess() *codyaccess.Store {
return codyaccess.NewStore(db.db)
return codyaccess.NewStore(db.DB)
}
func databaseName(msp bool) string {
@ -53,11 +53,11 @@ func NewHandle(ctx context.Context, logger log.Logger, contract runtime.Contract
if err != nil {
return nil, errors.Wrap(err, "get connection pool")
}
return &DB{db: pool}, nil
return &DB{DB: pool}, nil
}
// Close closes all connections in the pool and rejects future Acquire calls.
// Blocks until all connections are returned to pool and closed.
func (db *DB) Close() {
db.db.Close()
db.DB.Close()
}

View File

@ -7,6 +7,7 @@ go_library(
tags = [TAG_INFRA_CORESERVICES],
visibility = ["//cmd/enterprise-portal:__subpackages__"],
deps = [
"//cmd/enterprise-portal/internal/database/internal/tables",
"//cmd/enterprise-portal/internal/database/internal/tables/custommigrator",
"//internal/database/dbtest",
"@com_github_jackc_pgx_v5//:pgx",

View File

@ -16,10 +16,13 @@ import (
"gorm.io/gorm"
"gorm.io/gorm/schema"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables/custommigrator"
"github.com/sourcegraph/sourcegraph/internal/database/dbtest"
)
func Tables(_ *testing.T) []schema.Tabler { return tables.All() }
// NewTestDB creates a new test database and initializes the given list of
// tables for the suite. The test database is dropped after testing is completed
// unless failed.
@ -37,6 +40,7 @@ func NewTestDB(t testing.TB, system, suite string, tables ...schema.Tabler) *pgx
// Set up test suite database.
dbName := fmt.Sprintf("sourcegraph-test-%s-%s-%d", system, suite, time.Now().Unix())
t.Logf("Preparing database %s", dbName)
_, err = sqlDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %q", dbName))
require.NoError(t, err)

View File

@ -26,7 +26,7 @@ import (
"github.com/sourcegraph/sourcegraph/lib/pointers"
)
type importer struct {
type Importer struct {
logger log.Logger
dotcom *dotcomdb.Reader
@ -36,14 +36,29 @@ type importer struct {
codyGatewayAccess *codyaccess.CodyGatewayStore
}
var _ goroutine.Handler = (*importer)(nil)
func NewHandler(
ctx context.Context,
logger log.Logger,
dotcom *dotcomdb.Reader,
enterprisePortal *database.DB,
) *Importer {
return &Importer{
logger: logger,
dotcom: dotcom,
subscriptions: enterprisePortal.Subscriptions(),
licenses: enterprisePortal.Subscriptions().Licenses(),
codyGatewayAccess: enterprisePortal.CodyAccess().CodyGateway(),
}
}
// New returns a periodic goroutine that runs an importer that reconciles
// subscriptions, licenses, and Cody Gateway access from dotcom into the
// Enterprise Portal database.
var _ goroutine.Handler = (*Importer)(nil)
// NewPeriodicImporter returns a periodic goroutine that runs an importer that
// reconciles subscriptions, licenses, and Cody Gateway access from dotcom into
// the Enterprise Portal database.
//
// If interval is 0, the importer is disabled.
func New(
func NewPeriodicImporter(
ctx context.Context,
logger log.Logger,
dotcom *dotcomdb.Reader,
@ -56,13 +71,7 @@ func New(
}
return goroutine.NewPeriodicGoroutine(
ctx,
&importer{
logger: logger,
dotcom: dotcom,
subscriptions: enterprisePortal.Subscriptions(),
licenses: enterprisePortal.Subscriptions().Licenses(),
codyGatewayAccess: enterprisePortal.CodyAccess().CodyGateway(),
},
NewHandler(ctx, logger, dotcom, enterprisePortal),
goroutine.WithOperation(
observation.NewContext(logger, observation.Tracer(trace.GetTracer())).
Operation(observation.Op{
@ -72,7 +81,7 @@ func New(
goroutine.WithInterval(interval))
}
func (i *importer) Handle(ctx context.Context) error {
func (i *Importer) Handle(ctx context.Context) error {
l := trace.Logger(ctx, i.logger)
dotcomSubscriptions, err := i.dotcom.ListEnterpriseSubscriptions(ctx, dotcomdb.ListEnterpriseSubscriptionsOptions{})
@ -101,7 +110,7 @@ func (i *importer) Handle(ctx context.Context) error {
return nil
}
func (i *importer) importSubscription(ctx context.Context, dotcomSub *dotcomdb.SubscriptionAttributes) (err error) {
func (i *Importer) importSubscription(ctx context.Context, dotcomSub *dotcomdb.SubscriptionAttributes) (err error) {
tr, ctx := trace.New(ctx, "importSubscription",
attribute.String("dotcomSub.ID", dotcomSub.ID))
defer tr.EndWithErr(&err)
@ -227,7 +236,7 @@ func (i *importer) importSubscription(ctx context.Context, dotcomSub *dotcomdb.S
return nil
}
func (i *importer) importLicense(ctx context.Context, subscriptionID string, dotcomLicense *dotcomdb.LicenseAttributes) (err error) {
func (i *Importer) importLicense(ctx context.Context, subscriptionID string, dotcomLicense *dotcomdb.LicenseAttributes) (err error) {
tr, ctx := trace.New(ctx, "importSubscription",
attribute.String("dotcomSub.ID", subscriptionID),
attribute.String("dotcomLicense.ID", dotcomLicense.ID))

View File

@ -28,6 +28,10 @@ go_test(
],
deps = [
":dotcomdb",
"//cmd/enterprise-portal/internal/codyaccessservice",
"//cmd/enterprise-portal/internal/database",
"//cmd/enterprise-portal/internal/database/databasetest",
"//cmd/enterprise-portal/internal/database/importer",
"//cmd/frontend/dotcomproductsubscriptiontest",
"//cmd/frontend/graphqlbackend",
"//internal/database",
@ -35,11 +39,15 @@ go_test(
"//internal/database/dbtest",
"//internal/license",
"//internal/licensing",
"//lib/enterpriseportal/codyaccess/v1:codyaccess",
"//lib/enterpriseportal/subscriptions/v1:subscriptions",
"//lib/pointers",
"@com_connectrpc_connect//:connect",
"@com_github_jackc_pgx_v4//stdlib",
"@com_github_jackc_pgx_v5//pgxpool",
"@com_github_sourcegraph_log//logtest",
"@com_github_sourcegraph_sourcegraph_accounts_sdk_go//:sourcegraph-accounts-sdk-go",
"@com_github_sourcegraph_sourcegraph_accounts_sdk_go//scopes",
"@com_github_stretchr_testify//assert",
"@com_github_stretchr_testify//require",
],

View File

@ -125,10 +125,10 @@ func (c CodyGatewayAccessAttributes) EvaluateRateLimits() CodyGatewayRateLimits
Chat: licensing.NewCodyGatewayChatRateLimit(p, c.ActiveLicenseUserCount),
CodeSource: codyaccessv1.CodyGatewayRateLimitSource_CODY_GATEWAY_RATE_LIMIT_SOURCE_PLAN,
Code: licensing.NewCodyGatewayCodeRateLimit(p, c.ActiveLicenseUserCount, c.ActiveLicenseTags),
Code: licensing.NewCodyGatewayCodeRateLimit(p, c.ActiveLicenseUserCount),
EmbeddingsSource: codyaccessv1.CodyGatewayRateLimitSource_CODY_GATEWAY_RATE_LIMIT_SOURCE_PLAN,
Embeddings: licensing.NewCodyGatewayEmbeddingsRateLimit(p, c.ActiveLicenseUserCount, c.ActiveLicenseTags),
Embeddings: licensing.NewCodyGatewayEmbeddingsRateLimit(p, c.ActiveLicenseUserCount),
}
// Chat

View File

@ -6,6 +6,7 @@ import (
"testing"
"time"
"connectrpc.com/connect"
pgxstdlibv4 "github.com/jackc/pgx/v4/stdlib"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/stretchr/testify/assert"
@ -13,19 +14,28 @@ import (
"github.com/sourcegraph/log/logtest"
"github.com/sourcegraph/sourcegraph-accounts-sdk-go/scopes"
sams "github.com/sourcegraph/sourcegraph-accounts-sdk-go"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/codyaccessservice"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/databasetest"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/importer"
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/dotcomdb"
"github.com/sourcegraph/sourcegraph/cmd/frontend/dotcomproductsubscriptiontest"
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
"github.com/sourcegraph/sourcegraph/internal/database"
sgdatabase "github.com/sourcegraph/sourcegraph/internal/database"
"github.com/sourcegraph/sourcegraph/internal/database/dbconn"
"github.com/sourcegraph/sourcegraph/internal/database/dbtest"
"github.com/sourcegraph/sourcegraph/internal/license"
"github.com/sourcegraph/sourcegraph/internal/licensing"
v1 "github.com/sourcegraph/sourcegraph/lib/enterpriseportal/subscriptions/v1"
codyaccessv1 "github.com/sourcegraph/sourcegraph/lib/enterpriseportal/codyaccess/v1"
subscriptionsv1 "github.com/sourcegraph/sourcegraph/lib/enterpriseportal/subscriptions/v1"
"github.com/sourcegraph/sourcegraph/lib/pointers"
)
func newTestDotcomReader(t *testing.T, opts dotcomdb.ReaderOptions) (database.DB, *dotcomdb.Reader) {
func newTestDotcomReader(t *testing.T, opts dotcomdb.ReaderOptions) (sgdatabase.DB, *dotcomdb.Reader) {
ctx := context.Background()
// Set up a Sourcegraph test database.
@ -65,7 +75,23 @@ func newTestDotcomReader(t *testing.T, opts dotcomdb.ReaderOptions) (database.DB
r := dotcomdb.NewReader(conn, opts)
require.NoError(t, r.Ping(ctx))
return database.NewDB(logtest.Scoped(t), sgtestdb), r
return sgdatabase.NewDB(logtest.Scoped(t), sgtestdb), r
}
type mockCodyAccessV1Store struct{ codyaccessservice.StoreV1 }
func (mockCodyAccessV1Store) IntrospectSAMSToken(context.Context, string) (*sams.IntrospectTokenResponse, error) {
return &sams.IntrospectTokenResponse{
Active: true,
ExpiresAt: time.Now().Add(24 * time.Hour),
Scopes: scopes.Scopes{scopes.ToScope(scopes.ServiceEnterprisePortal, scopes.PermissionEnterprisePortalCodyAccess, scopes.ActionRead)},
}, nil
}
func mockAuthenticatedServiceRequest[T any](message *T) *connect.Request[T] {
req := connect.NewRequest(message)
req.Header().Set("authorization", "bearer foobar")
return req
}
type mockedData struct {
@ -76,7 +102,7 @@ type mockedData struct {
archivedSubscriptions int
}
func setupDBAndInsertMockLicense(t *testing.T, dotcomdb database.DB, info license.Info, cgAccess *graphqlbackend.UpdateCodyGatewayAccessInput) mockedData {
func setupDBAndInsertMockLicense(t *testing.T, dotcomdb sgdatabase.DB, info license.Info, cgAccess *graphqlbackend.UpdateCodyGatewayAccessInput) mockedData {
start := time.Now()
ctx := context.Background()
@ -87,7 +113,7 @@ func setupDBAndInsertMockLicense(t *testing.T, dotcomdb database.DB, info licens
{
// Create a different subscription and license that's rubbish,
// created at the same time, to ensure we don't use it
u, err := dotcomdb.Users().Create(ctx, database.NewUser{Username: "barbaz"})
u, err := dotcomdb.Users().Create(ctx, sgdatabase.NewUser{Username: "barbaz"})
require.NoError(t, err)
sub, err := subscriptionsdb.Create(ctx, u.ID, u.Username)
require.NoError(t, err)
@ -104,7 +130,7 @@ func setupDBAndInsertMockLicense(t *testing.T, dotcomdb database.DB, info licens
{
// Create a different subscription and license that's archived,
// created at the same time, to ensure we don't use it
u, err := dotcomdb.Users().Create(ctx, database.NewUser{Username: "archived"})
u, err := dotcomdb.Users().Create(ctx, sgdatabase.NewUser{Username: "archived"})
require.NoError(t, err)
sub, err := subscriptionsdb.Create(ctx, u.ID, u.Username)
require.NoError(t, err)
@ -124,7 +150,7 @@ func setupDBAndInsertMockLicense(t *testing.T, dotcomdb database.DB, info licens
{
// Create a different subscription and license that's not a dev tag,
// created at the same time, to ensure we don't use it
u, err := dotcomdb.Users().Create(ctx, database.NewUser{Username: "not-dev"})
u, err := dotcomdb.Users().Create(ctx, sgdatabase.NewUser{Username: "not-dev"})
require.NoError(t, err)
sub, err := subscriptionsdb.Create(ctx, u.ID, u.Username)
require.NoError(t, err)
@ -138,7 +164,7 @@ func setupDBAndInsertMockLicense(t *testing.T, dotcomdb database.DB, info licens
}
// Create the subscription we will assert against
u, err := dotcomdb.Users().Create(ctx, database.NewUser{Username: "user"})
u, err := dotcomdb.Users().Create(ctx, sgdatabase.NewUser{Username: "user"})
require.NoError(t, err)
subid, err := subscriptionsdb.Create(ctx, u.ID, u.Username)
require.NoError(t, err)
@ -150,7 +176,7 @@ func setupDBAndInsertMockLicense(t *testing.T, dotcomdb database.DB, info licens
result.accessTokens = append(result.accessTokens, license.GenerateLicenseKeyBasedAccessToken(key1))
_, err = licensesdb.Create(ctx, subid, key1, 2, license.Info{
CreatedAt: info.CreatedAt.Add(-time.Hour),
ExpiresAt: info.ExpiresAt.Add(-time.Hour),
ExpiresAt: info.ExpiresAt.Add(-time.Minute), // should expire first, but not be expired
Tags: []string{licensing.DevTag},
})
require.NoError(t, err)
@ -170,7 +196,7 @@ func setupDBAndInsertMockLicense(t *testing.T, dotcomdb database.DB, info licens
{
// Create another different subscription and license that's also rubbish,
// created at the same time, to ensure we don't use it
u, err := dotcomdb.Users().Create(ctx, database.NewUser{Username: "foobar"})
u, err := dotcomdb.Users().Create(ctx, sgdatabase.NewUser{Username: "foobar"})
require.NoError(t, err)
sub, err := subscriptionsdb.Create(ctx, u.ID, u.Username)
require.NoError(t, err)
@ -198,7 +224,7 @@ func TestGetCodyGatewayAccessAttributes(t *testing.T) {
t.Parallel()
ctx := context.Background()
for _, tc := range []struct {
for i, tc := range []struct {
name string
info license.Info
cgAccess graphqlbackend.UpdateCodyGatewayAccessInput
@ -244,16 +270,37 @@ func TestGetCodyGatewayAccessAttributes(t *testing.T) {
dotcomdb, dotcomreader := newTestDotcomReader(t, dotcomdb.ReaderOptions{
DevOnly: true,
})
// First, set up a subscription and license and some other rubbish
// data to ensure we only get the license we want.
mock := setupDBAndInsertMockLicense(t, dotcomdb, tc.info, &tc.cgAccess)
// Now import the data for parity so we can compare against the
// Enterprise Portal implementation
epDB := &database.DB{
DB: databasetest.NewTestDB(t, "ep-dotcomdb", fmt.Sprintf("get-attributes-%d", i), databasetest.Tables(t)...),
}
err := importer.NewHandler(ctx, logtest.Scoped(t), dotcomreader, epDB).Handle(ctx)
require.NoError(t, err)
codyAccessService := codyaccessservice.NewHandlerV1(logtest.Scoped(t), mockCodyAccessV1Store{
StoreV1: codyaccessservice.NewStoreV1(codyaccessservice.StoreV1Options{
DB: epDB,
}),
})
t.Run("by subscription ID", func(t *testing.T) {
t.Parallel()
attr, err := dotcomreader.GetCodyGatewayAccessAttributesBySubscription(ctx, mock.targetSubscriptionID)
require.NoError(t, err)
validateAccessAttributes(t, dotcomdb, mock, attr, tc.info)
access, err := codyAccessService.GetCodyGatewayAccess(ctx, mockAuthenticatedServiceRequest(&codyaccessv1.GetCodyGatewayAccessRequest{
Query: &codyaccessv1.GetCodyGatewayAccessRequest_SubscriptionId{
SubscriptionId: mock.targetSubscriptionID,
},
}))
require.NoError(t, err)
validateAccessAttributes(t, dotcomdb, mock, attr, access.Msg.GetAccess(), tc.info)
})
t.Run("by access token", func(t *testing.T) {
@ -266,7 +313,13 @@ func TestGetCodyGatewayAccessAttributes(t *testing.T) {
attr, err := dotcomreader.GetCodyGatewayAccessAttributesByAccessToken(ctx, token)
require.NoError(t, err)
validateAccessAttributes(t, dotcomdb, mock, attr, tc.info)
access, err := codyAccessService.GetCodyGatewayAccess(ctx, mockAuthenticatedServiceRequest(&codyaccessv1.GetCodyGatewayAccessRequest{
Query: &codyaccessv1.GetCodyGatewayAccessRequest_AccessToken{
AccessToken: token,
},
}))
require.NoError(t, err)
validateAccessAttributes(t, dotcomdb, mock, attr, access.Msg.GetAccess(), tc.info)
t.Run("compare with dotcom tokens DB", func(t *testing.T) {
subID, err := dotcomproductsubscriptiontest.NewTokensDB(t, dotcomdb).
@ -281,11 +334,20 @@ func TestGetCodyGatewayAccessAttributes(t *testing.T) {
}
}
func validateAccessAttributes(t *testing.T, dotcomdb database.DB, mock mockedData, attr *dotcomdb.CodyGatewayAccessAttributes, info license.Info) {
func validateAccessAttributes(t *testing.T, dotcomdb sgdatabase.DB, mock mockedData, attr *dotcomdb.CodyGatewayAccessAttributes, access *codyaccessv1.CodyGatewayAccess, info license.Info) {
assert.Equal(t, mock.targetSubscriptionID, attr.SubscriptionID)
assert.Equal(t, subscriptionsv1.EnterpriseSubscriptionIDPrefix+mock.targetSubscriptionID, access.SubscriptionId)
assert.Equal(t, int(info.UserCount), *attr.ActiveLicenseUserCount)
assert.Len(t, attr.LicenseKeyHashes, 2)
assert.Equal(t, attr.GenerateAccessTokens(), mock.accessTokens)
assert.Equal(t, mock.accessTokens, attr.GenerateAccessTokens())
var protoAccessTokens []string
for _, t := range access.AccessTokens {
protoAccessTokens = append(protoAccessTokens, t.GetToken())
}
assert.ElementsMatch(t, mock.accessTokens, protoAccessTokens)
limits := attr.EvaluateRateLimits()
// Validate against the expected values as produced by existing resolvers
@ -300,30 +362,37 @@ func validateAccessAttributes(t *testing.T, dotcomdb database.DB, mock mockedDat
name string
expected graphqlbackend.CodyGatewayRateLimit
got licensing.CodyGatewayRateLimit
gotProto *codyaccessv1.CodyGatewayRateLimit
}{{
name: "Chat",
expected: mustWithCtx(t, expected.ChatCompletionsRateLimit),
got: limits.Chat,
gotProto: access.ChatCompletionsRateLimit,
}, {
name: "Code",
expected: mustWithCtx(t, expected.CodeCompletionsRateLimit),
got: limits.Code,
gotProto: access.CodeCompletionsRateLimit,
}, {
name: "Embeddings",
expected: mustWithCtx(t, expected.EmbeddingsRateLimit),
got: limits.Embeddings,
gotProto: access.EmbeddingsRateLimit,
}} {
t.Run(compare.name, func(t *testing.T) {
// We only care about limit and interval now
assert.Equal(t, int64(compare.expected.Limit()), compare.got.Limit, "Limit")
assert.Equal(t, compare.expected.IntervalSeconds(), compare.got.IntervalSeconds, "IntervalSeconds")
assert.Equal(t, uint64(compare.expected.Limit()), compare.gotProto.Limit, "Limit")
assert.Equal(t, int64(compare.expected.IntervalSeconds()), compare.gotProto.IntervalDuration.Seconds, "IntervalSeconds")
})
}
}
func TestGetAllCodyGatewayAccessAttributes(t *testing.T) {
t.Parallel()
dotcomdb, dotcomreader := newTestDotcomReader(t, dotcomdb.ReaderOptions{
dotcomDB, dotcomreader := newTestDotcomReader(t, dotcomdb.ReaderOptions{
DevOnly: true,
})
@ -334,21 +403,53 @@ func TestGetAllCodyGatewayAccessAttributes(t *testing.T) {
Tags: []string{licensing.PlanEnterprise1.Tag(), licensing.DevTag},
}
cgAccess := graphqlbackend.UpdateCodyGatewayAccessInput{Enabled: pointers.Ptr(true)}
mock := setupDBAndInsertMockLicense(t, dotcomdb, info, &cgAccess)
mock := setupDBAndInsertMockLicense(t, dotcomDB, info, &cgAccess)
attrs, err := dotcomreader.GetAllCodyGatewayAccessAttributes(context.Background())
// Now import the data for parity so we can compare against the
// Enterprise Portal implementation
ctx := context.Background()
epDB := &database.DB{
DB: databasetest.NewTestDB(t, "ep-dotcomdb", "get-all-attributes", databasetest.Tables(t)...),
}
err := importer.NewHandler(ctx, logtest.Scoped(t), dotcomreader, epDB).Handle(ctx)
require.NoError(t, err)
codyAccessService := codyaccessservice.NewHandlerV1(logtest.Scoped(t), mockCodyAccessV1Store{
StoreV1: codyaccessservice.NewStoreV1(codyaccessservice.StoreV1Options{
DB: epDB,
}),
})
attrs, err := dotcomreader.GetAllCodyGatewayAccessAttributes(ctx)
require.NoError(t, err)
assert.Len(t, attrs, 3) // 3 subscriptions created in setupDBAndInsertMockLicense
var found bool
accesses, err := codyAccessService.ListCodyGatewayAccesses(ctx, mockAuthenticatedServiceRequest(&codyaccessv1.ListCodyGatewayAccessesRequest{}))
require.NoError(t, err)
assert.Len(t, accesses.Msg.GetAccesses(), 3) // 3 subscriptions created in setupDBAndInsertMockLicense
var (
foundAttr *dotcomdb.CodyGatewayAccessAttributes
foundAccess *codyaccessv1.CodyGatewayAccess
)
for _, attr := range attrs {
if attr.SubscriptionID == mock.targetSubscriptionID {
found = true
validateAccessAttributes(t, dotcomdb, mock, attr, info)
foundAttr = attr
} else {
assert.False(t, attr.CodyGatewayEnabled)
}
}
assert.True(t, found)
for _, access := range accesses.Msg.GetAccesses() {
if access.SubscriptionId == subscriptionsv1.EnterpriseSubscriptionIDPrefix+mock.targetSubscriptionID {
foundAccess = access
} else {
assert.False(t, access.Enabled)
}
}
require.NotNil(t, foundAttr)
require.NotNil(t, foundAccess)
validateAccessAttributes(t, dotcomDB, mock, foundAttr, foundAccess, info)
}
func TestListEnterpriseSubscriptionLicenses(t *testing.T) {
@ -378,7 +479,7 @@ func TestListEnterpriseSubscriptionLicenses(t *testing.T) {
ctx := context.Background()
for _, tc := range []struct {
name string
filters []*v1.ListEnterpriseSubscriptionLicensesFilter
filters []*subscriptionsv1.ListEnterpriseSubscriptionLicensesFilter
pageSize int
expect func(t *testing.T, licenses []*dotcomdb.LicenseAttributes)
}{{
@ -390,8 +491,8 @@ func TestListEnterpriseSubscriptionLicenses(t *testing.T) {
},
}, {
name: "filter by subscription ID",
filters: []*v1.ListEnterpriseSubscriptionLicensesFilter{{
Filter: &v1.ListEnterpriseSubscriptionLicensesFilter_SubscriptionId{
filters: []*subscriptionsv1.ListEnterpriseSubscriptionLicensesFilter{{
Filter: &subscriptionsv1.ListEnterpriseSubscriptionLicensesFilter_SubscriptionId{
SubscriptionId: mock.targetSubscriptionID,
},
}},
@ -403,8 +504,8 @@ func TestListEnterpriseSubscriptionLicenses(t *testing.T) {
},
}, {
name: "filter by subscription ID and limit 1",
filters: []*v1.ListEnterpriseSubscriptionLicensesFilter{{
Filter: &v1.ListEnterpriseSubscriptionLicensesFilter_SubscriptionId{
filters: []*subscriptionsv1.ListEnterpriseSubscriptionLicensesFilter{{
Filter: &subscriptionsv1.ListEnterpriseSubscriptionLicensesFilter_SubscriptionId{
SubscriptionId: mock.targetSubscriptionID,
},
}},
@ -415,12 +516,12 @@ func TestListEnterpriseSubscriptionLicenses(t *testing.T) {
},
}, {
name: "filter by subscription ID and not archived",
filters: []*v1.ListEnterpriseSubscriptionLicensesFilter{{
Filter: &v1.ListEnterpriseSubscriptionLicensesFilter_SubscriptionId{
filters: []*subscriptionsv1.ListEnterpriseSubscriptionLicensesFilter{{
Filter: &subscriptionsv1.ListEnterpriseSubscriptionLicensesFilter_SubscriptionId{
SubscriptionId: mock.targetSubscriptionID,
},
}, {
Filter: &v1.ListEnterpriseSubscriptionLicensesFilter_IsRevoked{
Filter: &subscriptionsv1.ListEnterpriseSubscriptionLicensesFilter_IsRevoked{
IsRevoked: false,
},
}},
@ -434,8 +535,8 @@ func TestListEnterpriseSubscriptionLicenses(t *testing.T) {
},
}, {
name: "filter by is archived",
filters: []*v1.ListEnterpriseSubscriptionLicensesFilter{{
Filter: &v1.ListEnterpriseSubscriptionLicensesFilter_IsRevoked{
filters: []*subscriptionsv1.ListEnterpriseSubscriptionLicensesFilter{{
Filter: &subscriptionsv1.ListEnterpriseSubscriptionLicensesFilter_IsRevoked{
IsRevoked: true,
},
}},
@ -444,8 +545,8 @@ func TestListEnterpriseSubscriptionLicenses(t *testing.T) {
},
}, {
name: "filter by not archived",
filters: []*v1.ListEnterpriseSubscriptionLicensesFilter{{
Filter: &v1.ListEnterpriseSubscriptionLicensesFilter_IsRevoked{
filters: []*subscriptionsv1.ListEnterpriseSubscriptionLicensesFilter{{
Filter: &subscriptionsv1.ListEnterpriseSubscriptionLicensesFilter_IsRevoked{
IsRevoked: false,
},
}},

View File

@ -91,7 +91,7 @@ func (r codyGatewayAccessResolver) CodeCompletionsRateLimit(ctx context.Context)
var source graphqlbackend.CodyGatewayRateLimitSource
if activeLicense != nil {
source = graphqlbackend.CodyGatewayRateLimitSourcePlan
rateLimit = licensing.NewCodyGatewayCodeRateLimit(licensing.PlanFromTags(activeLicense.LicenseTags), activeLicense.LicenseUserCount, activeLicense.LicenseTags)
rateLimit = licensing.NewCodyGatewayCodeRateLimit(licensing.PlanFromTags(activeLicense.LicenseTags), activeLicense.LicenseUserCount)
}
// Apply overrides
@ -131,7 +131,7 @@ func (r codyGatewayAccessResolver) EmbeddingsRateLimit(ctx context.Context) (gra
var source graphqlbackend.CodyGatewayRateLimitSource
if activeLicense != nil {
source = graphqlbackend.CodyGatewayRateLimitSourcePlan
rateLimit = licensing.NewCodyGatewayEmbeddingsRateLimit(licensing.PlanFromTags(activeLicense.LicenseTags), activeLicense.LicenseUserCount, activeLicense.LicenseTags)
rateLimit = licensing.NewCodyGatewayEmbeddingsRateLimit(licensing.PlanFromTags(activeLicense.LicenseTags), activeLicense.LicenseUserCount)
}
// Apply overrides

View File

@ -53,7 +53,7 @@ func NewCodyGatewayChatRateLimit(plan Plan, userCount *int) CodyGatewayRateLimit
}
// NewCodyGatewayCodeRateLimit applies default Cody Gateway access based on the plan.
func NewCodyGatewayCodeRateLimit(plan Plan, userCount *int, licenseTags []string) CodyGatewayRateLimit {
func NewCodyGatewayCodeRateLimit(plan Plan, userCount *int) CodyGatewayRateLimit {
uc := 0
if userCount != nil {
uc = *userCount
@ -87,7 +87,7 @@ func NewCodyGatewayCodeRateLimit(plan Plan, userCount *int, licenseTags []string
const tokensPerDollar = int(1 / (0.0001 / 1_000))
// NewCodyGatewayEmbeddingsRateLimit applies default Cody Gateway access based on the plan.
func NewCodyGatewayEmbeddingsRateLimit(plan Plan, userCount *int, licenseTags []string) CodyGatewayRateLimit {
func NewCodyGatewayEmbeddingsRateLimit(plan Plan, userCount *int) CodyGatewayRateLimit {
uc := 0
if userCount != nil {
uc = *userCount

View File

@ -104,7 +104,7 @@ func TestCodyGatewayCodeRateLimit(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := NewCodyGatewayCodeRateLimit(tt.plan, tt.userCount, tt.licenseTags)
got := NewCodyGatewayCodeRateLimit(tt.plan, tt.userCount)
if diff := cmp.Diff(got, tt.want); diff != "" {
t.Fatalf("incorrect rate limit computed: %s", diff)
}
@ -151,7 +151,7 @@ func TestCodyGatewayEmbeddingsRateLimit(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := NewCodyGatewayEmbeddingsRateLimit(tt.plan, tt.userCount, tt.licenseTags)
got := NewCodyGatewayEmbeddingsRateLimit(tt.plan, tt.userCount)
if diff := cmp.Diff(got, tt.want); diff != "" {
t.Fatalf("incorrect rate limit computed: %s", diff)
}