mirror of
https://github.com/sourcegraph/sourcegraph.git
synced 2026-02-06 15:31:48 +00:00
feat/enterpriseportal: use database for reading Cody Gateway access
This commit is contained in:
parent
5916b1f248
commit
4064629c48
@ -13,16 +13,22 @@ go_library(
|
||||
visibility = ["//cmd/enterprise-portal:__subpackages__"],
|
||||
deps = [
|
||||
"//cmd/enterprise-portal/internal/connectutil",
|
||||
"//cmd/enterprise-portal/internal/database",
|
||||
"//cmd/enterprise-portal/internal/database/codyaccess",
|
||||
"//cmd/enterprise-portal/internal/dotcomdb",
|
||||
"//cmd/enterprise-portal/internal/samsm2m",
|
||||
"//internal/codygateway/codygatewayactor",
|
||||
"//internal/codygateway/codygatewayevents",
|
||||
"//internal/completions/types",
|
||||
"//internal/license",
|
||||
"//internal/licensing",
|
||||
"//internal/productsubscription",
|
||||
"//internal/trace",
|
||||
"//lib/enterpriseportal/codyaccess/v1:codyaccess",
|
||||
"//lib/enterpriseportal/codyaccess/v1/v1connect",
|
||||
"//lib/enterpriseportal/subscriptions/v1:subscriptions",
|
||||
"//lib/errors",
|
||||
"//lib/pointers",
|
||||
"@com_connectrpc_connect//:connect",
|
||||
"@com_github_sourcegraph_conc//pool",
|
||||
"@com_github_sourcegraph_log//:log",
|
||||
@ -38,7 +44,7 @@ go_test(
|
||||
srcs = ["adapters_test.go"],
|
||||
embed = [":codyaccessservice"],
|
||||
deps = [
|
||||
"//cmd/enterprise-portal/internal/dotcomdb",
|
||||
"//cmd/enterprise-portal/internal/database/codyaccess",
|
||||
"@com_github_hexops_autogold_v2//:autogold",
|
||||
"@com_github_stretchr_testify//assert",
|
||||
],
|
||||
|
||||
@ -1,46 +1,51 @@
|
||||
package codyaccessservice
|
||||
|
||||
import (
|
||||
"encoding/hex"
|
||||
|
||||
"google.golang.org/protobuf/types/known/durationpb"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/dotcomdb"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/codyaccess"
|
||||
"github.com/sourcegraph/sourcegraph/internal/codygateway/codygatewayevents"
|
||||
"github.com/sourcegraph/sourcegraph/internal/license"
|
||||
"github.com/sourcegraph/sourcegraph/internal/licensing"
|
||||
codyaccessv1 "github.com/sourcegraph/sourcegraph/lib/enterpriseportal/codyaccess/v1"
|
||||
subscriptionsv1 "github.com/sourcegraph/sourcegraph/lib/enterpriseportal/subscriptions/v1"
|
||||
"github.com/sourcegraph/sourcegraph/lib/pointers"
|
||||
)
|
||||
|
||||
func convertAccessAttrsToProto(attrs *dotcomdb.CodyGatewayAccessAttributes) *codyaccessv1.CodyGatewayAccess {
|
||||
func convertAccessAttrsToProto(access *codyaccess.CodyGatewayAccessWithSubscriptionDetails) *codyaccessv1.CodyGatewayAccess {
|
||||
// Provide ID in prefixed format.
|
||||
subscriptionID := subscriptionsv1.EnterpriseSubscriptionIDPrefix + attrs.SubscriptionID
|
||||
subscriptionID := subscriptionsv1.EnterpriseSubscriptionIDPrefix + access.SubscriptionID
|
||||
|
||||
// Always try to return the full response, since even when disabled, some
|
||||
// features may be allowed via Cody Gateway (notably attributions). This
|
||||
// also allows Cody Gateway to cache the state of actors that are disabled.
|
||||
limits := attrs.EvaluateRateLimits()
|
||||
limits := evaluateCodyGatewayAccessRateLimits(access)
|
||||
return &codyaccessv1.CodyGatewayAccess{
|
||||
SubscriptionId: subscriptionID,
|
||||
SubscriptionDisplayName: attrs.GetSubscriptionDisplayName(),
|
||||
Enabled: attrs.CodyGatewayEnabled,
|
||||
SubscriptionDisplayName: access.DisplayName,
|
||||
Enabled: access.Enabled,
|
||||
// Rate limits return nil if not enabled, per API spec
|
||||
ChatCompletionsRateLimit: nilIfNotEnabled(attrs.CodyGatewayEnabled, &codyaccessv1.CodyGatewayRateLimit{
|
||||
ChatCompletionsRateLimit: nilIfNotEnabled(access.Enabled, &codyaccessv1.CodyGatewayRateLimit{
|
||||
Source: limits.ChatSource,
|
||||
Limit: uint64(limits.Chat.Limit),
|
||||
IntervalDuration: durationpb.New(limits.Chat.IntervalDuration()),
|
||||
}),
|
||||
CodeCompletionsRateLimit: nilIfNotEnabled(attrs.CodyGatewayEnabled, &codyaccessv1.CodyGatewayRateLimit{
|
||||
CodeCompletionsRateLimit: nilIfNotEnabled(access.Enabled, &codyaccessv1.CodyGatewayRateLimit{
|
||||
Source: limits.CodeSource,
|
||||
Limit: uint64(limits.Code.Limit),
|
||||
IntervalDuration: durationpb.New(limits.Code.IntervalDuration()),
|
||||
}),
|
||||
EmbeddingsRateLimit: nilIfNotEnabled(attrs.CodyGatewayEnabled, &codyaccessv1.CodyGatewayRateLimit{
|
||||
EmbeddingsRateLimit: nilIfNotEnabled(access.Enabled, &codyaccessv1.CodyGatewayRateLimit{
|
||||
Source: limits.EmbeddingsSource,
|
||||
Limit: uint64(limits.Embeddings.Limit),
|
||||
IntervalDuration: durationpb.New(limits.Embeddings.IntervalDuration()),
|
||||
}),
|
||||
// This is always provided, even if access is disabled
|
||||
AccessTokens: func() []*codyaccessv1.CodyGatewayAccessToken {
|
||||
accessTokens := attrs.GenerateAccessTokens()
|
||||
accessTokens := generateCodyGatewayAccessTokens(access)
|
||||
if len(accessTokens) == 0 {
|
||||
return []*codyaccessv1.CodyGatewayAccessToken{}
|
||||
}
|
||||
@ -56,6 +61,18 @@ func convertAccessAttrsToProto(attrs *dotcomdb.CodyGatewayAccessAttributes) *cod
|
||||
}
|
||||
}
|
||||
|
||||
func generateCodyGatewayAccessTokens(access *codyaccess.CodyGatewayAccessWithSubscriptionDetails) []string {
|
||||
accessTokens := make([]string, 0, len(access.LicenseKeyHashes))
|
||||
for _, t := range access.LicenseKeyHashes {
|
||||
if len(t) == 0 { // query can return empty hashes, ignore these
|
||||
continue
|
||||
}
|
||||
// See license.GenerateLicenseKeyBasedAccessToken
|
||||
accessTokens = append(accessTokens, license.LicenseKeyBasedAccessTokenPrefix+hex.EncodeToString(t))
|
||||
}
|
||||
return accessTokens
|
||||
}
|
||||
|
||||
func nilIfNotEnabled[T any](enabled bool, value *T) *T {
|
||||
if !enabled {
|
||||
return nil
|
||||
@ -63,6 +80,68 @@ func nilIfNotEnabled[T any](enabled bool, value *T) *T {
|
||||
return value
|
||||
}
|
||||
|
||||
type CodyGatewayRateLimits struct {
|
||||
ChatSource codyaccessv1.CodyGatewayRateLimitSource
|
||||
Chat licensing.CodyGatewayRateLimit
|
||||
|
||||
CodeSource codyaccessv1.CodyGatewayRateLimitSource
|
||||
Code licensing.CodyGatewayRateLimit
|
||||
|
||||
EmbeddingsSource codyaccessv1.CodyGatewayRateLimitSource
|
||||
Embeddings licensing.CodyGatewayRateLimit
|
||||
}
|
||||
|
||||
func maybeApplyOverride[T ~int32 | ~int64](limit *T, overrideValue T, overrideValid bool) codyaccessv1.CodyGatewayRateLimitSource {
|
||||
if overrideValid {
|
||||
*limit = overrideValue
|
||||
return codyaccessv1.CodyGatewayRateLimitSource_CODY_GATEWAY_RATE_LIMIT_SOURCE_OVERRIDE
|
||||
}
|
||||
// No override
|
||||
return codyaccessv1.CodyGatewayRateLimitSource_CODY_GATEWAY_RATE_LIMIT_SOURCE_PLAN
|
||||
}
|
||||
|
||||
// evaluateCodyGatewayAccessRateLimits returns the current CodyGatewayRateLimits based on the
|
||||
// plan and applying known overrides on top. This closely models the existing
|
||||
// codyGatewayAccessResolver in 'cmd/frontend/internal/dotcom/productsubscription'.
|
||||
func evaluateCodyGatewayAccessRateLimits(access *codyaccess.CodyGatewayAccessWithSubscriptionDetails) CodyGatewayRateLimits {
|
||||
// Set defaults for everything based on active licnese plan and user count.
|
||||
// If there isn't one, zero values apply.
|
||||
activeLicense := pointers.DerefZero(access.ActiveLicenseInfo)
|
||||
p := licensing.PlanFromTags(activeLicense.Tags)
|
||||
userCount := pointers.Ptr(int(activeLicense.UserCount))
|
||||
|
||||
limits := CodyGatewayRateLimits{
|
||||
ChatSource: codyaccessv1.CodyGatewayRateLimitSource_CODY_GATEWAY_RATE_LIMIT_SOURCE_PLAN,
|
||||
Chat: licensing.NewCodyGatewayChatRateLimit(p, userCount),
|
||||
|
||||
CodeSource: codyaccessv1.CodyGatewayRateLimitSource_CODY_GATEWAY_RATE_LIMIT_SOURCE_PLAN,
|
||||
Code: licensing.NewCodyGatewayCodeRateLimit(p, userCount),
|
||||
|
||||
EmbeddingsSource: codyaccessv1.CodyGatewayRateLimitSource_CODY_GATEWAY_RATE_LIMIT_SOURCE_PLAN,
|
||||
Embeddings: licensing.NewCodyGatewayEmbeddingsRateLimit(p, userCount),
|
||||
}
|
||||
|
||||
// Chat
|
||||
limits.ChatSource = maybeApplyOverride(&limits.Chat.Limit,
|
||||
access.ChatCompletionsRateLimit.Int64, access.ChatCompletionsRateLimit.Valid)
|
||||
limits.ChatSource = maybeApplyOverride(&limits.Chat.IntervalSeconds,
|
||||
access.ChatCompletionsRateLimitIntervalSeconds.Int32, access.ChatCompletionsRateLimitIntervalSeconds.Valid)
|
||||
|
||||
// Code
|
||||
limits.CodeSource = maybeApplyOverride(&limits.Code.Limit,
|
||||
access.CodeCompletionsRateLimit.Int64, access.CodeCompletionsRateLimit.Valid)
|
||||
limits.CodeSource = maybeApplyOverride(&limits.Code.IntervalSeconds,
|
||||
access.CodeCompletionsRateLimitIntervalSeconds.Int32, access.CodeCompletionsRateLimitIntervalSeconds.Valid)
|
||||
|
||||
// Embeddings
|
||||
limits.EmbeddingsSource = maybeApplyOverride(&limits.Embeddings.Limit,
|
||||
access.EmbeddingsRateLimit.Int64, access.EmbeddingsRateLimit.Valid)
|
||||
limits.EmbeddingsSource = maybeApplyOverride(&limits.Embeddings.IntervalSeconds,
|
||||
access.EmbeddingsRateLimitIntervalSeconds.Int32, access.EmbeddingsRateLimitIntervalSeconds.Valid)
|
||||
|
||||
return limits
|
||||
}
|
||||
|
||||
func convertCodyGatewayUsageDatapoints(usage []codygatewayevents.SubscriptionUsage) []*codyaccessv1.CodyGatewayUsage_UsageDatapoint {
|
||||
results := make([]*codyaccessv1.CodyGatewayUsage_UsageDatapoint, len(usage))
|
||||
for i, datapoint := range usage {
|
||||
|
||||
@ -7,19 +7,21 @@ import (
|
||||
"github.com/hexops/autogold/v2"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/dotcomdb"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/codyaccess"
|
||||
)
|
||||
|
||||
func TestConvertAccessAttrsToProto(t *testing.T) {
|
||||
t.Run("zero value", func(t *testing.T) {
|
||||
proto := convertAccessAttrsToProto(&dotcomdb.CodyGatewayAccessAttributes{})
|
||||
proto := convertAccessAttrsToProto(&codyaccess.CodyGatewayAccessWithSubscriptionDetails{})
|
||||
assert.False(t, proto.Enabled)
|
||||
})
|
||||
|
||||
t.Run("disabled returns access tokens", func(t *testing.T) {
|
||||
proto := convertAccessAttrsToProto(&dotcomdb.CodyGatewayAccessAttributes{
|
||||
CodyGatewayEnabled: false,
|
||||
LicenseKeyHashes: [][]byte{[]byte("abc"), []byte("efg")},
|
||||
proto := convertAccessAttrsToProto(&codyaccess.CodyGatewayAccessWithSubscriptionDetails{
|
||||
CodyGatewayAccess: codyaccess.CodyGatewayAccess{
|
||||
Enabled: false,
|
||||
},
|
||||
LicenseKeyHashes: [][]byte{[]byte("abc"), []byte("efg")},
|
||||
})
|
||||
assert.False(t, proto.Enabled)
|
||||
// NOTE: These are not real access tokens
|
||||
@ -31,18 +33,22 @@ func TestConvertAccessAttrsToProto(t *testing.T) {
|
||||
})
|
||||
|
||||
t.Run("enabled with empty access token", func(t *testing.T) {
|
||||
proto := convertAccessAttrsToProto(&dotcomdb.CodyGatewayAccessAttributes{
|
||||
CodyGatewayEnabled: true,
|
||||
LicenseKeyHashes: [][]byte{[]byte(""), nil},
|
||||
proto := convertAccessAttrsToProto(&codyaccess.CodyGatewayAccessWithSubscriptionDetails{
|
||||
CodyGatewayAccess: codyaccess.CodyGatewayAccess{
|
||||
Enabled: true,
|
||||
},
|
||||
LicenseKeyHashes: [][]byte{[]byte(""), nil},
|
||||
})
|
||||
assert.True(t, proto.Enabled)
|
||||
assert.Empty(t, proto.GetAccessTokens())
|
||||
})
|
||||
|
||||
t.Run("enabled returns everything", func(t *testing.T) {
|
||||
proto := convertAccessAttrsToProto(&dotcomdb.CodyGatewayAccessAttributes{
|
||||
CodyGatewayEnabled: true,
|
||||
LicenseKeyHashes: [][]byte{[]byte("abc"), []byte("efg")},
|
||||
proto := convertAccessAttrsToProto(&codyaccess.CodyGatewayAccessWithSubscriptionDetails{
|
||||
CodyGatewayAccess: codyaccess.CodyGatewayAccess{
|
||||
Enabled: true,
|
||||
},
|
||||
LicenseKeyHashes: [][]byte{[]byte("abc"), []byte("efg")},
|
||||
})
|
||||
assert.True(t, proto.Enabled)
|
||||
// NOTE: These are not real access tokens
|
||||
|
||||
@ -12,6 +12,7 @@ import (
|
||||
"github.com/sourcegraph/sourcegraph-accounts-sdk-go/scopes"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/connectutil"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/codyaccess"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/dotcomdb"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/samsm2m"
|
||||
"github.com/sourcegraph/sourcegraph/internal/trace"
|
||||
@ -23,42 +24,40 @@ import (
|
||||
|
||||
const Name = codyaccessv1connect.CodyAccessServiceName
|
||||
|
||||
type DotComDB interface {
|
||||
GetCodyGatewayAccessAttributesBySubscription(ctx context.Context, subscriptionID string) (*dotcomdb.CodyGatewayAccessAttributes, error)
|
||||
GetCodyGatewayAccessAttributesByAccessToken(ctx context.Context, subscriptionID string) (*dotcomdb.CodyGatewayAccessAttributes, error)
|
||||
GetAllCodyGatewayAccessAttributes(ctx context.Context) ([]*dotcomdb.CodyGatewayAccessAttributes, error)
|
||||
}
|
||||
|
||||
func RegisterV1(
|
||||
logger log.Logger,
|
||||
mux *http.ServeMux,
|
||||
store StoreV1,
|
||||
dotcom DotComDB,
|
||||
opts ...connect.HandlerOption,
|
||||
) {
|
||||
mux.Handle(
|
||||
codyaccessv1connect.NewCodyAccessServiceHandler(
|
||||
&handlerV1{
|
||||
logger: logger.Scoped("codyaccess.v1"),
|
||||
store: store,
|
||||
dotcom: dotcom,
|
||||
},
|
||||
NewHandlerV1(logger, store),
|
||||
opts...,
|
||||
),
|
||||
)
|
||||
}
|
||||
|
||||
type handlerV1 struct {
|
||||
type HandlerV1 struct {
|
||||
codyaccessv1connect.UnimplementedCodyAccessServiceHandler
|
||||
|
||||
logger log.Logger
|
||||
store StoreV1
|
||||
dotcom DotComDB
|
||||
}
|
||||
|
||||
var _ codyaccessv1connect.CodyAccessServiceHandler = (*handlerV1)(nil)
|
||||
func NewHandlerV1(
|
||||
logger log.Logger,
|
||||
store StoreV1,
|
||||
) *HandlerV1 {
|
||||
return &HandlerV1{
|
||||
logger: logger.Scoped("codyaccess.v1"),
|
||||
store: store,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *handlerV1) GetCodyGatewayAccess(ctx context.Context, req *connect.Request[codyaccessv1.GetCodyGatewayAccessRequest]) (*connect.Response[codyaccessv1.GetCodyGatewayAccessResponse], error) {
|
||||
var _ codyaccessv1connect.CodyAccessServiceHandler = (*HandlerV1)(nil)
|
||||
|
||||
func (s *HandlerV1) GetCodyGatewayAccess(ctx context.Context, req *connect.Request[codyaccessv1.GetCodyGatewayAccessRequest]) (*connect.Response[codyaccessv1.GetCodyGatewayAccessResponse], error) {
|
||||
logger := trace.Logger(ctx, s.logger).
|
||||
With(log.String("queryType", fmt.Sprintf("%T", req.Msg.GetQuery())))
|
||||
|
||||
@ -70,19 +69,19 @@ func (s *handlerV1) GetCodyGatewayAccess(ctx context.Context, req *connect.Reque
|
||||
}
|
||||
logger = logger.With(clientAttrs...)
|
||||
|
||||
var attr *dotcomdb.CodyGatewayAccessAttributes
|
||||
var attr *codyaccess.CodyGatewayAccessWithSubscriptionDetails
|
||||
switch query := req.Msg.GetQuery().(type) {
|
||||
case *codyaccessv1.GetCodyGatewayAccessRequest_SubscriptionId:
|
||||
if len(query.SubscriptionId) == 0 {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, errors.New("invalid query: subscription ID"))
|
||||
}
|
||||
attr, err = s.dotcom.GetCodyGatewayAccessAttributesBySubscription(ctx, query.SubscriptionId)
|
||||
attr, err = s.store.GetCodyGatewayAccessBySubscription(ctx, query.SubscriptionId)
|
||||
|
||||
case *codyaccessv1.GetCodyGatewayAccessRequest_AccessToken:
|
||||
if len(query.AccessToken) == 0 {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, errors.New("invalid query: access token"))
|
||||
}
|
||||
attr, err = s.dotcom.GetCodyGatewayAccessAttributesByAccessToken(ctx, query.AccessToken)
|
||||
attr, err = s.store.GetCodyGatewayAccessByAccessToken(ctx, query.AccessToken)
|
||||
|
||||
default:
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, errors.New("invalid query"))
|
||||
@ -103,7 +102,7 @@ func (s *handlerV1) GetCodyGatewayAccess(ctx context.Context, req *connect.Reque
|
||||
}), nil
|
||||
}
|
||||
|
||||
func (s *handlerV1) ListCodyGatewayAccesses(ctx context.Context, req *connect.Request[codyaccessv1.ListCodyGatewayAccessesRequest]) (*connect.Response[codyaccessv1.ListCodyGatewayAccessesResponse], error) {
|
||||
func (s *HandlerV1) ListCodyGatewayAccesses(ctx context.Context, req *connect.Request[codyaccessv1.ListCodyGatewayAccessesRequest]) (*connect.Response[codyaccessv1.ListCodyGatewayAccessesResponse], error) {
|
||||
logger := trace.Logger(ctx, s.logger)
|
||||
|
||||
// 🚨 SECURITY: Require approrpiate M2M scope.
|
||||
@ -122,7 +121,7 @@ func (s *handlerV1) ListCodyGatewayAccesses(ctx context.Context, req *connect.Re
|
||||
return nil, connect.NewError(connect.CodeUnimplemented, errors.New("pagination not implemented"))
|
||||
}
|
||||
|
||||
attrs, err := s.dotcom.GetAllCodyGatewayAccessAttributes(ctx)
|
||||
attrs, err := s.store.ListCodyGatewayAccesses(ctx)
|
||||
if err != nil {
|
||||
if err == dotcomdb.ErrCodyGatewayAccessNotFound {
|
||||
return nil, connect.NewError(connect.CodeNotFound, err)
|
||||
@ -146,7 +145,7 @@ func (s *handlerV1) ListCodyGatewayAccesses(ctx context.Context, req *connect.Re
|
||||
return connect.NewResponse(&resp), nil
|
||||
}
|
||||
|
||||
func (s *handlerV1) GetCodyGatewayUsage(ctx context.Context, req *connect.Request[codyaccessv1.GetCodyGatewayUsageRequest]) (*connect.Response[codyaccessv1.GetCodyGatewayUsageResponse], error) {
|
||||
func (s *HandlerV1) GetCodyGatewayUsage(ctx context.Context, req *connect.Request[codyaccessv1.GetCodyGatewayUsageRequest]) (*connect.Response[codyaccessv1.GetCodyGatewayUsageResponse], error) {
|
||||
logger := trace.Logger(ctx, s.logger)
|
||||
|
||||
// 🚨 SECURITY: Require appropriate M2M scope.
|
||||
|
||||
@ -2,13 +2,19 @@ package codyaccessservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
|
||||
"github.com/sourcegraph/conc/pool"
|
||||
|
||||
sams "github.com/sourcegraph/sourcegraph-accounts-sdk-go"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/codyaccess"
|
||||
"github.com/sourcegraph/sourcegraph/internal/codygateway/codygatewayactor"
|
||||
"github.com/sourcegraph/sourcegraph/internal/codygateway/codygatewayevents"
|
||||
"github.com/sourcegraph/sourcegraph/internal/completions/types"
|
||||
"github.com/sourcegraph/sourcegraph/internal/license"
|
||||
"github.com/sourcegraph/sourcegraph/internal/productsubscription"
|
||||
codyaccessv1 "github.com/sourcegraph/sourcegraph/lib/enterpriseportal/codyaccess/v1"
|
||||
subscriptionsv1 "github.com/sourcegraph/sourcegraph/lib/enterpriseportal/subscriptions/v1"
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
@ -24,18 +30,33 @@ type StoreV1 interface {
|
||||
// is no longer active. It is critical that the caller not honor tokens where
|
||||
// `.Active == false`.
|
||||
IntrospectSAMSToken(ctx context.Context, token string) (*sams.IntrospectTokenResponse, error)
|
||||
|
||||
// GetCodyGatewayUsage retrieves recent Cody Gateway usage data.
|
||||
// The subscriptionID should not be prefixed.
|
||||
GetCodyGatewayUsage(ctx context.Context, subscriptionID string) (*codyaccessv1.CodyGatewayUsage, error)
|
||||
|
||||
// GetCodyGatewayAccessBySubscription retrieves Cody Gateway access by
|
||||
// subscription ID.
|
||||
GetCodyGatewayAccessBySubscription(ctx context.Context, subscriptionID string) (*codyaccess.CodyGatewayAccessWithSubscriptionDetails, error)
|
||||
|
||||
// GetCodyGatewayAccessByAccessToken retrieves Cody Gateway access details
|
||||
// associated with the given access token.
|
||||
GetCodyGatewayAccessByAccessToken(ctx context.Context, token string) (*codyaccess.CodyGatewayAccessWithSubscriptionDetails, error)
|
||||
|
||||
// ListCodyGatewayAccesses retrieves all Cody Gateway accesses with their
|
||||
// associated subscription details.
|
||||
ListCodyGatewayAccesses(ctx context.Context) ([]*codyaccess.CodyGatewayAccessWithSubscriptionDetails, error)
|
||||
}
|
||||
|
||||
type storeV1 struct {
|
||||
SAMSClient *sams.ClientV1
|
||||
CodyAccess *codyaccess.Store
|
||||
CodyGatewayEvents *codygatewayevents.Service
|
||||
}
|
||||
|
||||
type StoreV1Options struct {
|
||||
SAMSClient *sams.ClientV1
|
||||
DB *database.DB
|
||||
// Optional.
|
||||
CodyGatewayEvents *codygatewayevents.Service
|
||||
}
|
||||
@ -46,6 +67,7 @@ var errStoreUnimplemented = errors.New("unimplemented")
|
||||
func NewStoreV1(opts StoreV1Options) StoreV1 {
|
||||
return &storeV1{
|
||||
SAMSClient: opts.SAMSClient,
|
||||
CodyAccess: opts.DB.CodyAccess(),
|
||||
CodyGatewayEvents: opts.CodyGatewayEvents,
|
||||
}
|
||||
}
|
||||
@ -126,4 +148,34 @@ func (s *storeV1) GetCodyGatewayUsage(ctx context.Context, subscriptionID string
|
||||
applyResult(usage)
|
||||
}
|
||||
return usage, nil
|
||||
|
||||
}
|
||||
|
||||
func (s *storeV1) GetCodyGatewayAccessBySubscription(ctx context.Context, subscriptionID string) (*codyaccess.CodyGatewayAccessWithSubscriptionDetails, error) {
|
||||
return s.CodyAccess.CodyGateway().Get(ctx, codyaccess.GetCodyGatewayAccessOptions{
|
||||
SubscriptionID: subscriptionID,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *storeV1) GetCodyGatewayAccessByAccessToken(ctx context.Context, token string) (*codyaccess.CodyGatewayAccessWithSubscriptionDetails, error) {
|
||||
// Below is copied from 'func (t dbTokens) LookupProductSubscriptionIDByAccessToken'
|
||||
// in 'cmd/frontend/internal/dotcom/productsubscription'.
|
||||
if !strings.HasPrefix(token, productsubscription.AccessTokenPrefix) &&
|
||||
!strings.HasPrefix(token, license.LicenseKeyBasedAccessTokenPrefix) {
|
||||
return nil, errors.WithSafeDetails(codyaccess.ErrSubscriptionDoesNotExist, "invalid token with unknown prefix")
|
||||
}
|
||||
tokenSansPrefix := token[len(license.LicenseKeyBasedAccessTokenPrefix):]
|
||||
decoded, err := hex.DecodeString(tokenSansPrefix)
|
||||
if err != nil {
|
||||
return nil, errors.WithSafeDetails(codyaccess.ErrSubscriptionDoesNotExist, "invalid token with unknown encoding")
|
||||
}
|
||||
// End copied code.
|
||||
|
||||
return s.CodyAccess.CodyGateway().Get(ctx, codyaccess.GetCodyGatewayAccessOptions{
|
||||
LicenseKeyHash: decoded,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *storeV1) ListCodyGatewayAccesses(ctx context.Context) ([]*codyaccess.CodyGatewayAccessWithSubscriptionDetails, error) {
|
||||
return s.CodyAccess.CodyGateway().List(ctx)
|
||||
}
|
||||
|
||||
@ -95,7 +95,8 @@ func scanCodyGatewayAccess(row pgx.Row) (*CodyGatewayAccessWithSubscriptionDetai
|
||||
// an active subscription exists, but explicit access is not configured. In
|
||||
// this case we still need to return a valid CodyGatewayAccessWithSubscriptionDetails,
|
||||
// just with empty fields.
|
||||
var maybeEnabled *bool
|
||||
var maybeEnabled sql.NullBool
|
||||
var maybeDisplayName sql.NullString
|
||||
err := row.Scan(
|
||||
&a.SubscriptionID,
|
||||
&maybeEnabled,
|
||||
@ -106,7 +107,7 @@ func scanCodyGatewayAccess(row pgx.Row) (*CodyGatewayAccessWithSubscriptionDetai
|
||||
&a.EmbeddingsRateLimit,
|
||||
&a.EmbeddingsRateLimitIntervalSeconds,
|
||||
// Subscriptions fields
|
||||
&a.DisplayName,
|
||||
&maybeDisplayName,
|
||||
// License fields
|
||||
&a.ActiveLicenseInfo,
|
||||
&a.LicenseKeyHashes,
|
||||
@ -114,9 +115,9 @@ func scanCodyGatewayAccess(row pgx.Row) (*CodyGatewayAccessWithSubscriptionDetai
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if maybeEnabled != nil {
|
||||
a.Enabled = *maybeEnabled
|
||||
}
|
||||
a.Enabled = maybeEnabled.Bool
|
||||
a.DisplayName = maybeDisplayName.String
|
||||
|
||||
return &a, nil
|
||||
}
|
||||
|
||||
@ -181,27 +182,55 @@ type CodyGatewayAccessWithSubscriptionDetails struct {
|
||||
DisplayName string
|
||||
|
||||
ActiveLicenseInfo *license.Info
|
||||
LicenseKeyHashes [][]byte
|
||||
|
||||
// Used by GenerateAccessTokens
|
||||
LicenseKeyHashes [][]byte
|
||||
}
|
||||
|
||||
var ErrSubscriptionDoesNotExist = errors.New("subscription does not exist")
|
||||
|
||||
type GetCodyGatewayAccessOptions struct {
|
||||
SubscriptionID string
|
||||
LicenseKeyHash []byte
|
||||
}
|
||||
|
||||
func (opts GetCodyGatewayAccessOptions) buildConds() (string, pgx.NamedArgs, error) {
|
||||
if opts.SubscriptionID == "" && len(opts.LicenseKeyHash) == 0 {
|
||||
return "", nil, errors.New("must specify either SubscriptionID or LicenseKeyHash")
|
||||
}
|
||||
|
||||
args := pgx.NamedArgs{}
|
||||
conds := []string{"TRUE"}
|
||||
if opts.SubscriptionID != "" {
|
||||
conds = append(conds, "subscription.id = @subscriptionID")
|
||||
args["subscriptionID"] = opts.SubscriptionID
|
||||
}
|
||||
if len(opts.LicenseKeyHash) > 0 {
|
||||
conds = append(conds, "@licenseKeyHash = ANY(tokens.license_key_hashes)")
|
||||
args["licenseKeyHash"] = opts.LicenseKeyHash
|
||||
}
|
||||
return strings.Join(conds, " AND "), args, nil
|
||||
}
|
||||
|
||||
// Get returns the Cody Gateway access for the given subscription.
|
||||
func (s *CodyGatewayStore) Get(ctx context.Context, subscriptionID string) (*CodyGatewayAccessWithSubscriptionDetails, error) {
|
||||
func (s *CodyGatewayStore) Get(ctx context.Context, opts GetCodyGatewayAccessOptions) (*CodyGatewayAccessWithSubscriptionDetails, error) {
|
||||
conds, args, err := opts.buildConds()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
query := fmt.Sprintf(`SELECT
|
||||
%s
|
||||
FROM
|
||||
enterprise_portal_cody_gateway_access AS access
|
||||
%s
|
||||
WHERE
|
||||
subscription.id = @subscriptionID
|
||||
%s
|
||||
AND subscription.archived_at IS NULL`,
|
||||
strings.Join(codyGatewayAccessTableColumns(), ", "),
|
||||
codyGatewayAccessJoinClauses)
|
||||
codyGatewayAccessJoinClauses,
|
||||
conds)
|
||||
|
||||
sub, err := scanCodyGatewayAccess(s.db.QueryRow(ctx, query, pgx.NamedArgs{
|
||||
"subscriptionID": subscriptionID,
|
||||
}))
|
||||
sub, err := scanCodyGatewayAccess(s.db.QueryRow(ctx, query, args))
|
||||
if err != nil {
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
// RIGHT JOIN in query ensures that if we find no result, it's
|
||||
@ -297,5 +326,5 @@ func (s *CodyGatewayStore) Upsert(ctx context.Context, subscriptionID string, op
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return s.Get(ctx, subscriptionID)
|
||||
return s.Get(ctx, GetCodyGatewayAccessOptions{SubscriptionID: subscriptionID})
|
||||
}
|
||||
|
||||
@ -149,7 +149,9 @@ func TestCodyGatewayStore(t *testing.T) {
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = codyaccess.NewCodyGatewayStore(db).Get(ctx, subscriptionID)
|
||||
_, err = codyaccess.NewCodyGatewayStore(db).Get(ctx, codyaccess.GetCodyGatewayAccessOptions{
|
||||
SubscriptionID: subscriptionID,
|
||||
})
|
||||
assert.ErrorIs(t, err, codyaccess.ErrSubscriptionDoesNotExist)
|
||||
})
|
||||
})
|
||||
@ -214,15 +216,36 @@ func CodyGatewayStoreListAndGet(t *testing.T, ctx context.Context, subscriptionI
|
||||
t.Run("Get", func(t *testing.T) {
|
||||
for idx, sub := range subscriptionIDs {
|
||||
t.Run(fmt.Sprintf("idx=%d", idx), func(t *testing.T) {
|
||||
got, err := s.Get(ctx, sub)
|
||||
got, err := s.Get(ctx, codyaccess.GetCodyGatewayAccessOptions{
|
||||
SubscriptionID: sub,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
assertAccess(idx, got)
|
||||
|
||||
// Reverse lookup by license key hash
|
||||
for _, hash := range got.LicenseKeyHashes {
|
||||
got2, err := s.Get(ctx, codyaccess.GetCodyGatewayAccessOptions{
|
||||
LicenseKeyHash: hash,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, got2.LicenseKeyHashes, 2) // 2 valid licenses
|
||||
assert.Equal(t, got, got2)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("ErrSubscriptionDoesNotExist", func(t *testing.T) {
|
||||
_, err := s.Get(ctx, uuid.NewString())
|
||||
_, err := s.Get(ctx, codyaccess.GetCodyGatewayAccessOptions{
|
||||
SubscriptionID: uuid.NewString(),
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, codyaccess.ErrSubscriptionDoesNotExist)
|
||||
|
||||
_, err = s.Get(ctx, codyaccess.GetCodyGatewayAccessOptions{
|
||||
LicenseKeyHash: []byte(uuid.NewString()),
|
||||
})
|
||||
assert.Error(t, err)
|
||||
assert.ErrorIs(t, err, codyaccess.ErrSubscriptionDoesNotExist)
|
||||
})
|
||||
})
|
||||
@ -241,7 +264,9 @@ func CodyGatewayStoreUpsert(t *testing.T, ctx context.Context, subscriptionIDs [
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
got, err := s.Get(ctx, currentAccess.SubscriptionID)
|
||||
got, err := s.Get(ctx, codyaccess.GetCodyGatewayAccessOptions{
|
||||
SubscriptionID: currentAccess.SubscriptionID,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
assert.False(t, got.Enabled)
|
||||
assert.Equal(t, currentAccess.SubscriptionID, got.SubscriptionID)
|
||||
|
||||
@ -20,15 +20,15 @@ var databaseTracer = otel.Tracer("enterprise-portal/internal/database")
|
||||
|
||||
// DB is the database handle for the storage layer.
|
||||
type DB struct {
|
||||
db *pgxpool.Pool
|
||||
DB *pgxpool.Pool
|
||||
}
|
||||
|
||||
func (db *DB) Subscriptions() *subscriptions.Store {
|
||||
return subscriptions.NewStore(db.db)
|
||||
return subscriptions.NewStore(db.DB)
|
||||
}
|
||||
|
||||
func (db *DB) CodyAccess() *codyaccess.Store {
|
||||
return codyaccess.NewStore(db.db)
|
||||
return codyaccess.NewStore(db.DB)
|
||||
}
|
||||
|
||||
func databaseName(msp bool) string {
|
||||
@ -53,11 +53,11 @@ func NewHandle(ctx context.Context, logger log.Logger, contract runtime.Contract
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "get connection pool")
|
||||
}
|
||||
return &DB{db: pool}, nil
|
||||
return &DB{DB: pool}, nil
|
||||
}
|
||||
|
||||
// Close closes all connections in the pool and rejects future Acquire calls.
|
||||
// Blocks until all connections are returned to pool and closed.
|
||||
func (db *DB) Close() {
|
||||
db.db.Close()
|
||||
db.DB.Close()
|
||||
}
|
||||
|
||||
@ -7,6 +7,7 @@ go_library(
|
||||
tags = [TAG_INFRA_CORESERVICES],
|
||||
visibility = ["//cmd/enterprise-portal:__subpackages__"],
|
||||
deps = [
|
||||
"//cmd/enterprise-portal/internal/database/internal/tables",
|
||||
"//cmd/enterprise-portal/internal/database/internal/tables/custommigrator",
|
||||
"//internal/database/dbtest",
|
||||
"@com_github_jackc_pgx_v5//:pgx",
|
||||
|
||||
@ -16,10 +16,13 @@ import (
|
||||
"gorm.io/gorm"
|
||||
"gorm.io/gorm/schema"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/internal/tables/custommigrator"
|
||||
"github.com/sourcegraph/sourcegraph/internal/database/dbtest"
|
||||
)
|
||||
|
||||
func Tables(_ *testing.T) []schema.Tabler { return tables.All() }
|
||||
|
||||
// NewTestDB creates a new test database and initializes the given list of
|
||||
// tables for the suite. The test database is dropped after testing is completed
|
||||
// unless failed.
|
||||
@ -37,6 +40,7 @@ func NewTestDB(t testing.TB, system, suite string, tables ...schema.Tabler) *pgx
|
||||
|
||||
// Set up test suite database.
|
||||
dbName := fmt.Sprintf("sourcegraph-test-%s-%s-%d", system, suite, time.Now().Unix())
|
||||
t.Logf("Preparing database %s", dbName)
|
||||
_, err = sqlDB.Exec(fmt.Sprintf("DROP DATABASE IF EXISTS %q", dbName))
|
||||
require.NoError(t, err)
|
||||
|
||||
|
||||
@ -26,7 +26,7 @@ import (
|
||||
"github.com/sourcegraph/sourcegraph/lib/pointers"
|
||||
)
|
||||
|
||||
type importer struct {
|
||||
type Importer struct {
|
||||
logger log.Logger
|
||||
|
||||
dotcom *dotcomdb.Reader
|
||||
@ -36,14 +36,29 @@ type importer struct {
|
||||
codyGatewayAccess *codyaccess.CodyGatewayStore
|
||||
}
|
||||
|
||||
var _ goroutine.Handler = (*importer)(nil)
|
||||
func NewHandler(
|
||||
ctx context.Context,
|
||||
logger log.Logger,
|
||||
dotcom *dotcomdb.Reader,
|
||||
enterprisePortal *database.DB,
|
||||
) *Importer {
|
||||
return &Importer{
|
||||
logger: logger,
|
||||
dotcom: dotcom,
|
||||
subscriptions: enterprisePortal.Subscriptions(),
|
||||
licenses: enterprisePortal.Subscriptions().Licenses(),
|
||||
codyGatewayAccess: enterprisePortal.CodyAccess().CodyGateway(),
|
||||
}
|
||||
}
|
||||
|
||||
// New returns a periodic goroutine that runs an importer that reconciles
|
||||
// subscriptions, licenses, and Cody Gateway access from dotcom into the
|
||||
// Enterprise Portal database.
|
||||
var _ goroutine.Handler = (*Importer)(nil)
|
||||
|
||||
// NewPeriodicImporter returns a periodic goroutine that runs an importer that
|
||||
// reconciles subscriptions, licenses, and Cody Gateway access from dotcom into
|
||||
// the Enterprise Portal database.
|
||||
//
|
||||
// If interval is 0, the importer is disabled.
|
||||
func New(
|
||||
func NewPeriodicImporter(
|
||||
ctx context.Context,
|
||||
logger log.Logger,
|
||||
dotcom *dotcomdb.Reader,
|
||||
@ -56,13 +71,7 @@ func New(
|
||||
}
|
||||
return goroutine.NewPeriodicGoroutine(
|
||||
ctx,
|
||||
&importer{
|
||||
logger: logger,
|
||||
dotcom: dotcom,
|
||||
subscriptions: enterprisePortal.Subscriptions(),
|
||||
licenses: enterprisePortal.Subscriptions().Licenses(),
|
||||
codyGatewayAccess: enterprisePortal.CodyAccess().CodyGateway(),
|
||||
},
|
||||
NewHandler(ctx, logger, dotcom, enterprisePortal),
|
||||
goroutine.WithOperation(
|
||||
observation.NewContext(logger, observation.Tracer(trace.GetTracer())).
|
||||
Operation(observation.Op{
|
||||
@ -72,7 +81,7 @@ func New(
|
||||
goroutine.WithInterval(interval))
|
||||
}
|
||||
|
||||
func (i *importer) Handle(ctx context.Context) error {
|
||||
func (i *Importer) Handle(ctx context.Context) error {
|
||||
l := trace.Logger(ctx, i.logger)
|
||||
|
||||
dotcomSubscriptions, err := i.dotcom.ListEnterpriseSubscriptions(ctx, dotcomdb.ListEnterpriseSubscriptionsOptions{})
|
||||
@ -101,7 +110,7 @@ func (i *importer) Handle(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *importer) importSubscription(ctx context.Context, dotcomSub *dotcomdb.SubscriptionAttributes) (err error) {
|
||||
func (i *Importer) importSubscription(ctx context.Context, dotcomSub *dotcomdb.SubscriptionAttributes) (err error) {
|
||||
tr, ctx := trace.New(ctx, "importSubscription",
|
||||
attribute.String("dotcomSub.ID", dotcomSub.ID))
|
||||
defer tr.EndWithErr(&err)
|
||||
@ -227,7 +236,7 @@ func (i *importer) importSubscription(ctx context.Context, dotcomSub *dotcomdb.S
|
||||
return nil
|
||||
}
|
||||
|
||||
func (i *importer) importLicense(ctx context.Context, subscriptionID string, dotcomLicense *dotcomdb.LicenseAttributes) (err error) {
|
||||
func (i *Importer) importLicense(ctx context.Context, subscriptionID string, dotcomLicense *dotcomdb.LicenseAttributes) (err error) {
|
||||
tr, ctx := trace.New(ctx, "importSubscription",
|
||||
attribute.String("dotcomSub.ID", subscriptionID),
|
||||
attribute.String("dotcomLicense.ID", dotcomLicense.ID))
|
||||
|
||||
@ -28,6 +28,10 @@ go_test(
|
||||
],
|
||||
deps = [
|
||||
":dotcomdb",
|
||||
"//cmd/enterprise-portal/internal/codyaccessservice",
|
||||
"//cmd/enterprise-portal/internal/database",
|
||||
"//cmd/enterprise-portal/internal/database/databasetest",
|
||||
"//cmd/enterprise-portal/internal/database/importer",
|
||||
"//cmd/frontend/dotcomproductsubscriptiontest",
|
||||
"//cmd/frontend/graphqlbackend",
|
||||
"//internal/database",
|
||||
@ -35,11 +39,15 @@ go_test(
|
||||
"//internal/database/dbtest",
|
||||
"//internal/license",
|
||||
"//internal/licensing",
|
||||
"//lib/enterpriseportal/codyaccess/v1:codyaccess",
|
||||
"//lib/enterpriseportal/subscriptions/v1:subscriptions",
|
||||
"//lib/pointers",
|
||||
"@com_connectrpc_connect//:connect",
|
||||
"@com_github_jackc_pgx_v4//stdlib",
|
||||
"@com_github_jackc_pgx_v5//pgxpool",
|
||||
"@com_github_sourcegraph_log//logtest",
|
||||
"@com_github_sourcegraph_sourcegraph_accounts_sdk_go//:sourcegraph-accounts-sdk-go",
|
||||
"@com_github_sourcegraph_sourcegraph_accounts_sdk_go//scopes",
|
||||
"@com_github_stretchr_testify//assert",
|
||||
"@com_github_stretchr_testify//require",
|
||||
],
|
||||
|
||||
@ -125,10 +125,10 @@ func (c CodyGatewayAccessAttributes) EvaluateRateLimits() CodyGatewayRateLimits
|
||||
Chat: licensing.NewCodyGatewayChatRateLimit(p, c.ActiveLicenseUserCount),
|
||||
|
||||
CodeSource: codyaccessv1.CodyGatewayRateLimitSource_CODY_GATEWAY_RATE_LIMIT_SOURCE_PLAN,
|
||||
Code: licensing.NewCodyGatewayCodeRateLimit(p, c.ActiveLicenseUserCount, c.ActiveLicenseTags),
|
||||
Code: licensing.NewCodyGatewayCodeRateLimit(p, c.ActiveLicenseUserCount),
|
||||
|
||||
EmbeddingsSource: codyaccessv1.CodyGatewayRateLimitSource_CODY_GATEWAY_RATE_LIMIT_SOURCE_PLAN,
|
||||
Embeddings: licensing.NewCodyGatewayEmbeddingsRateLimit(p, c.ActiveLicenseUserCount, c.ActiveLicenseTags),
|
||||
Embeddings: licensing.NewCodyGatewayEmbeddingsRateLimit(p, c.ActiveLicenseUserCount),
|
||||
}
|
||||
|
||||
// Chat
|
||||
|
||||
@ -6,6 +6,7 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
pgxstdlibv4 "github.com/jackc/pgx/v4/stdlib"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/stretchr/testify/assert"
|
||||
@ -13,19 +14,28 @@ import (
|
||||
|
||||
"github.com/sourcegraph/log/logtest"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph-accounts-sdk-go/scopes"
|
||||
|
||||
sams "github.com/sourcegraph/sourcegraph-accounts-sdk-go"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/codyaccessservice"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/databasetest"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/database/importer"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/enterprise-portal/internal/dotcomdb"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/dotcomproductsubscriptiontest"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
|
||||
"github.com/sourcegraph/sourcegraph/internal/database"
|
||||
sgdatabase "github.com/sourcegraph/sourcegraph/internal/database"
|
||||
"github.com/sourcegraph/sourcegraph/internal/database/dbconn"
|
||||
"github.com/sourcegraph/sourcegraph/internal/database/dbtest"
|
||||
"github.com/sourcegraph/sourcegraph/internal/license"
|
||||
"github.com/sourcegraph/sourcegraph/internal/licensing"
|
||||
v1 "github.com/sourcegraph/sourcegraph/lib/enterpriseportal/subscriptions/v1"
|
||||
codyaccessv1 "github.com/sourcegraph/sourcegraph/lib/enterpriseportal/codyaccess/v1"
|
||||
subscriptionsv1 "github.com/sourcegraph/sourcegraph/lib/enterpriseportal/subscriptions/v1"
|
||||
"github.com/sourcegraph/sourcegraph/lib/pointers"
|
||||
)
|
||||
|
||||
func newTestDotcomReader(t *testing.T, opts dotcomdb.ReaderOptions) (database.DB, *dotcomdb.Reader) {
|
||||
func newTestDotcomReader(t *testing.T, opts dotcomdb.ReaderOptions) (sgdatabase.DB, *dotcomdb.Reader) {
|
||||
ctx := context.Background()
|
||||
|
||||
// Set up a Sourcegraph test database.
|
||||
@ -65,7 +75,23 @@ func newTestDotcomReader(t *testing.T, opts dotcomdb.ReaderOptions) (database.DB
|
||||
r := dotcomdb.NewReader(conn, opts)
|
||||
require.NoError(t, r.Ping(ctx))
|
||||
|
||||
return database.NewDB(logtest.Scoped(t), sgtestdb), r
|
||||
return sgdatabase.NewDB(logtest.Scoped(t), sgtestdb), r
|
||||
}
|
||||
|
||||
type mockCodyAccessV1Store struct{ codyaccessservice.StoreV1 }
|
||||
|
||||
func (mockCodyAccessV1Store) IntrospectSAMSToken(context.Context, string) (*sams.IntrospectTokenResponse, error) {
|
||||
return &sams.IntrospectTokenResponse{
|
||||
Active: true,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
Scopes: scopes.Scopes{scopes.ToScope(scopes.ServiceEnterprisePortal, scopes.PermissionEnterprisePortalCodyAccess, scopes.ActionRead)},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func mockAuthenticatedServiceRequest[T any](message *T) *connect.Request[T] {
|
||||
req := connect.NewRequest(message)
|
||||
req.Header().Set("authorization", "bearer foobar")
|
||||
return req
|
||||
}
|
||||
|
||||
type mockedData struct {
|
||||
@ -76,7 +102,7 @@ type mockedData struct {
|
||||
archivedSubscriptions int
|
||||
}
|
||||
|
||||
func setupDBAndInsertMockLicense(t *testing.T, dotcomdb database.DB, info license.Info, cgAccess *graphqlbackend.UpdateCodyGatewayAccessInput) mockedData {
|
||||
func setupDBAndInsertMockLicense(t *testing.T, dotcomdb sgdatabase.DB, info license.Info, cgAccess *graphqlbackend.UpdateCodyGatewayAccessInput) mockedData {
|
||||
start := time.Now()
|
||||
|
||||
ctx := context.Background()
|
||||
@ -87,7 +113,7 @@ func setupDBAndInsertMockLicense(t *testing.T, dotcomdb database.DB, info licens
|
||||
{
|
||||
// Create a different subscription and license that's rubbish,
|
||||
// created at the same time, to ensure we don't use it
|
||||
u, err := dotcomdb.Users().Create(ctx, database.NewUser{Username: "barbaz"})
|
||||
u, err := dotcomdb.Users().Create(ctx, sgdatabase.NewUser{Username: "barbaz"})
|
||||
require.NoError(t, err)
|
||||
sub, err := subscriptionsdb.Create(ctx, u.ID, u.Username)
|
||||
require.NoError(t, err)
|
||||
@ -104,7 +130,7 @@ func setupDBAndInsertMockLicense(t *testing.T, dotcomdb database.DB, info licens
|
||||
{
|
||||
// Create a different subscription and license that's archived,
|
||||
// created at the same time, to ensure we don't use it
|
||||
u, err := dotcomdb.Users().Create(ctx, database.NewUser{Username: "archived"})
|
||||
u, err := dotcomdb.Users().Create(ctx, sgdatabase.NewUser{Username: "archived"})
|
||||
require.NoError(t, err)
|
||||
sub, err := subscriptionsdb.Create(ctx, u.ID, u.Username)
|
||||
require.NoError(t, err)
|
||||
@ -124,7 +150,7 @@ func setupDBAndInsertMockLicense(t *testing.T, dotcomdb database.DB, info licens
|
||||
{
|
||||
// Create a different subscription and license that's not a dev tag,
|
||||
// created at the same time, to ensure we don't use it
|
||||
u, err := dotcomdb.Users().Create(ctx, database.NewUser{Username: "not-dev"})
|
||||
u, err := dotcomdb.Users().Create(ctx, sgdatabase.NewUser{Username: "not-dev"})
|
||||
require.NoError(t, err)
|
||||
sub, err := subscriptionsdb.Create(ctx, u.ID, u.Username)
|
||||
require.NoError(t, err)
|
||||
@ -138,7 +164,7 @@ func setupDBAndInsertMockLicense(t *testing.T, dotcomdb database.DB, info licens
|
||||
}
|
||||
|
||||
// Create the subscription we will assert against
|
||||
u, err := dotcomdb.Users().Create(ctx, database.NewUser{Username: "user"})
|
||||
u, err := dotcomdb.Users().Create(ctx, sgdatabase.NewUser{Username: "user"})
|
||||
require.NoError(t, err)
|
||||
subid, err := subscriptionsdb.Create(ctx, u.ID, u.Username)
|
||||
require.NoError(t, err)
|
||||
@ -150,7 +176,7 @@ func setupDBAndInsertMockLicense(t *testing.T, dotcomdb database.DB, info licens
|
||||
result.accessTokens = append(result.accessTokens, license.GenerateLicenseKeyBasedAccessToken(key1))
|
||||
_, err = licensesdb.Create(ctx, subid, key1, 2, license.Info{
|
||||
CreatedAt: info.CreatedAt.Add(-time.Hour),
|
||||
ExpiresAt: info.ExpiresAt.Add(-time.Hour),
|
||||
ExpiresAt: info.ExpiresAt.Add(-time.Minute), // should expire first, but not be expired
|
||||
Tags: []string{licensing.DevTag},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
@ -170,7 +196,7 @@ func setupDBAndInsertMockLicense(t *testing.T, dotcomdb database.DB, info licens
|
||||
{
|
||||
// Create another different subscription and license that's also rubbish,
|
||||
// created at the same time, to ensure we don't use it
|
||||
u, err := dotcomdb.Users().Create(ctx, database.NewUser{Username: "foobar"})
|
||||
u, err := dotcomdb.Users().Create(ctx, sgdatabase.NewUser{Username: "foobar"})
|
||||
require.NoError(t, err)
|
||||
sub, err := subscriptionsdb.Create(ctx, u.ID, u.Username)
|
||||
require.NoError(t, err)
|
||||
@ -198,7 +224,7 @@ func TestGetCodyGatewayAccessAttributes(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
for _, tc := range []struct {
|
||||
for i, tc := range []struct {
|
||||
name string
|
||||
info license.Info
|
||||
cgAccess graphqlbackend.UpdateCodyGatewayAccessInput
|
||||
@ -244,16 +270,37 @@ func TestGetCodyGatewayAccessAttributes(t *testing.T) {
|
||||
dotcomdb, dotcomreader := newTestDotcomReader(t, dotcomdb.ReaderOptions{
|
||||
DevOnly: true,
|
||||
})
|
||||
|
||||
// First, set up a subscription and license and some other rubbish
|
||||
// data to ensure we only get the license we want.
|
||||
mock := setupDBAndInsertMockLicense(t, dotcomdb, tc.info, &tc.cgAccess)
|
||||
|
||||
// Now import the data for parity so we can compare against the
|
||||
// Enterprise Portal implementation
|
||||
epDB := &database.DB{
|
||||
DB: databasetest.NewTestDB(t, "ep-dotcomdb", fmt.Sprintf("get-attributes-%d", i), databasetest.Tables(t)...),
|
||||
}
|
||||
err := importer.NewHandler(ctx, logtest.Scoped(t), dotcomreader, epDB).Handle(ctx)
|
||||
require.NoError(t, err)
|
||||
codyAccessService := codyaccessservice.NewHandlerV1(logtest.Scoped(t), mockCodyAccessV1Store{
|
||||
StoreV1: codyaccessservice.NewStoreV1(codyaccessservice.StoreV1Options{
|
||||
DB: epDB,
|
||||
}),
|
||||
})
|
||||
|
||||
t.Run("by subscription ID", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
attr, err := dotcomreader.GetCodyGatewayAccessAttributesBySubscription(ctx, mock.targetSubscriptionID)
|
||||
require.NoError(t, err)
|
||||
validateAccessAttributes(t, dotcomdb, mock, attr, tc.info)
|
||||
access, err := codyAccessService.GetCodyGatewayAccess(ctx, mockAuthenticatedServiceRequest(&codyaccessv1.GetCodyGatewayAccessRequest{
|
||||
Query: &codyaccessv1.GetCodyGatewayAccessRequest_SubscriptionId{
|
||||
SubscriptionId: mock.targetSubscriptionID,
|
||||
},
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
|
||||
validateAccessAttributes(t, dotcomdb, mock, attr, access.Msg.GetAccess(), tc.info)
|
||||
})
|
||||
|
||||
t.Run("by access token", func(t *testing.T) {
|
||||
@ -266,7 +313,13 @@ func TestGetCodyGatewayAccessAttributes(t *testing.T) {
|
||||
|
||||
attr, err := dotcomreader.GetCodyGatewayAccessAttributesByAccessToken(ctx, token)
|
||||
require.NoError(t, err)
|
||||
validateAccessAttributes(t, dotcomdb, mock, attr, tc.info)
|
||||
access, err := codyAccessService.GetCodyGatewayAccess(ctx, mockAuthenticatedServiceRequest(&codyaccessv1.GetCodyGatewayAccessRequest{
|
||||
Query: &codyaccessv1.GetCodyGatewayAccessRequest_AccessToken{
|
||||
AccessToken: token,
|
||||
},
|
||||
}))
|
||||
require.NoError(t, err)
|
||||
validateAccessAttributes(t, dotcomdb, mock, attr, access.Msg.GetAccess(), tc.info)
|
||||
|
||||
t.Run("compare with dotcom tokens DB", func(t *testing.T) {
|
||||
subID, err := dotcomproductsubscriptiontest.NewTokensDB(t, dotcomdb).
|
||||
@ -281,11 +334,20 @@ func TestGetCodyGatewayAccessAttributes(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func validateAccessAttributes(t *testing.T, dotcomdb database.DB, mock mockedData, attr *dotcomdb.CodyGatewayAccessAttributes, info license.Info) {
|
||||
func validateAccessAttributes(t *testing.T, dotcomdb sgdatabase.DB, mock mockedData, attr *dotcomdb.CodyGatewayAccessAttributes, access *codyaccessv1.CodyGatewayAccess, info license.Info) {
|
||||
assert.Equal(t, mock.targetSubscriptionID, attr.SubscriptionID)
|
||||
assert.Equal(t, subscriptionsv1.EnterpriseSubscriptionIDPrefix+mock.targetSubscriptionID, access.SubscriptionId)
|
||||
|
||||
assert.Equal(t, int(info.UserCount), *attr.ActiveLicenseUserCount)
|
||||
assert.Len(t, attr.LicenseKeyHashes, 2)
|
||||
assert.Equal(t, attr.GenerateAccessTokens(), mock.accessTokens)
|
||||
|
||||
assert.Equal(t, mock.accessTokens, attr.GenerateAccessTokens())
|
||||
var protoAccessTokens []string
|
||||
for _, t := range access.AccessTokens {
|
||||
protoAccessTokens = append(protoAccessTokens, t.GetToken())
|
||||
}
|
||||
assert.ElementsMatch(t, mock.accessTokens, protoAccessTokens)
|
||||
|
||||
limits := attr.EvaluateRateLimits()
|
||||
|
||||
// Validate against the expected values as produced by existing resolvers
|
||||
@ -300,30 +362,37 @@ func validateAccessAttributes(t *testing.T, dotcomdb database.DB, mock mockedDat
|
||||
name string
|
||||
expected graphqlbackend.CodyGatewayRateLimit
|
||||
got licensing.CodyGatewayRateLimit
|
||||
gotProto *codyaccessv1.CodyGatewayRateLimit
|
||||
}{{
|
||||
name: "Chat",
|
||||
expected: mustWithCtx(t, expected.ChatCompletionsRateLimit),
|
||||
got: limits.Chat,
|
||||
gotProto: access.ChatCompletionsRateLimit,
|
||||
}, {
|
||||
name: "Code",
|
||||
expected: mustWithCtx(t, expected.CodeCompletionsRateLimit),
|
||||
got: limits.Code,
|
||||
gotProto: access.CodeCompletionsRateLimit,
|
||||
}, {
|
||||
name: "Embeddings",
|
||||
expected: mustWithCtx(t, expected.EmbeddingsRateLimit),
|
||||
got: limits.Embeddings,
|
||||
gotProto: access.EmbeddingsRateLimit,
|
||||
}} {
|
||||
t.Run(compare.name, func(t *testing.T) {
|
||||
// We only care about limit and interval now
|
||||
assert.Equal(t, int64(compare.expected.Limit()), compare.got.Limit, "Limit")
|
||||
assert.Equal(t, compare.expected.IntervalSeconds(), compare.got.IntervalSeconds, "IntervalSeconds")
|
||||
|
||||
assert.Equal(t, uint64(compare.expected.Limit()), compare.gotProto.Limit, "Limit")
|
||||
assert.Equal(t, int64(compare.expected.IntervalSeconds()), compare.gotProto.IntervalDuration.Seconds, "IntervalSeconds")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetAllCodyGatewayAccessAttributes(t *testing.T) {
|
||||
t.Parallel()
|
||||
dotcomdb, dotcomreader := newTestDotcomReader(t, dotcomdb.ReaderOptions{
|
||||
dotcomDB, dotcomreader := newTestDotcomReader(t, dotcomdb.ReaderOptions{
|
||||
DevOnly: true,
|
||||
})
|
||||
|
||||
@ -334,21 +403,53 @@ func TestGetAllCodyGatewayAccessAttributes(t *testing.T) {
|
||||
Tags: []string{licensing.PlanEnterprise1.Tag(), licensing.DevTag},
|
||||
}
|
||||
cgAccess := graphqlbackend.UpdateCodyGatewayAccessInput{Enabled: pointers.Ptr(true)}
|
||||
mock := setupDBAndInsertMockLicense(t, dotcomdb, info, &cgAccess)
|
||||
mock := setupDBAndInsertMockLicense(t, dotcomDB, info, &cgAccess)
|
||||
|
||||
attrs, err := dotcomreader.GetAllCodyGatewayAccessAttributes(context.Background())
|
||||
// Now import the data for parity so we can compare against the
|
||||
// Enterprise Portal implementation
|
||||
ctx := context.Background()
|
||||
epDB := &database.DB{
|
||||
DB: databasetest.NewTestDB(t, "ep-dotcomdb", "get-all-attributes", databasetest.Tables(t)...),
|
||||
}
|
||||
err := importer.NewHandler(ctx, logtest.Scoped(t), dotcomreader, epDB).Handle(ctx)
|
||||
require.NoError(t, err)
|
||||
codyAccessService := codyaccessservice.NewHandlerV1(logtest.Scoped(t), mockCodyAccessV1Store{
|
||||
StoreV1: codyaccessservice.NewStoreV1(codyaccessservice.StoreV1Options{
|
||||
DB: epDB,
|
||||
}),
|
||||
})
|
||||
|
||||
attrs, err := dotcomreader.GetAllCodyGatewayAccessAttributes(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, attrs, 3) // 3 subscriptions created in setupDBAndInsertMockLicense
|
||||
var found bool
|
||||
|
||||
accesses, err := codyAccessService.ListCodyGatewayAccesses(ctx, mockAuthenticatedServiceRequest(&codyaccessv1.ListCodyGatewayAccessesRequest{}))
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, accesses.Msg.GetAccesses(), 3) // 3 subscriptions created in setupDBAndInsertMockLicense
|
||||
|
||||
var (
|
||||
foundAttr *dotcomdb.CodyGatewayAccessAttributes
|
||||
foundAccess *codyaccessv1.CodyGatewayAccess
|
||||
)
|
||||
for _, attr := range attrs {
|
||||
if attr.SubscriptionID == mock.targetSubscriptionID {
|
||||
found = true
|
||||
validateAccessAttributes(t, dotcomdb, mock, attr, info)
|
||||
foundAttr = attr
|
||||
} else {
|
||||
assert.False(t, attr.CodyGatewayEnabled)
|
||||
}
|
||||
}
|
||||
assert.True(t, found)
|
||||
for _, access := range accesses.Msg.GetAccesses() {
|
||||
if access.SubscriptionId == subscriptionsv1.EnterpriseSubscriptionIDPrefix+mock.targetSubscriptionID {
|
||||
foundAccess = access
|
||||
} else {
|
||||
assert.False(t, access.Enabled)
|
||||
}
|
||||
}
|
||||
require.NotNil(t, foundAttr)
|
||||
require.NotNil(t, foundAccess)
|
||||
|
||||
validateAccessAttributes(t, dotcomDB, mock, foundAttr, foundAccess, info)
|
||||
|
||||
}
|
||||
|
||||
func TestListEnterpriseSubscriptionLicenses(t *testing.T) {
|
||||
@ -378,7 +479,7 @@ func TestListEnterpriseSubscriptionLicenses(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
for _, tc := range []struct {
|
||||
name string
|
||||
filters []*v1.ListEnterpriseSubscriptionLicensesFilter
|
||||
filters []*subscriptionsv1.ListEnterpriseSubscriptionLicensesFilter
|
||||
pageSize int
|
||||
expect func(t *testing.T, licenses []*dotcomdb.LicenseAttributes)
|
||||
}{{
|
||||
@ -390,8 +491,8 @@ func TestListEnterpriseSubscriptionLicenses(t *testing.T) {
|
||||
},
|
||||
}, {
|
||||
name: "filter by subscription ID",
|
||||
filters: []*v1.ListEnterpriseSubscriptionLicensesFilter{{
|
||||
Filter: &v1.ListEnterpriseSubscriptionLicensesFilter_SubscriptionId{
|
||||
filters: []*subscriptionsv1.ListEnterpriseSubscriptionLicensesFilter{{
|
||||
Filter: &subscriptionsv1.ListEnterpriseSubscriptionLicensesFilter_SubscriptionId{
|
||||
SubscriptionId: mock.targetSubscriptionID,
|
||||
},
|
||||
}},
|
||||
@ -403,8 +504,8 @@ func TestListEnterpriseSubscriptionLicenses(t *testing.T) {
|
||||
},
|
||||
}, {
|
||||
name: "filter by subscription ID and limit 1",
|
||||
filters: []*v1.ListEnterpriseSubscriptionLicensesFilter{{
|
||||
Filter: &v1.ListEnterpriseSubscriptionLicensesFilter_SubscriptionId{
|
||||
filters: []*subscriptionsv1.ListEnterpriseSubscriptionLicensesFilter{{
|
||||
Filter: &subscriptionsv1.ListEnterpriseSubscriptionLicensesFilter_SubscriptionId{
|
||||
SubscriptionId: mock.targetSubscriptionID,
|
||||
},
|
||||
}},
|
||||
@ -415,12 +516,12 @@ func TestListEnterpriseSubscriptionLicenses(t *testing.T) {
|
||||
},
|
||||
}, {
|
||||
name: "filter by subscription ID and not archived",
|
||||
filters: []*v1.ListEnterpriseSubscriptionLicensesFilter{{
|
||||
Filter: &v1.ListEnterpriseSubscriptionLicensesFilter_SubscriptionId{
|
||||
filters: []*subscriptionsv1.ListEnterpriseSubscriptionLicensesFilter{{
|
||||
Filter: &subscriptionsv1.ListEnterpriseSubscriptionLicensesFilter_SubscriptionId{
|
||||
SubscriptionId: mock.targetSubscriptionID,
|
||||
},
|
||||
}, {
|
||||
Filter: &v1.ListEnterpriseSubscriptionLicensesFilter_IsRevoked{
|
||||
Filter: &subscriptionsv1.ListEnterpriseSubscriptionLicensesFilter_IsRevoked{
|
||||
IsRevoked: false,
|
||||
},
|
||||
}},
|
||||
@ -434,8 +535,8 @@ func TestListEnterpriseSubscriptionLicenses(t *testing.T) {
|
||||
},
|
||||
}, {
|
||||
name: "filter by is archived",
|
||||
filters: []*v1.ListEnterpriseSubscriptionLicensesFilter{{
|
||||
Filter: &v1.ListEnterpriseSubscriptionLicensesFilter_IsRevoked{
|
||||
filters: []*subscriptionsv1.ListEnterpriseSubscriptionLicensesFilter{{
|
||||
Filter: &subscriptionsv1.ListEnterpriseSubscriptionLicensesFilter_IsRevoked{
|
||||
IsRevoked: true,
|
||||
},
|
||||
}},
|
||||
@ -444,8 +545,8 @@ func TestListEnterpriseSubscriptionLicenses(t *testing.T) {
|
||||
},
|
||||
}, {
|
||||
name: "filter by not archived",
|
||||
filters: []*v1.ListEnterpriseSubscriptionLicensesFilter{{
|
||||
Filter: &v1.ListEnterpriseSubscriptionLicensesFilter_IsRevoked{
|
||||
filters: []*subscriptionsv1.ListEnterpriseSubscriptionLicensesFilter{{
|
||||
Filter: &subscriptionsv1.ListEnterpriseSubscriptionLicensesFilter_IsRevoked{
|
||||
IsRevoked: false,
|
||||
},
|
||||
}},
|
||||
|
||||
@ -91,7 +91,7 @@ func (r codyGatewayAccessResolver) CodeCompletionsRateLimit(ctx context.Context)
|
||||
var source graphqlbackend.CodyGatewayRateLimitSource
|
||||
if activeLicense != nil {
|
||||
source = graphqlbackend.CodyGatewayRateLimitSourcePlan
|
||||
rateLimit = licensing.NewCodyGatewayCodeRateLimit(licensing.PlanFromTags(activeLicense.LicenseTags), activeLicense.LicenseUserCount, activeLicense.LicenseTags)
|
||||
rateLimit = licensing.NewCodyGatewayCodeRateLimit(licensing.PlanFromTags(activeLicense.LicenseTags), activeLicense.LicenseUserCount)
|
||||
}
|
||||
|
||||
// Apply overrides
|
||||
@ -131,7 +131,7 @@ func (r codyGatewayAccessResolver) EmbeddingsRateLimit(ctx context.Context) (gra
|
||||
var source graphqlbackend.CodyGatewayRateLimitSource
|
||||
if activeLicense != nil {
|
||||
source = graphqlbackend.CodyGatewayRateLimitSourcePlan
|
||||
rateLimit = licensing.NewCodyGatewayEmbeddingsRateLimit(licensing.PlanFromTags(activeLicense.LicenseTags), activeLicense.LicenseUserCount, activeLicense.LicenseTags)
|
||||
rateLimit = licensing.NewCodyGatewayEmbeddingsRateLimit(licensing.PlanFromTags(activeLicense.LicenseTags), activeLicense.LicenseUserCount)
|
||||
}
|
||||
|
||||
// Apply overrides
|
||||
|
||||
@ -53,7 +53,7 @@ func NewCodyGatewayChatRateLimit(plan Plan, userCount *int) CodyGatewayRateLimit
|
||||
}
|
||||
|
||||
// NewCodyGatewayCodeRateLimit applies default Cody Gateway access based on the plan.
|
||||
func NewCodyGatewayCodeRateLimit(plan Plan, userCount *int, licenseTags []string) CodyGatewayRateLimit {
|
||||
func NewCodyGatewayCodeRateLimit(plan Plan, userCount *int) CodyGatewayRateLimit {
|
||||
uc := 0
|
||||
if userCount != nil {
|
||||
uc = *userCount
|
||||
@ -87,7 +87,7 @@ func NewCodyGatewayCodeRateLimit(plan Plan, userCount *int, licenseTags []string
|
||||
const tokensPerDollar = int(1 / (0.0001 / 1_000))
|
||||
|
||||
// NewCodyGatewayEmbeddingsRateLimit applies default Cody Gateway access based on the plan.
|
||||
func NewCodyGatewayEmbeddingsRateLimit(plan Plan, userCount *int, licenseTags []string) CodyGatewayRateLimit {
|
||||
func NewCodyGatewayEmbeddingsRateLimit(plan Plan, userCount *int) CodyGatewayRateLimit {
|
||||
uc := 0
|
||||
if userCount != nil {
|
||||
uc = *userCount
|
||||
|
||||
@ -104,7 +104,7 @@ func TestCodyGatewayCodeRateLimit(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := NewCodyGatewayCodeRateLimit(tt.plan, tt.userCount, tt.licenseTags)
|
||||
got := NewCodyGatewayCodeRateLimit(tt.plan, tt.userCount)
|
||||
if diff := cmp.Diff(got, tt.want); diff != "" {
|
||||
t.Fatalf("incorrect rate limit computed: %s", diff)
|
||||
}
|
||||
@ -151,7 +151,7 @@ func TestCodyGatewayEmbeddingsRateLimit(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := NewCodyGatewayEmbeddingsRateLimit(tt.plan, tt.userCount, tt.licenseTags)
|
||||
got := NewCodyGatewayEmbeddingsRateLimit(tt.plan, tt.userCount)
|
||||
if diff := cmp.Diff(got, tt.want); diff != "" {
|
||||
t.Fatalf("incorrect rate limit computed: %s", diff)
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user