From 4064629c48cd053d910a673802d7beac537adcab Mon Sep 17 00:00:00 2001 From: Robert Lin Date: Thu, 18 Jul 2024 12:01:29 -0700 Subject: [PATCH] feat/enterpriseportal: use database for reading Cody Gateway access --- .../internal/codyaccessservice/BUILD.bazel | 8 +- .../internal/codyaccessservice/adapters.go | 99 +++++++++- .../codyaccessservice/adapters_test.go | 28 +-- .../internal/codyaccessservice/v1.go | 43 +++-- .../internal/codyaccessservice/v1_store.go | 52 ++++++ .../database/codyaccess/codygateway.go | 55 ++++-- .../database/codyaccess/codygateway_test.go | 33 +++- .../internal/database/database.go | 10 +- .../database/databasetest/BUILD.bazel | 1 + .../database/databasetest/databasetest.go | 4 + .../internal/database/importer/importer.go | 41 +++-- .../internal/dotcomdb/BUILD.bazel | 8 + .../internal/dotcomdb/dotcomdb.go | 4 +- .../internal/dotcomdb/dotcomdb_test.go | 171 ++++++++++++++---- .../codygateway_graphql.go | 4 +- internal/licensing/codygateway.go | 4 +- internal/licensing/codygateway_test.go | 4 +- 17 files changed, 444 insertions(+), 125 deletions(-) diff --git a/cmd/enterprise-portal/internal/codyaccessservice/BUILD.bazel b/cmd/enterprise-portal/internal/codyaccessservice/BUILD.bazel index c24456fb45b..de5d5cd3d65 100644 --- a/cmd/enterprise-portal/internal/codyaccessservice/BUILD.bazel +++ b/cmd/enterprise-portal/internal/codyaccessservice/BUILD.bazel @@ -13,16 +13,22 @@ go_library( visibility = ["//cmd/enterprise-portal:__subpackages__"], deps = [ "//cmd/enterprise-portal/internal/connectutil", + "//cmd/enterprise-portal/internal/database", + "//cmd/enterprise-portal/internal/database/codyaccess", "//cmd/enterprise-portal/internal/dotcomdb", "//cmd/enterprise-portal/internal/samsm2m", "//internal/codygateway/codygatewayactor", "//internal/codygateway/codygatewayevents", "//internal/completions/types", + "//internal/license", + "//internal/licensing", + "//internal/productsubscription", "//internal/trace", "//lib/enterpriseportal/codyaccess/v1:codyaccess", "//lib/enterpriseportal/codyaccess/v1/v1connect", "//lib/enterpriseportal/subscriptions/v1:subscriptions", "//lib/errors", + "//lib/pointers", "@com_connectrpc_connect//:connect", "@com_github_sourcegraph_conc//pool", "@com_github_sourcegraph_log//:log", @@ -38,7 +44,7 @@ go_test( srcs = ["adapters_test.go"], embed = [":codyaccessservice"], deps = [ - "//cmd/enterprise-portal/internal/dotcomdb", + "//cmd/enterprise-portal/internal/database/codyaccess", "@com_github_hexops_autogold_v2//:autogold", "@com_github_stretchr_testify//assert", ], diff --git a/cmd/enterprise-portal/internal/codyaccessservice/adapters.go b/cmd/enterprise-portal/internal/codyaccessservice/adapters.go index b82e1e4c932..369b6c27e1f 100644 --- a/cmd/enterprise-portal/internal/codyaccessservice/adapters.go +++ b/cmd/enterprise-portal/internal/codyaccessservice/adapters.go @@ -1,46 +1,51 @@ package codyaccessservice import ( + "encoding/hex" + "google.golang.org/protobuf/types/known/durationpb" "google.golang.org/protobuf/types/known/timestamppb" - "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/dotcomdb" + "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/codyaccess" "github.com/sourcegraph/sourcegraph/internal/codygateway/codygatewayevents" + "github.com/sourcegraph/sourcegraph/internal/license" + "github.com/sourcegraph/sourcegraph/internal/licensing" codyaccessv1 "github.com/sourcegraph/sourcegraph/lib/enterpriseportal/codyaccess/v1" subscriptionsv1 "github.com/sourcegraph/sourcegraph/lib/enterpriseportal/subscriptions/v1" + "github.com/sourcegraph/sourcegraph/lib/pointers" ) -func convertAccessAttrsToProto(attrs *dotcomdb.CodyGatewayAccessAttributes) *codyaccessv1.CodyGatewayAccess { +func convertAccessAttrsToProto(access *codyaccess.CodyGatewayAccessWithSubscriptionDetails) *codyaccessv1.CodyGatewayAccess { // Provide ID in prefixed format. - subscriptionID := subscriptionsv1.EnterpriseSubscriptionIDPrefix + attrs.SubscriptionID + subscriptionID := subscriptionsv1.EnterpriseSubscriptionIDPrefix + access.SubscriptionID // Always try to return the full response, since even when disabled, some // features may be allowed via Cody Gateway (notably attributions). This // also allows Cody Gateway to cache the state of actors that are disabled. - limits := attrs.EvaluateRateLimits() + limits := evaluateCodyGatewayAccessRateLimits(access) return &codyaccessv1.CodyGatewayAccess{ SubscriptionId: subscriptionID, - SubscriptionDisplayName: attrs.GetSubscriptionDisplayName(), - Enabled: attrs.CodyGatewayEnabled, + SubscriptionDisplayName: access.DisplayName, + Enabled: access.Enabled, // Rate limits return nil if not enabled, per API spec - ChatCompletionsRateLimit: nilIfNotEnabled(attrs.CodyGatewayEnabled, &codyaccessv1.CodyGatewayRateLimit{ + ChatCompletionsRateLimit: nilIfNotEnabled(access.Enabled, &codyaccessv1.CodyGatewayRateLimit{ Source: limits.ChatSource, Limit: uint64(limits.Chat.Limit), IntervalDuration: durationpb.New(limits.Chat.IntervalDuration()), }), - CodeCompletionsRateLimit: nilIfNotEnabled(attrs.CodyGatewayEnabled, &codyaccessv1.CodyGatewayRateLimit{ + CodeCompletionsRateLimit: nilIfNotEnabled(access.Enabled, &codyaccessv1.CodyGatewayRateLimit{ Source: limits.CodeSource, Limit: uint64(limits.Code.Limit), IntervalDuration: durationpb.New(limits.Code.IntervalDuration()), }), - EmbeddingsRateLimit: nilIfNotEnabled(attrs.CodyGatewayEnabled, &codyaccessv1.CodyGatewayRateLimit{ + EmbeddingsRateLimit: nilIfNotEnabled(access.Enabled, &codyaccessv1.CodyGatewayRateLimit{ Source: limits.EmbeddingsSource, Limit: uint64(limits.Embeddings.Limit), IntervalDuration: durationpb.New(limits.Embeddings.IntervalDuration()), }), // This is always provided, even if access is disabled AccessTokens: func() []*codyaccessv1.CodyGatewayAccessToken { - accessTokens := attrs.GenerateAccessTokens() + accessTokens := generateCodyGatewayAccessTokens(access) if len(accessTokens) == 0 { return []*codyaccessv1.CodyGatewayAccessToken{} } @@ -56,6 +61,18 @@ func convertAccessAttrsToProto(attrs *dotcomdb.CodyGatewayAccessAttributes) *cod } } +func generateCodyGatewayAccessTokens(access *codyaccess.CodyGatewayAccessWithSubscriptionDetails) []string { + accessTokens := make([]string, 0, len(access.LicenseKeyHashes)) + for _, t := range access.LicenseKeyHashes { + if len(t) == 0 { // query can return empty hashes, ignore these + continue + } + // See license.GenerateLicenseKeyBasedAccessToken + accessTokens = append(accessTokens, license.LicenseKeyBasedAccessTokenPrefix+hex.EncodeToString(t)) + } + return accessTokens +} + func nilIfNotEnabled[T any](enabled bool, value *T) *T { if !enabled { return nil @@ -63,6 +80,68 @@ func nilIfNotEnabled[T any](enabled bool, value *T) *T { return value } +type CodyGatewayRateLimits struct { + ChatSource codyaccessv1.CodyGatewayRateLimitSource + Chat licensing.CodyGatewayRateLimit + + CodeSource codyaccessv1.CodyGatewayRateLimitSource + Code licensing.CodyGatewayRateLimit + + EmbeddingsSource codyaccessv1.CodyGatewayRateLimitSource + Embeddings licensing.CodyGatewayRateLimit +} + +func maybeApplyOverride[T ~int32 | ~int64](limit *T, overrideValue T, overrideValid bool) codyaccessv1.CodyGatewayRateLimitSource { + if overrideValid { + *limit = overrideValue + return codyaccessv1.CodyGatewayRateLimitSource_CODY_GATEWAY_RATE_LIMIT_SOURCE_OVERRIDE + } + // No override + return codyaccessv1.CodyGatewayRateLimitSource_CODY_GATEWAY_RATE_LIMIT_SOURCE_PLAN +} + +// evaluateCodyGatewayAccessRateLimits returns the current CodyGatewayRateLimits based on the +// plan and applying known overrides on top. This closely models the existing +// codyGatewayAccessResolver in 'cmd/frontend/internal/dotcom/productsubscription'. +func evaluateCodyGatewayAccessRateLimits(access *codyaccess.CodyGatewayAccessWithSubscriptionDetails) CodyGatewayRateLimits { + // Set defaults for everything based on active licnese plan and user count. + // If there isn't one, zero values apply. + activeLicense := pointers.DerefZero(access.ActiveLicenseInfo) + p := licensing.PlanFromTags(activeLicense.Tags) + userCount := pointers.Ptr(int(activeLicense.UserCount)) + + limits := CodyGatewayRateLimits{ + ChatSource: codyaccessv1.CodyGatewayRateLimitSource_CODY_GATEWAY_RATE_LIMIT_SOURCE_PLAN, + Chat: licensing.NewCodyGatewayChatRateLimit(p, userCount), + + CodeSource: codyaccessv1.CodyGatewayRateLimitSource_CODY_GATEWAY_RATE_LIMIT_SOURCE_PLAN, + Code: licensing.NewCodyGatewayCodeRateLimit(p, userCount), + + EmbeddingsSource: codyaccessv1.CodyGatewayRateLimitSource_CODY_GATEWAY_RATE_LIMIT_SOURCE_PLAN, + Embeddings: licensing.NewCodyGatewayEmbeddingsRateLimit(p, userCount), + } + + // Chat + limits.ChatSource = maybeApplyOverride(&limits.Chat.Limit, + access.ChatCompletionsRateLimit.Int64, access.ChatCompletionsRateLimit.Valid) + limits.ChatSource = maybeApplyOverride(&limits.Chat.IntervalSeconds, + access.ChatCompletionsRateLimitIntervalSeconds.Int32, access.ChatCompletionsRateLimitIntervalSeconds.Valid) + + // Code + limits.CodeSource = maybeApplyOverride(&limits.Code.Limit, + access.CodeCompletionsRateLimit.Int64, access.CodeCompletionsRateLimit.Valid) + limits.CodeSource = maybeApplyOverride(&limits.Code.IntervalSeconds, + access.CodeCompletionsRateLimitIntervalSeconds.Int32, access.CodeCompletionsRateLimitIntervalSeconds.Valid) + + // Embeddings + limits.EmbeddingsSource = maybeApplyOverride(&limits.Embeddings.Limit, + access.EmbeddingsRateLimit.Int64, access.EmbeddingsRateLimit.Valid) + limits.EmbeddingsSource = maybeApplyOverride(&limits.Embeddings.IntervalSeconds, + access.EmbeddingsRateLimitIntervalSeconds.Int32, access.EmbeddingsRateLimitIntervalSeconds.Valid) + + return limits +} + func convertCodyGatewayUsageDatapoints(usage []codygatewayevents.SubscriptionUsage) []*codyaccessv1.CodyGatewayUsage_UsageDatapoint { results := make([]*codyaccessv1.CodyGatewayUsage_UsageDatapoint, len(usage)) for i, datapoint := range usage { diff --git a/cmd/enterprise-portal/internal/codyaccessservice/adapters_test.go b/cmd/enterprise-portal/internal/codyaccessservice/adapters_test.go index 175974e40e9..592a8b789c7 100644 --- a/cmd/enterprise-portal/internal/codyaccessservice/adapters_test.go +++ b/cmd/enterprise-portal/internal/codyaccessservice/adapters_test.go @@ -7,19 +7,21 @@ import ( "github.com/hexops/autogold/v2" "github.com/stretchr/testify/assert" - "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/dotcomdb" + "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/codyaccess" ) func TestConvertAccessAttrsToProto(t *testing.T) { t.Run("zero value", func(t *testing.T) { - proto := convertAccessAttrsToProto(&dotcomdb.CodyGatewayAccessAttributes{}) + proto := convertAccessAttrsToProto(&codyaccess.CodyGatewayAccessWithSubscriptionDetails{}) assert.False(t, proto.Enabled) }) t.Run("disabled returns access tokens", func(t *testing.T) { - proto := convertAccessAttrsToProto(&dotcomdb.CodyGatewayAccessAttributes{ - CodyGatewayEnabled: false, - LicenseKeyHashes: [][]byte{[]byte("abc"), []byte("efg")}, + proto := convertAccessAttrsToProto(&codyaccess.CodyGatewayAccessWithSubscriptionDetails{ + CodyGatewayAccess: codyaccess.CodyGatewayAccess{ + Enabled: false, + }, + LicenseKeyHashes: [][]byte{[]byte("abc"), []byte("efg")}, }) assert.False(t, proto.Enabled) // NOTE: These are not real access tokens @@ -31,18 +33,22 @@ func TestConvertAccessAttrsToProto(t *testing.T) { }) t.Run("enabled with empty access token", func(t *testing.T) { - proto := convertAccessAttrsToProto(&dotcomdb.CodyGatewayAccessAttributes{ - CodyGatewayEnabled: true, - LicenseKeyHashes: [][]byte{[]byte(""), nil}, + proto := convertAccessAttrsToProto(&codyaccess.CodyGatewayAccessWithSubscriptionDetails{ + CodyGatewayAccess: codyaccess.CodyGatewayAccess{ + Enabled: true, + }, + LicenseKeyHashes: [][]byte{[]byte(""), nil}, }) assert.True(t, proto.Enabled) assert.Empty(t, proto.GetAccessTokens()) }) t.Run("enabled returns everything", func(t *testing.T) { - proto := convertAccessAttrsToProto(&dotcomdb.CodyGatewayAccessAttributes{ - CodyGatewayEnabled: true, - LicenseKeyHashes: [][]byte{[]byte("abc"), []byte("efg")}, + proto := convertAccessAttrsToProto(&codyaccess.CodyGatewayAccessWithSubscriptionDetails{ + CodyGatewayAccess: codyaccess.CodyGatewayAccess{ + Enabled: true, + }, + LicenseKeyHashes: [][]byte{[]byte("abc"), []byte("efg")}, }) assert.True(t, proto.Enabled) // NOTE: These are not real access tokens diff --git a/cmd/enterprise-portal/internal/codyaccessservice/v1.go b/cmd/enterprise-portal/internal/codyaccessservice/v1.go index de206231967..6311067edfa 100644 --- a/cmd/enterprise-portal/internal/codyaccessservice/v1.go +++ b/cmd/enterprise-portal/internal/codyaccessservice/v1.go @@ -12,6 +12,7 @@ import ( "github.com/sourcegraph/sourcegraph-accounts-sdk-go/scopes" "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/connectutil" + "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/codyaccess" "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/dotcomdb" "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/samsm2m" "github.com/sourcegraph/sourcegraph/internal/trace" @@ -23,42 +24,40 @@ import ( const Name = codyaccessv1connect.CodyAccessServiceName -type DotComDB interface { - GetCodyGatewayAccessAttributesBySubscription(ctx context.Context, subscriptionID string) (*dotcomdb.CodyGatewayAccessAttributes, error) - GetCodyGatewayAccessAttributesByAccessToken(ctx context.Context, subscriptionID string) (*dotcomdb.CodyGatewayAccessAttributes, error) - GetAllCodyGatewayAccessAttributes(ctx context.Context) ([]*dotcomdb.CodyGatewayAccessAttributes, error) -} - func RegisterV1( logger log.Logger, mux *http.ServeMux, store StoreV1, - dotcom DotComDB, opts ...connect.HandlerOption, ) { mux.Handle( codyaccessv1connect.NewCodyAccessServiceHandler( - &handlerV1{ - logger: logger.Scoped("codyaccess.v1"), - store: store, - dotcom: dotcom, - }, + NewHandlerV1(logger, store), opts..., ), ) } -type handlerV1 struct { +type HandlerV1 struct { codyaccessv1connect.UnimplementedCodyAccessServiceHandler logger log.Logger store StoreV1 - dotcom DotComDB } -var _ codyaccessv1connect.CodyAccessServiceHandler = (*handlerV1)(nil) +func NewHandlerV1( + logger log.Logger, + store StoreV1, +) *HandlerV1 { + return &HandlerV1{ + logger: logger.Scoped("codyaccess.v1"), + store: store, + } +} -func (s *handlerV1) GetCodyGatewayAccess(ctx context.Context, req *connect.Request[codyaccessv1.GetCodyGatewayAccessRequest]) (*connect.Response[codyaccessv1.GetCodyGatewayAccessResponse], error) { +var _ codyaccessv1connect.CodyAccessServiceHandler = (*HandlerV1)(nil) + +func (s *HandlerV1) GetCodyGatewayAccess(ctx context.Context, req *connect.Request[codyaccessv1.GetCodyGatewayAccessRequest]) (*connect.Response[codyaccessv1.GetCodyGatewayAccessResponse], error) { logger := trace.Logger(ctx, s.logger). With(log.String("queryType", fmt.Sprintf("%T", req.Msg.GetQuery()))) @@ -70,19 +69,19 @@ func (s *handlerV1) GetCodyGatewayAccess(ctx context.Context, req *connect.Reque } logger = logger.With(clientAttrs...) - var attr *dotcomdb.CodyGatewayAccessAttributes + var attr *codyaccess.CodyGatewayAccessWithSubscriptionDetails switch query := req.Msg.GetQuery().(type) { case *codyaccessv1.GetCodyGatewayAccessRequest_SubscriptionId: if len(query.SubscriptionId) == 0 { return nil, connect.NewError(connect.CodeInvalidArgument, errors.New("invalid query: subscription ID")) } - attr, err = s.dotcom.GetCodyGatewayAccessAttributesBySubscription(ctx, query.SubscriptionId) + attr, err = s.store.GetCodyGatewayAccessBySubscription(ctx, query.SubscriptionId) case *codyaccessv1.GetCodyGatewayAccessRequest_AccessToken: if len(query.AccessToken) == 0 { return nil, connect.NewError(connect.CodeInvalidArgument, errors.New("invalid query: access token")) } - attr, err = s.dotcom.GetCodyGatewayAccessAttributesByAccessToken(ctx, query.AccessToken) + attr, err = s.store.GetCodyGatewayAccessByAccessToken(ctx, query.AccessToken) default: return nil, connect.NewError(connect.CodeInvalidArgument, errors.New("invalid query")) @@ -103,7 +102,7 @@ func (s *handlerV1) GetCodyGatewayAccess(ctx context.Context, req *connect.Reque }), nil } -func (s *handlerV1) ListCodyGatewayAccesses(ctx context.Context, req *connect.Request[codyaccessv1.ListCodyGatewayAccessesRequest]) (*connect.Response[codyaccessv1.ListCodyGatewayAccessesResponse], error) { +func (s *HandlerV1) ListCodyGatewayAccesses(ctx context.Context, req *connect.Request[codyaccessv1.ListCodyGatewayAccessesRequest]) (*connect.Response[codyaccessv1.ListCodyGatewayAccessesResponse], error) { logger := trace.Logger(ctx, s.logger) // 🚨 SECURITY: Require approrpiate M2M scope. @@ -122,7 +121,7 @@ func (s *handlerV1) ListCodyGatewayAccesses(ctx context.Context, req *connect.Re return nil, connect.NewError(connect.CodeUnimplemented, errors.New("pagination not implemented")) } - attrs, err := s.dotcom.GetAllCodyGatewayAccessAttributes(ctx) + attrs, err := s.store.ListCodyGatewayAccesses(ctx) if err != nil { if err == dotcomdb.ErrCodyGatewayAccessNotFound { return nil, connect.NewError(connect.CodeNotFound, err) @@ -146,7 +145,7 @@ func (s *handlerV1) ListCodyGatewayAccesses(ctx context.Context, req *connect.Re return connect.NewResponse(&resp), nil } -func (s *handlerV1) GetCodyGatewayUsage(ctx context.Context, req *connect.Request[codyaccessv1.GetCodyGatewayUsageRequest]) (*connect.Response[codyaccessv1.GetCodyGatewayUsageResponse], error) { +func (s *HandlerV1) GetCodyGatewayUsage(ctx context.Context, req *connect.Request[codyaccessv1.GetCodyGatewayUsageRequest]) (*connect.Response[codyaccessv1.GetCodyGatewayUsageResponse], error) { logger := trace.Logger(ctx, s.logger) // 🚨 SECURITY: Require appropriate M2M scope. diff --git a/cmd/enterprise-portal/internal/codyaccessservice/v1_store.go b/cmd/enterprise-portal/internal/codyaccessservice/v1_store.go index 674fcb2d4dc..ba9bc998721 100644 --- a/cmd/enterprise-portal/internal/codyaccessservice/v1_store.go +++ b/cmd/enterprise-portal/internal/codyaccessservice/v1_store.go @@ -2,13 +2,19 @@ package codyaccessservice import ( "context" + "encoding/hex" + "strings" "github.com/sourcegraph/conc/pool" sams "github.com/sourcegraph/sourcegraph-accounts-sdk-go" + "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database" + "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/codyaccess" "github.com/sourcegraph/sourcegraph/internal/codygateway/codygatewayactor" "github.com/sourcegraph/sourcegraph/internal/codygateway/codygatewayevents" "github.com/sourcegraph/sourcegraph/internal/completions/types" + "github.com/sourcegraph/sourcegraph/internal/license" + "github.com/sourcegraph/sourcegraph/internal/productsubscription" codyaccessv1 "github.com/sourcegraph/sourcegraph/lib/enterpriseportal/codyaccess/v1" subscriptionsv1 "github.com/sourcegraph/sourcegraph/lib/enterpriseportal/subscriptions/v1" "github.com/sourcegraph/sourcegraph/lib/errors" @@ -24,18 +30,33 @@ type StoreV1 interface { // is no longer active. It is critical that the caller not honor tokens where // `.Active == false`. IntrospectSAMSToken(ctx context.Context, token string) (*sams.IntrospectTokenResponse, error) + // GetCodyGatewayUsage retrieves recent Cody Gateway usage data. // The subscriptionID should not be prefixed. GetCodyGatewayUsage(ctx context.Context, subscriptionID string) (*codyaccessv1.CodyGatewayUsage, error) + + // GetCodyGatewayAccessBySubscription retrieves Cody Gateway access by + // subscription ID. + GetCodyGatewayAccessBySubscription(ctx context.Context, subscriptionID string) (*codyaccess.CodyGatewayAccessWithSubscriptionDetails, error) + + // GetCodyGatewayAccessByAccessToken retrieves Cody Gateway access details + // associated with the given access token. + GetCodyGatewayAccessByAccessToken(ctx context.Context, token string) (*codyaccess.CodyGatewayAccessWithSubscriptionDetails, error) + + // ListCodyGatewayAccesses retrieves all Cody Gateway accesses with their + // associated subscription details. + ListCodyGatewayAccesses(ctx context.Context) ([]*codyaccess.CodyGatewayAccessWithSubscriptionDetails, error) } type storeV1 struct { SAMSClient *sams.ClientV1 + CodyAccess *codyaccess.Store CodyGatewayEvents *codygatewayevents.Service } type StoreV1Options struct { SAMSClient *sams.ClientV1 + DB *database.DB // Optional. CodyGatewayEvents *codygatewayevents.Service } @@ -46,6 +67,7 @@ var errStoreUnimplemented = errors.New("unimplemented") func NewStoreV1(opts StoreV1Options) StoreV1 { return &storeV1{ SAMSClient: opts.SAMSClient, + CodyAccess: opts.DB.CodyAccess(), CodyGatewayEvents: opts.CodyGatewayEvents, } } @@ -126,4 +148,34 @@ func (s *storeV1) GetCodyGatewayUsage(ctx context.Context, subscriptionID string applyResult(usage) } return usage, nil + +} + +func (s *storeV1) GetCodyGatewayAccessBySubscription(ctx context.Context, subscriptionID string) (*codyaccess.CodyGatewayAccessWithSubscriptionDetails, error) { + return s.CodyAccess.CodyGateway().Get(ctx, codyaccess.GetCodyGatewayAccessOptions{ + SubscriptionID: subscriptionID, + }) +} + +func (s *storeV1) GetCodyGatewayAccessByAccessToken(ctx context.Context, token string) (*codyaccess.CodyGatewayAccessWithSubscriptionDetails, error) { + // Below is copied from 'func (t dbTokens) LookupProductSubscriptionIDByAccessToken' + // in 'cmd/frontend/internal/dotcom/productsubscription'. + if !strings.HasPrefix(token, productsubscription.AccessTokenPrefix) && + !strings.HasPrefix(token, license.LicenseKeyBasedAccessTokenPrefix) { + return nil, errors.WithSafeDetails(codyaccess.ErrSubscriptionDoesNotExist, "invalid token with unknown prefix") + } + tokenSansPrefix := token[len(license.LicenseKeyBasedAccessTokenPrefix):] + decoded, err := hex.DecodeString(tokenSansPrefix) + if err != nil { + return nil, errors.WithSafeDetails(codyaccess.ErrSubscriptionDoesNotExist, "invalid token with unknown encoding") + } + // End copied code. + + return s.CodyAccess.CodyGateway().Get(ctx, codyaccess.GetCodyGatewayAccessOptions{ + LicenseKeyHash: decoded, + }) +} + +func (s *storeV1) ListCodyGatewayAccesses(ctx context.Context) ([]*codyaccess.CodyGatewayAccessWithSubscriptionDetails, error) { + return s.CodyAccess.CodyGateway().List(ctx) } diff --git a/cmd/enterprise-portal/internal/database/codyaccess/codygateway.go b/cmd/enterprise-portal/internal/database/codyaccess/codygateway.go index b6acf8812f3..76b79f308a6 100644 --- a/cmd/enterprise-portal/internal/database/codyaccess/codygateway.go +++ b/cmd/enterprise-portal/internal/database/codyaccess/codygateway.go @@ -95,7 +95,8 @@ func scanCodyGatewayAccess(row pgx.Row) (*CodyGatewayAccessWithSubscriptionDetai // an active subscription exists, but explicit access is not configured. In // this case we still need to return a valid CodyGatewayAccessWithSubscriptionDetails, // just with empty fields. - var maybeEnabled *bool + var maybeEnabled sql.NullBool + var maybeDisplayName sql.NullString err := row.Scan( &a.SubscriptionID, &maybeEnabled, @@ -106,7 +107,7 @@ func scanCodyGatewayAccess(row pgx.Row) (*CodyGatewayAccessWithSubscriptionDetai &a.EmbeddingsRateLimit, &a.EmbeddingsRateLimitIntervalSeconds, // Subscriptions fields - &a.DisplayName, + &maybeDisplayName, // License fields &a.ActiveLicenseInfo, &a.LicenseKeyHashes, @@ -114,9 +115,9 @@ func scanCodyGatewayAccess(row pgx.Row) (*CodyGatewayAccessWithSubscriptionDetai if err != nil { return nil, err } - if maybeEnabled != nil { - a.Enabled = *maybeEnabled - } + a.Enabled = maybeEnabled.Bool + a.DisplayName = maybeDisplayName.String + return &a, nil } @@ -181,27 +182,55 @@ type CodyGatewayAccessWithSubscriptionDetails struct { DisplayName string ActiveLicenseInfo *license.Info - LicenseKeyHashes [][]byte + + // Used by GenerateAccessTokens + LicenseKeyHashes [][]byte } var ErrSubscriptionDoesNotExist = errors.New("subscription does not exist") +type GetCodyGatewayAccessOptions struct { + SubscriptionID string + LicenseKeyHash []byte +} + +func (opts GetCodyGatewayAccessOptions) buildConds() (string, pgx.NamedArgs, error) { + if opts.SubscriptionID == "" && len(opts.LicenseKeyHash) == 0 { + return "", nil, errors.New("must specify either SubscriptionID or LicenseKeyHash") + } + + args := pgx.NamedArgs{} + conds := []string{"TRUE"} + if opts.SubscriptionID != "" { + conds = append(conds, "subscription.id = @subscriptionID") + args["subscriptionID"] = opts.SubscriptionID + } + if len(opts.LicenseKeyHash) > 0 { + conds = append(conds, "@licenseKeyHash = ANY(tokens.license_key_hashes)") + args["licenseKeyHash"] = opts.LicenseKeyHash + } + return strings.Join(conds, " AND "), args, nil +} + // Get returns the Cody Gateway access for the given subscription. -func (s *CodyGatewayStore) Get(ctx context.Context, subscriptionID string) (*CodyGatewayAccessWithSubscriptionDetails, error) { +func (s *CodyGatewayStore) Get(ctx context.Context, opts GetCodyGatewayAccessOptions) (*CodyGatewayAccessWithSubscriptionDetails, error) { + conds, args, err := opts.buildConds() + if err != nil { + return nil, err + } query := fmt.Sprintf(`SELECT %s FROM enterprise_portal_cody_gateway_access AS access %s WHERE - subscription.id = @subscriptionID + %s AND subscription.archived_at IS NULL`, strings.Join(codyGatewayAccessTableColumns(), ", "), - codyGatewayAccessJoinClauses) + codyGatewayAccessJoinClauses, + conds) - sub, err := scanCodyGatewayAccess(s.db.QueryRow(ctx, query, pgx.NamedArgs{ - "subscriptionID": subscriptionID, - })) + sub, err := scanCodyGatewayAccess(s.db.QueryRow(ctx, query, args)) if err != nil { if errors.Is(err, pgx.ErrNoRows) { // RIGHT JOIN in query ensures that if we find no result, it's @@ -297,5 +326,5 @@ func (s *CodyGatewayStore) Upsert(ctx context.Context, subscriptionID string, op } return nil, err } - return s.Get(ctx, subscriptionID) + return s.Get(ctx, GetCodyGatewayAccessOptions{SubscriptionID: subscriptionID}) } diff --git a/cmd/enterprise-portal/internal/database/codyaccess/codygateway_test.go b/cmd/enterprise-portal/internal/database/codyaccess/codygateway_test.go index b226545d106..964d02b8910 100644 --- a/cmd/enterprise-portal/internal/database/codyaccess/codygateway_test.go +++ b/cmd/enterprise-portal/internal/database/codyaccess/codygateway_test.go @@ -149,7 +149,9 @@ func TestCodyGatewayStore(t *testing.T) { }) require.NoError(t, err) - _, err = codyaccess.NewCodyGatewayStore(db).Get(ctx, subscriptionID) + _, err = codyaccess.NewCodyGatewayStore(db).Get(ctx, codyaccess.GetCodyGatewayAccessOptions{ + SubscriptionID: subscriptionID, + }) assert.ErrorIs(t, err, codyaccess.ErrSubscriptionDoesNotExist) }) }) @@ -214,15 +216,36 @@ func CodyGatewayStoreListAndGet(t *testing.T, ctx context.Context, subscriptionI t.Run("Get", func(t *testing.T) { for idx, sub := range subscriptionIDs { t.Run(fmt.Sprintf("idx=%d", idx), func(t *testing.T) { - got, err := s.Get(ctx, sub) + got, err := s.Get(ctx, codyaccess.GetCodyGatewayAccessOptions{ + SubscriptionID: sub, + }) require.NoError(t, err) assertAccess(idx, got) + + // Reverse lookup by license key hash + for _, hash := range got.LicenseKeyHashes { + got2, err := s.Get(ctx, codyaccess.GetCodyGatewayAccessOptions{ + LicenseKeyHash: hash, + }) + require.NoError(t, err) + assert.Len(t, got2.LicenseKeyHashes, 2) // 2 valid licenses + assert.Equal(t, got, got2) + } }) } t.Run("ErrSubscriptionDoesNotExist", func(t *testing.T) { - _, err := s.Get(ctx, uuid.NewString()) + _, err := s.Get(ctx, codyaccess.GetCodyGatewayAccessOptions{ + SubscriptionID: uuid.NewString(), + }) + assert.Error(t, err) + assert.ErrorIs(t, err, codyaccess.ErrSubscriptionDoesNotExist) + + _, err = s.Get(ctx, codyaccess.GetCodyGatewayAccessOptions{ + LicenseKeyHash: []byte(uuid.NewString()), + }) + assert.Error(t, err) assert.ErrorIs(t, err, codyaccess.ErrSubscriptionDoesNotExist) }) }) @@ -241,7 +264,9 @@ func CodyGatewayStoreUpsert(t *testing.T, ctx context.Context, subscriptionIDs [ ) require.NoError(t, err) - got, err := s.Get(ctx, currentAccess.SubscriptionID) + got, err := s.Get(ctx, codyaccess.GetCodyGatewayAccessOptions{ + SubscriptionID: currentAccess.SubscriptionID, + }) require.NoError(t, err) assert.False(t, got.Enabled) assert.Equal(t, currentAccess.SubscriptionID, got.SubscriptionID) diff --git a/cmd/enterprise-portal/internal/database/database.go b/cmd/enterprise-portal/internal/database/database.go index 440c4dc12bc..c1b3c855710 100644 --- a/cmd/enterprise-portal/internal/database/database.go +++ b/cmd/enterprise-portal/internal/database/database.go @@ -20,15 +20,15 @@ var databaseTracer = otel.Tracer("enterprise-portal/internal/database") // DB is the database handle for the storage layer. type DB struct { - db *pgxpool.Pool + DB *pgxpool.Pool } func (db *DB) Subscriptions() *subscriptions.Store { - return subscriptions.NewStore(db.db) + return subscriptions.NewStore(db.DB) } func (db *DB) CodyAccess() *codyaccess.Store { - return codyaccess.NewStore(db.db) + return codyaccess.NewStore(db.DB) } func databaseName(msp bool) string { @@ -53,11 +53,11 @@ func NewHandle(ctx context.Context, logger log.Logger, contract runtime.Contract if err != nil { return nil, errors.Wrap(err, "get connection pool") } - return &DB{db: pool}, nil + return &DB{DB: pool}, nil } // Close closes all connections in the pool and rejects future Acquire calls. // Blocks until all connections are returned to pool and closed. func (db *DB) Close() { - db.db.Close() + db.DB.Close() } diff --git a/cmd/enterprise-portal/internal/database/databasetest/BUILD.bazel b/cmd/enterprise-portal/internal/database/databasetest/BUILD.bazel index d2a6fe3656d..fa1cd5cc43a 100644 --- a/cmd/enterprise-portal/internal/database/databasetest/BUILD.bazel +++ b/cmd/enterprise-portal/internal/database/databasetest/BUILD.bazel @@ -7,6 +7,7 @@ go_library( tags = [TAG_INFRA_CORESERVICES], visibility = ["//cmd/enterprise-portal:__subpackages__"], deps = [ + "//cmd/enterprise-portal/internal/database/internal/tables", "//cmd/enterprise-portal/internal/database/internal/tables/custommigrator", "//internal/database/dbtest", "@com_github_jackc_pgx_v5//:pgx", diff --git a/cmd/enterprise-portal/internal/database/databasetest/databasetest.go b/cmd/enterprise-portal/internal/database/databasetest/databasetest.go index 455444e94b9..553877e5987 100644 --- a/cmd/enterprise-portal/internal/database/databasetest/databasetest.go +++ b/cmd/enterprise-portal/internal/database/databasetest/databasetest.go @@ -16,10 +16,13 @@ import ( "gorm.io/gorm" "gorm.io/gorm/schema" + "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables" "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables/custommigrator" "github.com/sourcegraph/sourcegraph/internal/database/dbtest" ) +func Tables(_ *testing.T) []schema.Tabler { return tables.All() } + // NewTestDB creates a new test database and initializes the given list of // tables for the suite. The test database is dropped after testing is completed // unless failed. @@ -37,6 +40,7 @@ func NewTestDB(t testing.TB, system, suite string, tables ...schema.Tabler) *pgx // Set up test suite database. dbName := fmt.Sprintf("sourcegraph-test-%s-%s-%d", system, suite, time.Now().Unix()) + t.Logf("Preparing database %s", dbName) _, err = sqlDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %q", dbName)) require.NoError(t, err) diff --git a/cmd/enterprise-portal/internal/database/importer/importer.go b/cmd/enterprise-portal/internal/database/importer/importer.go index 02c0acc8e17..e3ce6757b34 100644 --- a/cmd/enterprise-portal/internal/database/importer/importer.go +++ b/cmd/enterprise-portal/internal/database/importer/importer.go @@ -26,7 +26,7 @@ import ( "github.com/sourcegraph/sourcegraph/lib/pointers" ) -type importer struct { +type Importer struct { logger log.Logger dotcom *dotcomdb.Reader @@ -36,14 +36,29 @@ type importer struct { codyGatewayAccess *codyaccess.CodyGatewayStore } -var _ goroutine.Handler = (*importer)(nil) +func NewHandler( + ctx context.Context, + logger log.Logger, + dotcom *dotcomdb.Reader, + enterprisePortal *database.DB, +) *Importer { + return &Importer{ + logger: logger, + dotcom: dotcom, + subscriptions: enterprisePortal.Subscriptions(), + licenses: enterprisePortal.Subscriptions().Licenses(), + codyGatewayAccess: enterprisePortal.CodyAccess().CodyGateway(), + } +} -// New returns a periodic goroutine that runs an importer that reconciles -// subscriptions, licenses, and Cody Gateway access from dotcom into the -// Enterprise Portal database. +var _ goroutine.Handler = (*Importer)(nil) + +// NewPeriodicImporter returns a periodic goroutine that runs an importer that +// reconciles subscriptions, licenses, and Cody Gateway access from dotcom into +// the Enterprise Portal database. // // If interval is 0, the importer is disabled. -func New( +func NewPeriodicImporter( ctx context.Context, logger log.Logger, dotcom *dotcomdb.Reader, @@ -56,13 +71,7 @@ func New( } return goroutine.NewPeriodicGoroutine( ctx, - &importer{ - logger: logger, - dotcom: dotcom, - subscriptions: enterprisePortal.Subscriptions(), - licenses: enterprisePortal.Subscriptions().Licenses(), - codyGatewayAccess: enterprisePortal.CodyAccess().CodyGateway(), - }, + NewHandler(ctx, logger, dotcom, enterprisePortal), goroutine.WithOperation( observation.NewContext(logger, observation.Tracer(trace.GetTracer())). Operation(observation.Op{ @@ -72,7 +81,7 @@ func New( goroutine.WithInterval(interval)) } -func (i *importer) Handle(ctx context.Context) error { +func (i *Importer) Handle(ctx context.Context) error { l := trace.Logger(ctx, i.logger) dotcomSubscriptions, err := i.dotcom.ListEnterpriseSubscriptions(ctx, dotcomdb.ListEnterpriseSubscriptionsOptions{}) @@ -101,7 +110,7 @@ func (i *importer) Handle(ctx context.Context) error { return nil } -func (i *importer) importSubscription(ctx context.Context, dotcomSub *dotcomdb.SubscriptionAttributes) (err error) { +func (i *Importer) importSubscription(ctx context.Context, dotcomSub *dotcomdb.SubscriptionAttributes) (err error) { tr, ctx := trace.New(ctx, "importSubscription", attribute.String("dotcomSub.ID", dotcomSub.ID)) defer tr.EndWithErr(&err) @@ -227,7 +236,7 @@ func (i *importer) importSubscription(ctx context.Context, dotcomSub *dotcomdb.S return nil } -func (i *importer) importLicense(ctx context.Context, subscriptionID string, dotcomLicense *dotcomdb.LicenseAttributes) (err error) { +func (i *Importer) importLicense(ctx context.Context, subscriptionID string, dotcomLicense *dotcomdb.LicenseAttributes) (err error) { tr, ctx := trace.New(ctx, "importSubscription", attribute.String("dotcomSub.ID", subscriptionID), attribute.String("dotcomLicense.ID", dotcomLicense.ID)) diff --git a/cmd/enterprise-portal/internal/dotcomdb/BUILD.bazel b/cmd/enterprise-portal/internal/dotcomdb/BUILD.bazel index af7203c02cd..64af771532e 100644 --- a/cmd/enterprise-portal/internal/dotcomdb/BUILD.bazel +++ b/cmd/enterprise-portal/internal/dotcomdb/BUILD.bazel @@ -28,6 +28,10 @@ go_test( ], deps = [ ":dotcomdb", + "//cmd/enterprise-portal/internal/codyaccessservice", + "//cmd/enterprise-portal/internal/database", + "//cmd/enterprise-portal/internal/database/databasetest", + "//cmd/enterprise-portal/internal/database/importer", "//cmd/frontend/dotcomproductsubscriptiontest", "//cmd/frontend/graphqlbackend", "//internal/database", @@ -35,11 +39,15 @@ go_test( "//internal/database/dbtest", "//internal/license", "//internal/licensing", + "//lib/enterpriseportal/codyaccess/v1:codyaccess", "//lib/enterpriseportal/subscriptions/v1:subscriptions", "//lib/pointers", + "@com_connectrpc_connect//:connect", "@com_github_jackc_pgx_v4//stdlib", "@com_github_jackc_pgx_v5//pgxpool", "@com_github_sourcegraph_log//logtest", + "@com_github_sourcegraph_sourcegraph_accounts_sdk_go//:sourcegraph-accounts-sdk-go", + "@com_github_sourcegraph_sourcegraph_accounts_sdk_go//scopes", "@com_github_stretchr_testify//assert", "@com_github_stretchr_testify//require", ], diff --git a/cmd/enterprise-portal/internal/dotcomdb/dotcomdb.go b/cmd/enterprise-portal/internal/dotcomdb/dotcomdb.go index 304e5e270b2..74add521f5b 100644 --- a/cmd/enterprise-portal/internal/dotcomdb/dotcomdb.go +++ b/cmd/enterprise-portal/internal/dotcomdb/dotcomdb.go @@ -125,10 +125,10 @@ func (c CodyGatewayAccessAttributes) EvaluateRateLimits() CodyGatewayRateLimits Chat: licensing.NewCodyGatewayChatRateLimit(p, c.ActiveLicenseUserCount), CodeSource: codyaccessv1.CodyGatewayRateLimitSource_CODY_GATEWAY_RATE_LIMIT_SOURCE_PLAN, - Code: licensing.NewCodyGatewayCodeRateLimit(p, c.ActiveLicenseUserCount, c.ActiveLicenseTags), + Code: licensing.NewCodyGatewayCodeRateLimit(p, c.ActiveLicenseUserCount), EmbeddingsSource: codyaccessv1.CodyGatewayRateLimitSource_CODY_GATEWAY_RATE_LIMIT_SOURCE_PLAN, - Embeddings: licensing.NewCodyGatewayEmbeddingsRateLimit(p, c.ActiveLicenseUserCount, c.ActiveLicenseTags), + Embeddings: licensing.NewCodyGatewayEmbeddingsRateLimit(p, c.ActiveLicenseUserCount), } // Chat diff --git a/cmd/enterprise-portal/internal/dotcomdb/dotcomdb_test.go b/cmd/enterprise-portal/internal/dotcomdb/dotcomdb_test.go index 81eabf8ae09..9b88dfbcc40 100644 --- a/cmd/enterprise-portal/internal/dotcomdb/dotcomdb_test.go +++ b/cmd/enterprise-portal/internal/dotcomdb/dotcomdb_test.go @@ -6,6 +6,7 @@ import ( "testing" "time" + "connectrpc.com/connect" pgxstdlibv4 "github.com/jackc/pgx/v4/stdlib" "github.com/jackc/pgx/v5/pgxpool" "github.com/stretchr/testify/assert" @@ -13,19 +14,28 @@ import ( "github.com/sourcegraph/log/logtest" + "github.com/sourcegraph/sourcegraph-accounts-sdk-go/scopes" + + sams "github.com/sourcegraph/sourcegraph-accounts-sdk-go" + + "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/codyaccessservice" + "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database" + "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/databasetest" + "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/importer" "github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/dotcomdb" "github.com/sourcegraph/sourcegraph/cmd/frontend/dotcomproductsubscriptiontest" "github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend" - "github.com/sourcegraph/sourcegraph/internal/database" + sgdatabase "github.com/sourcegraph/sourcegraph/internal/database" "github.com/sourcegraph/sourcegraph/internal/database/dbconn" "github.com/sourcegraph/sourcegraph/internal/database/dbtest" "github.com/sourcegraph/sourcegraph/internal/license" "github.com/sourcegraph/sourcegraph/internal/licensing" - v1 "github.com/sourcegraph/sourcegraph/lib/enterpriseportal/subscriptions/v1" + codyaccessv1 "github.com/sourcegraph/sourcegraph/lib/enterpriseportal/codyaccess/v1" + subscriptionsv1 "github.com/sourcegraph/sourcegraph/lib/enterpriseportal/subscriptions/v1" "github.com/sourcegraph/sourcegraph/lib/pointers" ) -func newTestDotcomReader(t *testing.T, opts dotcomdb.ReaderOptions) (database.DB, *dotcomdb.Reader) { +func newTestDotcomReader(t *testing.T, opts dotcomdb.ReaderOptions) (sgdatabase.DB, *dotcomdb.Reader) { ctx := context.Background() // Set up a Sourcegraph test database. @@ -65,7 +75,23 @@ func newTestDotcomReader(t *testing.T, opts dotcomdb.ReaderOptions) (database.DB r := dotcomdb.NewReader(conn, opts) require.NoError(t, r.Ping(ctx)) - return database.NewDB(logtest.Scoped(t), sgtestdb), r + return sgdatabase.NewDB(logtest.Scoped(t), sgtestdb), r +} + +type mockCodyAccessV1Store struct{ codyaccessservice.StoreV1 } + +func (mockCodyAccessV1Store) IntrospectSAMSToken(context.Context, string) (*sams.IntrospectTokenResponse, error) { + return &sams.IntrospectTokenResponse{ + Active: true, + ExpiresAt: time.Now().Add(24 * time.Hour), + Scopes: scopes.Scopes{scopes.ToScope(scopes.ServiceEnterprisePortal, scopes.PermissionEnterprisePortalCodyAccess, scopes.ActionRead)}, + }, nil +} + +func mockAuthenticatedServiceRequest[T any](message *T) *connect.Request[T] { + req := connect.NewRequest(message) + req.Header().Set("authorization", "bearer foobar") + return req } type mockedData struct { @@ -76,7 +102,7 @@ type mockedData struct { archivedSubscriptions int } -func setupDBAndInsertMockLicense(t *testing.T, dotcomdb database.DB, info license.Info, cgAccess *graphqlbackend.UpdateCodyGatewayAccessInput) mockedData { +func setupDBAndInsertMockLicense(t *testing.T, dotcomdb sgdatabase.DB, info license.Info, cgAccess *graphqlbackend.UpdateCodyGatewayAccessInput) mockedData { start := time.Now() ctx := context.Background() @@ -87,7 +113,7 @@ func setupDBAndInsertMockLicense(t *testing.T, dotcomdb database.DB, info licens { // Create a different subscription and license that's rubbish, // created at the same time, to ensure we don't use it - u, err := dotcomdb.Users().Create(ctx, database.NewUser{Username: "barbaz"}) + u, err := dotcomdb.Users().Create(ctx, sgdatabase.NewUser{Username: "barbaz"}) require.NoError(t, err) sub, err := subscriptionsdb.Create(ctx, u.ID, u.Username) require.NoError(t, err) @@ -104,7 +130,7 @@ func setupDBAndInsertMockLicense(t *testing.T, dotcomdb database.DB, info licens { // Create a different subscription and license that's archived, // created at the same time, to ensure we don't use it - u, err := dotcomdb.Users().Create(ctx, database.NewUser{Username: "archived"}) + u, err := dotcomdb.Users().Create(ctx, sgdatabase.NewUser{Username: "archived"}) require.NoError(t, err) sub, err := subscriptionsdb.Create(ctx, u.ID, u.Username) require.NoError(t, err) @@ -124,7 +150,7 @@ func setupDBAndInsertMockLicense(t *testing.T, dotcomdb database.DB, info licens { // Create a different subscription and license that's not a dev tag, // created at the same time, to ensure we don't use it - u, err := dotcomdb.Users().Create(ctx, database.NewUser{Username: "not-dev"}) + u, err := dotcomdb.Users().Create(ctx, sgdatabase.NewUser{Username: "not-dev"}) require.NoError(t, err) sub, err := subscriptionsdb.Create(ctx, u.ID, u.Username) require.NoError(t, err) @@ -138,7 +164,7 @@ func setupDBAndInsertMockLicense(t *testing.T, dotcomdb database.DB, info licens } // Create the subscription we will assert against - u, err := dotcomdb.Users().Create(ctx, database.NewUser{Username: "user"}) + u, err := dotcomdb.Users().Create(ctx, sgdatabase.NewUser{Username: "user"}) require.NoError(t, err) subid, err := subscriptionsdb.Create(ctx, u.ID, u.Username) require.NoError(t, err) @@ -150,7 +176,7 @@ func setupDBAndInsertMockLicense(t *testing.T, dotcomdb database.DB, info licens result.accessTokens = append(result.accessTokens, license.GenerateLicenseKeyBasedAccessToken(key1)) _, err = licensesdb.Create(ctx, subid, key1, 2, license.Info{ CreatedAt: info.CreatedAt.Add(-time.Hour), - ExpiresAt: info.ExpiresAt.Add(-time.Hour), + ExpiresAt: info.ExpiresAt.Add(-time.Minute), // should expire first, but not be expired Tags: []string{licensing.DevTag}, }) require.NoError(t, err) @@ -170,7 +196,7 @@ func setupDBAndInsertMockLicense(t *testing.T, dotcomdb database.DB, info licens { // Create another different subscription and license that's also rubbish, // created at the same time, to ensure we don't use it - u, err := dotcomdb.Users().Create(ctx, database.NewUser{Username: "foobar"}) + u, err := dotcomdb.Users().Create(ctx, sgdatabase.NewUser{Username: "foobar"}) require.NoError(t, err) sub, err := subscriptionsdb.Create(ctx, u.ID, u.Username) require.NoError(t, err) @@ -198,7 +224,7 @@ func TestGetCodyGatewayAccessAttributes(t *testing.T) { t.Parallel() ctx := context.Background() - for _, tc := range []struct { + for i, tc := range []struct { name string info license.Info cgAccess graphqlbackend.UpdateCodyGatewayAccessInput @@ -244,16 +270,37 @@ func TestGetCodyGatewayAccessAttributes(t *testing.T) { dotcomdb, dotcomreader := newTestDotcomReader(t, dotcomdb.ReaderOptions{ DevOnly: true, }) + // First, set up a subscription and license and some other rubbish // data to ensure we only get the license we want. mock := setupDBAndInsertMockLicense(t, dotcomdb, tc.info, &tc.cgAccess) + // Now import the data for parity so we can compare against the + // Enterprise Portal implementation + epDB := &database.DB{ + DB: databasetest.NewTestDB(t, "ep-dotcomdb", fmt.Sprintf("get-attributes-%d", i), databasetest.Tables(t)...), + } + err := importer.NewHandler(ctx, logtest.Scoped(t), dotcomreader, epDB).Handle(ctx) + require.NoError(t, err) + codyAccessService := codyaccessservice.NewHandlerV1(logtest.Scoped(t), mockCodyAccessV1Store{ + StoreV1: codyaccessservice.NewStoreV1(codyaccessservice.StoreV1Options{ + DB: epDB, + }), + }) + t.Run("by subscription ID", func(t *testing.T) { t.Parallel() attr, err := dotcomreader.GetCodyGatewayAccessAttributesBySubscription(ctx, mock.targetSubscriptionID) require.NoError(t, err) - validateAccessAttributes(t, dotcomdb, mock, attr, tc.info) + access, err := codyAccessService.GetCodyGatewayAccess(ctx, mockAuthenticatedServiceRequest(&codyaccessv1.GetCodyGatewayAccessRequest{ + Query: &codyaccessv1.GetCodyGatewayAccessRequest_SubscriptionId{ + SubscriptionId: mock.targetSubscriptionID, + }, + })) + require.NoError(t, err) + + validateAccessAttributes(t, dotcomdb, mock, attr, access.Msg.GetAccess(), tc.info) }) t.Run("by access token", func(t *testing.T) { @@ -266,7 +313,13 @@ func TestGetCodyGatewayAccessAttributes(t *testing.T) { attr, err := dotcomreader.GetCodyGatewayAccessAttributesByAccessToken(ctx, token) require.NoError(t, err) - validateAccessAttributes(t, dotcomdb, mock, attr, tc.info) + access, err := codyAccessService.GetCodyGatewayAccess(ctx, mockAuthenticatedServiceRequest(&codyaccessv1.GetCodyGatewayAccessRequest{ + Query: &codyaccessv1.GetCodyGatewayAccessRequest_AccessToken{ + AccessToken: token, + }, + })) + require.NoError(t, err) + validateAccessAttributes(t, dotcomdb, mock, attr, access.Msg.GetAccess(), tc.info) t.Run("compare with dotcom tokens DB", func(t *testing.T) { subID, err := dotcomproductsubscriptiontest.NewTokensDB(t, dotcomdb). @@ -281,11 +334,20 @@ func TestGetCodyGatewayAccessAttributes(t *testing.T) { } } -func validateAccessAttributes(t *testing.T, dotcomdb database.DB, mock mockedData, attr *dotcomdb.CodyGatewayAccessAttributes, info license.Info) { +func validateAccessAttributes(t *testing.T, dotcomdb sgdatabase.DB, mock mockedData, attr *dotcomdb.CodyGatewayAccessAttributes, access *codyaccessv1.CodyGatewayAccess, info license.Info) { assert.Equal(t, mock.targetSubscriptionID, attr.SubscriptionID) + assert.Equal(t, subscriptionsv1.EnterpriseSubscriptionIDPrefix+mock.targetSubscriptionID, access.SubscriptionId) + assert.Equal(t, int(info.UserCount), *attr.ActiveLicenseUserCount) assert.Len(t, attr.LicenseKeyHashes, 2) - assert.Equal(t, attr.GenerateAccessTokens(), mock.accessTokens) + + assert.Equal(t, mock.accessTokens, attr.GenerateAccessTokens()) + var protoAccessTokens []string + for _, t := range access.AccessTokens { + protoAccessTokens = append(protoAccessTokens, t.GetToken()) + } + assert.ElementsMatch(t, mock.accessTokens, protoAccessTokens) + limits := attr.EvaluateRateLimits() // Validate against the expected values as produced by existing resolvers @@ -300,30 +362,37 @@ func validateAccessAttributes(t *testing.T, dotcomdb database.DB, mock mockedDat name string expected graphqlbackend.CodyGatewayRateLimit got licensing.CodyGatewayRateLimit + gotProto *codyaccessv1.CodyGatewayRateLimit }{{ name: "Chat", expected: mustWithCtx(t, expected.ChatCompletionsRateLimit), got: limits.Chat, + gotProto: access.ChatCompletionsRateLimit, }, { name: "Code", expected: mustWithCtx(t, expected.CodeCompletionsRateLimit), got: limits.Code, + gotProto: access.CodeCompletionsRateLimit, }, { name: "Embeddings", expected: mustWithCtx(t, expected.EmbeddingsRateLimit), got: limits.Embeddings, + gotProto: access.EmbeddingsRateLimit, }} { t.Run(compare.name, func(t *testing.T) { // We only care about limit and interval now assert.Equal(t, int64(compare.expected.Limit()), compare.got.Limit, "Limit") assert.Equal(t, compare.expected.IntervalSeconds(), compare.got.IntervalSeconds, "IntervalSeconds") + + assert.Equal(t, uint64(compare.expected.Limit()), compare.gotProto.Limit, "Limit") + assert.Equal(t, int64(compare.expected.IntervalSeconds()), compare.gotProto.IntervalDuration.Seconds, "IntervalSeconds") }) } } func TestGetAllCodyGatewayAccessAttributes(t *testing.T) { t.Parallel() - dotcomdb, dotcomreader := newTestDotcomReader(t, dotcomdb.ReaderOptions{ + dotcomDB, dotcomreader := newTestDotcomReader(t, dotcomdb.ReaderOptions{ DevOnly: true, }) @@ -334,21 +403,53 @@ func TestGetAllCodyGatewayAccessAttributes(t *testing.T) { Tags: []string{licensing.PlanEnterprise1.Tag(), licensing.DevTag}, } cgAccess := graphqlbackend.UpdateCodyGatewayAccessInput{Enabled: pointers.Ptr(true)} - mock := setupDBAndInsertMockLicense(t, dotcomdb, info, &cgAccess) + mock := setupDBAndInsertMockLicense(t, dotcomDB, info, &cgAccess) - attrs, err := dotcomreader.GetAllCodyGatewayAccessAttributes(context.Background()) + // Now import the data for parity so we can compare against the + // Enterprise Portal implementation + ctx := context.Background() + epDB := &database.DB{ + DB: databasetest.NewTestDB(t, "ep-dotcomdb", "get-all-attributes", databasetest.Tables(t)...), + } + err := importer.NewHandler(ctx, logtest.Scoped(t), dotcomreader, epDB).Handle(ctx) + require.NoError(t, err) + codyAccessService := codyaccessservice.NewHandlerV1(logtest.Scoped(t), mockCodyAccessV1Store{ + StoreV1: codyaccessservice.NewStoreV1(codyaccessservice.StoreV1Options{ + DB: epDB, + }), + }) + + attrs, err := dotcomreader.GetAllCodyGatewayAccessAttributes(ctx) require.NoError(t, err) assert.Len(t, attrs, 3) // 3 subscriptions created in setupDBAndInsertMockLicense - var found bool + + accesses, err := codyAccessService.ListCodyGatewayAccesses(ctx, mockAuthenticatedServiceRequest(&codyaccessv1.ListCodyGatewayAccessesRequest{})) + require.NoError(t, err) + assert.Len(t, accesses.Msg.GetAccesses(), 3) // 3 subscriptions created in setupDBAndInsertMockLicense + + var ( + foundAttr *dotcomdb.CodyGatewayAccessAttributes + foundAccess *codyaccessv1.CodyGatewayAccess + ) for _, attr := range attrs { if attr.SubscriptionID == mock.targetSubscriptionID { - found = true - validateAccessAttributes(t, dotcomdb, mock, attr, info) + foundAttr = attr } else { assert.False(t, attr.CodyGatewayEnabled) } } - assert.True(t, found) + for _, access := range accesses.Msg.GetAccesses() { + if access.SubscriptionId == subscriptionsv1.EnterpriseSubscriptionIDPrefix+mock.targetSubscriptionID { + foundAccess = access + } else { + assert.False(t, access.Enabled) + } + } + require.NotNil(t, foundAttr) + require.NotNil(t, foundAccess) + + validateAccessAttributes(t, dotcomDB, mock, foundAttr, foundAccess, info) + } func TestListEnterpriseSubscriptionLicenses(t *testing.T) { @@ -378,7 +479,7 @@ func TestListEnterpriseSubscriptionLicenses(t *testing.T) { ctx := context.Background() for _, tc := range []struct { name string - filters []*v1.ListEnterpriseSubscriptionLicensesFilter + filters []*subscriptionsv1.ListEnterpriseSubscriptionLicensesFilter pageSize int expect func(t *testing.T, licenses []*dotcomdb.LicenseAttributes) }{{ @@ -390,8 +491,8 @@ func TestListEnterpriseSubscriptionLicenses(t *testing.T) { }, }, { name: "filter by subscription ID", - filters: []*v1.ListEnterpriseSubscriptionLicensesFilter{{ - Filter: &v1.ListEnterpriseSubscriptionLicensesFilter_SubscriptionId{ + filters: []*subscriptionsv1.ListEnterpriseSubscriptionLicensesFilter{{ + Filter: &subscriptionsv1.ListEnterpriseSubscriptionLicensesFilter_SubscriptionId{ SubscriptionId: mock.targetSubscriptionID, }, }}, @@ -403,8 +504,8 @@ func TestListEnterpriseSubscriptionLicenses(t *testing.T) { }, }, { name: "filter by subscription ID and limit 1", - filters: []*v1.ListEnterpriseSubscriptionLicensesFilter{{ - Filter: &v1.ListEnterpriseSubscriptionLicensesFilter_SubscriptionId{ + filters: []*subscriptionsv1.ListEnterpriseSubscriptionLicensesFilter{{ + Filter: &subscriptionsv1.ListEnterpriseSubscriptionLicensesFilter_SubscriptionId{ SubscriptionId: mock.targetSubscriptionID, }, }}, @@ -415,12 +516,12 @@ func TestListEnterpriseSubscriptionLicenses(t *testing.T) { }, }, { name: "filter by subscription ID and not archived", - filters: []*v1.ListEnterpriseSubscriptionLicensesFilter{{ - Filter: &v1.ListEnterpriseSubscriptionLicensesFilter_SubscriptionId{ + filters: []*subscriptionsv1.ListEnterpriseSubscriptionLicensesFilter{{ + Filter: &subscriptionsv1.ListEnterpriseSubscriptionLicensesFilter_SubscriptionId{ SubscriptionId: mock.targetSubscriptionID, }, }, { - Filter: &v1.ListEnterpriseSubscriptionLicensesFilter_IsRevoked{ + Filter: &subscriptionsv1.ListEnterpriseSubscriptionLicensesFilter_IsRevoked{ IsRevoked: false, }, }}, @@ -434,8 +535,8 @@ func TestListEnterpriseSubscriptionLicenses(t *testing.T) { }, }, { name: "filter by is archived", - filters: []*v1.ListEnterpriseSubscriptionLicensesFilter{{ - Filter: &v1.ListEnterpriseSubscriptionLicensesFilter_IsRevoked{ + filters: []*subscriptionsv1.ListEnterpriseSubscriptionLicensesFilter{{ + Filter: &subscriptionsv1.ListEnterpriseSubscriptionLicensesFilter_IsRevoked{ IsRevoked: true, }, }}, @@ -444,8 +545,8 @@ func TestListEnterpriseSubscriptionLicenses(t *testing.T) { }, }, { name: "filter by not archived", - filters: []*v1.ListEnterpriseSubscriptionLicensesFilter{{ - Filter: &v1.ListEnterpriseSubscriptionLicensesFilter_IsRevoked{ + filters: []*subscriptionsv1.ListEnterpriseSubscriptionLicensesFilter{{ + Filter: &subscriptionsv1.ListEnterpriseSubscriptionLicensesFilter_IsRevoked{ IsRevoked: false, }, }}, diff --git a/cmd/frontend/internal/dotcom/productsubscription/codygateway_graphql.go b/cmd/frontend/internal/dotcom/productsubscription/codygateway_graphql.go index cdd3ba73d77..62a22af783a 100644 --- a/cmd/frontend/internal/dotcom/productsubscription/codygateway_graphql.go +++ b/cmd/frontend/internal/dotcom/productsubscription/codygateway_graphql.go @@ -91,7 +91,7 @@ func (r codyGatewayAccessResolver) CodeCompletionsRateLimit(ctx context.Context) var source graphqlbackend.CodyGatewayRateLimitSource if activeLicense != nil { source = graphqlbackend.CodyGatewayRateLimitSourcePlan - rateLimit = licensing.NewCodyGatewayCodeRateLimit(licensing.PlanFromTags(activeLicense.LicenseTags), activeLicense.LicenseUserCount, activeLicense.LicenseTags) + rateLimit = licensing.NewCodyGatewayCodeRateLimit(licensing.PlanFromTags(activeLicense.LicenseTags), activeLicense.LicenseUserCount) } // Apply overrides @@ -131,7 +131,7 @@ func (r codyGatewayAccessResolver) EmbeddingsRateLimit(ctx context.Context) (gra var source graphqlbackend.CodyGatewayRateLimitSource if activeLicense != nil { source = graphqlbackend.CodyGatewayRateLimitSourcePlan - rateLimit = licensing.NewCodyGatewayEmbeddingsRateLimit(licensing.PlanFromTags(activeLicense.LicenseTags), activeLicense.LicenseUserCount, activeLicense.LicenseTags) + rateLimit = licensing.NewCodyGatewayEmbeddingsRateLimit(licensing.PlanFromTags(activeLicense.LicenseTags), activeLicense.LicenseUserCount) } // Apply overrides diff --git a/internal/licensing/codygateway.go b/internal/licensing/codygateway.go index 2e665007f04..99264921000 100644 --- a/internal/licensing/codygateway.go +++ b/internal/licensing/codygateway.go @@ -53,7 +53,7 @@ func NewCodyGatewayChatRateLimit(plan Plan, userCount *int) CodyGatewayRateLimit } // NewCodyGatewayCodeRateLimit applies default Cody Gateway access based on the plan. -func NewCodyGatewayCodeRateLimit(plan Plan, userCount *int, licenseTags []string) CodyGatewayRateLimit { +func NewCodyGatewayCodeRateLimit(plan Plan, userCount *int) CodyGatewayRateLimit { uc := 0 if userCount != nil { uc = *userCount @@ -87,7 +87,7 @@ func NewCodyGatewayCodeRateLimit(plan Plan, userCount *int, licenseTags []string const tokensPerDollar = int(1 / (0.0001 / 1_000)) // NewCodyGatewayEmbeddingsRateLimit applies default Cody Gateway access based on the plan. -func NewCodyGatewayEmbeddingsRateLimit(plan Plan, userCount *int, licenseTags []string) CodyGatewayRateLimit { +func NewCodyGatewayEmbeddingsRateLimit(plan Plan, userCount *int) CodyGatewayRateLimit { uc := 0 if userCount != nil { uc = *userCount diff --git a/internal/licensing/codygateway_test.go b/internal/licensing/codygateway_test.go index ac241bcfa41..a61ea875ea9 100644 --- a/internal/licensing/codygateway_test.go +++ b/internal/licensing/codygateway_test.go @@ -104,7 +104,7 @@ func TestCodyGatewayCodeRateLimit(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := NewCodyGatewayCodeRateLimit(tt.plan, tt.userCount, tt.licenseTags) + got := NewCodyGatewayCodeRateLimit(tt.plan, tt.userCount) if diff := cmp.Diff(got, tt.want); diff != "" { t.Fatalf("incorrect rate limit computed: %s", diff) } @@ -151,7 +151,7 @@ func TestCodyGatewayEmbeddingsRateLimit(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := NewCodyGatewayEmbeddingsRateLimit(tt.plan, tt.userCount, tt.licenseTags) + got := NewCodyGatewayEmbeddingsRateLimit(tt.plan, tt.userCount) if diff := cmp.Diff(got, tt.want); diff != "" { t.Fatalf("incorrect rate limit computed: %s", diff) }