cody-gateway: use subscription account name for usage notification (#54648)

This PR changes the Cody Gateway usage notifications for product
subscriptions to mention the account name that is bound to the product
subscription in a best-effort manner, and fallback in the following
order:

1. Extract the account name from `customer:` license tag if present
2. Use the account username
3. For any reason account username isn't available, continue using the
subscription ID (UUID)

## Test plan

Unit tests are added.

Manual test:

<img width="827" alt="CleanShot 2023-07-05 at 21 51 10@2x"
src="https://github.com/sourcegraph/sourcegraph/assets/2946214/6475f689-9d89-47e9-9a24-8aaa9420267b">

---------

Co-authored-by: Robert Lin <robert@bobheadxi.dev>
This commit is contained in:
Joe Chen 2023-07-06 13:41:42 -04:00 committed by GitHub
parent c2de33541a
commit faabc3a35b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 378 additions and 183 deletions

View File

@ -22,11 +22,14 @@ type Actor struct {
Key string `json:"key"`
// ID is the identifier for this actor's rate-limiting pool. It is not a sensitive
// value. It must be set for all valid actors - if empty, the actor must be invalid
// and must not not have any feature access.
// and must not have any feature access.
//
// For example, for product subscriptions this is the subscription UUID. For
// Sourcegraph.com users, this is the string representation of the user ID.
ID string `json:"id"`
// Name is the human-readable name for this actor, e.g. username, account name.
// Optional for implementations - if unset, ID will be returned from GetName().
Name string `json:"name"`
// AccessEnabled is an evaluated field that summarizes whether or not Cody Gateway access
// is enabled.
//
@ -41,6 +44,24 @@ type Actor struct {
Source Source `json:"-"`
}
func (a *Actor) GetID() string {
return a.ID
}
func (a *Actor) GetName() string {
if a.Name == "" {
return a.ID
}
return a.Name
}
func (a *Actor) GetSource() codygateway.ActorSource {
if a.Source == nil {
return "unknown"
}
return codygateway.ActorSource(a.Source.Name())
}
type contextKey int
const actorKey contextKey = iota
@ -158,7 +179,7 @@ func (a *Actor) Limiter(
UpdateRateLimitTTL: a.LastUpdated != nil && time.Since(*a.LastUpdated) < 5*time.Minute,
NowFunc: time.Now,
RateLimitAlerter: func(ctx context.Context, usageRatio float32, ttl time.Duration) {
rateLimitNotifier(ctx, a.ID, codygateway.ActorSource(a.Source.Name()), feature, usageRatio, ttl)
rateLimitNotifier(ctx, a, feature, usageRatio, ttl)
},
}

View File

@ -28,8 +28,9 @@ func (s *Source) Get(ctx context.Context, token string) (*actor.Actor, error) {
return nil, actor.ErrNotFromSource{}
}
return &actor.Actor{
ID: "anonymous", // TODO: Make this IP-based?
Key: token,
ID: "anonymous", // TODO: Make this IP-based?
Name: "anonymous", // TODO: Make this IP-based?
AccessEnabled: s.allowAnonymous,
// Some basic defaults for chat and code completions.
RateLimits: map[codygateway.Feature]actor.RateLimit{

View File

@ -144,6 +144,7 @@ func newActor(source *Source, cacheKey string, user dotcom.DotcomUserState, conc
a := &actor.Actor{
Key: cacheKey,
ID: userID,
Name: user.Username,
AccessEnabled: userID != "" && user.GetCodyGatewayAccess().Enabled,
RateLimits: zeroRequestsAllowed(),
LastUpdated: &now,

View File

@ -54,11 +54,11 @@ var _ actor.Source = &Source{}
var _ actor.SourceUpdater = &Source{}
var _ actor.SourceSyncer = &Source{}
func NewSource(logger log.Logger, cache httpcache.Cache, dotComClient graphql.Client, internalMode bool, concurrencyConfig codygateway.ActorConcurrencyLimitConfig) *Source {
func NewSource(logger log.Logger, cache httpcache.Cache, dotcomClient graphql.Client, internalMode bool, concurrencyConfig codygateway.ActorConcurrencyLimitConfig) *Source {
return &Source{
log: logger.Scoped("productsubscriptions", "product subscription actor source"),
cache: cache,
dotcom: dotComClient,
dotcom: dotcomClient,
internalMode: internalMode,
@ -216,8 +216,32 @@ func (s *Source) fetchAndCache(ctx context.Context, token string) (*actor.Actor,
return act, nil
}
// getSubscriptionAccountName attempts to get the account name from the product
// subscription. It returns an empty string if no account name is available.
func getSubscriptionAccountName(s dotcom.ProductSubscriptionState) string {
// 1. Check if the special "customer:" tag is present
if s.ActiveLicense != nil && s.ActiveLicense.Info != nil {
for _, tag := range s.ActiveLicense.Info.Tags {
if strings.HasPrefix(tag, "customer:") {
return strings.TrimPrefix(tag, "customer:")
}
}
}
// 2. Use the username of the account
if s.Account != nil && s.Account.Username != "" {
return s.Account.Username
}
return ""
}
// newActor creates an actor from Sourcegraph.com product subscription state.
func newActor(source *Source, token string, s dotcom.ProductSubscriptionState, internalMode bool, concurrencyConfig codygateway.ActorConcurrencyLimitConfig) *actor.Actor {
name := getSubscriptionAccountName(s)
if name == "" {
name = s.Uuid
}
// In internal mode, only allow dev and internal licenses.
disallowedLicense := internalMode &&
(s.ActiveLicense == nil || s.ActiveLicense.Info == nil ||
@ -227,6 +251,7 @@ func newActor(source *Source, token string, s dotcom.ProductSubscriptionState, i
a := &actor.Actor{
Key: token,
ID: s.Uuid,
Name: name,
AccessEnabled: !disallowedLicense && !s.IsArchived && s.CodyGatewayAccess.Enabled,
RateLimits: map[codygateway.Feature]actor.RateLimit{},
LastUpdated: &now,

View File

@ -115,3 +115,44 @@ func TestNewActor(t *testing.T) {
})
}
}
func TestGetSubscriptionAccountName(t *testing.T) {
tests := []struct {
name string
mockUsername string
mockTags []string
wantName string
}{
{
name: "has special license tag",
mockUsername: "alice",
mockTags: []string{"trial", "customer:acme"},
wantName: "acme",
},
{
name: "use account username",
mockUsername: "alice",
mockTags: []string{"plan:enterprise-1"},
wantName: "alice",
},
{
name: "no account name",
wantName: "",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
got := getSubscriptionAccountName(dotcom.ProductSubscriptionState{
Account: &dotcom.ProductSubscriptionStateAccountUser{
Username: test.mockUsername,
},
ActiveLicense: &dotcom.ProductSubscriptionStateActiveLicenseProductLicense{
Info: &dotcom.ProductSubscriptionStateActiveLicenseProductLicenseInfo{
Tags: test.mockTags,
},
},
})
assert.Equal(t, test.wantName, got)
})
}
}

View File

@ -56,7 +56,7 @@ func TestAuthenticatorMiddleware(t *testing.T) {
t.Run("authenticated without cache hit", func(t *testing.T) {
cache := NewMockCache()
client := NewMockClient()
client := dotcom.NewMockClient()
client.MakeRequestFunc.SetDefaultHook(func(_ context.Context, _ *graphql.Request, resp *graphql.Response) error {
resp.Data.(*dotcom.CheckAccessTokenResponse).Dotcom = dotcom.CheckAccessTokenDotcomDotcomQuery{
ProductSubscriptionByAccessToken: dotcom.CheckAccessTokenDotcomDotcomQueryProductSubscriptionByAccessTokenProductSubscription{
@ -109,7 +109,7 @@ func TestAuthenticatorMiddleware(t *testing.T) {
[]byte(`{"id":"UHJvZHVjdFN1YnNjcmlwdGlvbjoiNjQ1MmE4ZmMtZTY1MC00NWE3LWEwYTItMzU3Zjc3NmIzYjQ2Ig==","accessEnabled":true,"rateLimit":null}`),
true,
)
client := NewMockClient()
client := dotcom.NewMockClient()
next := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
require.NotNil(t, actor.FromContext(r.Context()))
w.WriteHeader(http.StatusOK)
@ -133,7 +133,7 @@ func TestAuthenticatorMiddleware(t *testing.T) {
[]byte(`{"id":"UHJvZHVjdFN1YnNjcmlwdGlvbjoiNjQ1MmE4ZmMtZTY1MC00NWE3LWEwYTItMzU3Zjc3NmIzYjQ2Ig==","accessEnabled":false,"rateLimit":null}`),
true,
)
client := NewMockClient()
client := dotcom.NewMockClient()
w := httptest.NewRecorder()
r := httptest.NewRequest(http.MethodPost, "/", strings.NewReader(`{}`))
@ -148,7 +148,7 @@ func TestAuthenticatorMiddleware(t *testing.T) {
t.Run("access token denied from sources", func(t *testing.T) {
cache := NewMockCache()
client := NewMockClient()
client := dotcom.NewMockClient()
client.MakeRequestFunc.SetDefaultHook(func(_ context.Context, _ *graphql.Request, resp *graphql.Response) error {
return gqlerror.List{
{
@ -171,7 +171,7 @@ func TestAuthenticatorMiddleware(t *testing.T) {
t.Run("server error from sources", func(t *testing.T) {
cache := NewMockCache()
client := NewMockClient()
client := dotcom.NewMockClient()
client.MakeRequestFunc.SetDefaultHook(func(_ context.Context, _ *graphql.Request, resp *graphql.Response) error {
return errors.New("server error")
})
@ -189,7 +189,7 @@ func TestAuthenticatorMiddleware(t *testing.T) {
t.Run("internal mode, authenticated but not dev license", func(t *testing.T) {
cache := NewMockCache()
client := NewMockClient()
client := dotcom.NewMockClient()
client.MakeRequestFunc.SetDefaultHook(func(_ context.Context, _ *graphql.Request, resp *graphql.Response) error {
resp.Data.(*dotcom.CheckAccessTokenResponse).Dotcom = dotcom.CheckAccessTokenDotcomDotcomQuery{
ProductSubscriptionByAccessToken: dotcom.CheckAccessTokenDotcomDotcomQueryProductSubscriptionByAccessTokenProductSubscription{
@ -238,7 +238,7 @@ func TestAuthenticatorMiddleware(t *testing.T) {
t.Run("internal mode, authenticated dev license", func(t *testing.T) {
cache := NewMockCache()
client := NewMockClient()
client := dotcom.NewMockClient()
client.MakeRequestFunc.SetDefaultHook(func(_ context.Context, _ *graphql.Request, resp *graphql.Response) error {
resp.Data.(*dotcom.CheckAccessTokenResponse).Dotcom = dotcom.CheckAccessTokenDotcomDotcomQuery{
ProductSubscriptionByAccessToken: dotcom.CheckAccessTokenDotcomDotcomQueryProductSubscriptionByAccessTokenProductSubscription{
@ -287,7 +287,7 @@ func TestAuthenticatorMiddleware(t *testing.T) {
t.Run("internal mode, authenticated internal license", func(t *testing.T) {
cache := NewMockCache()
client := NewMockClient()
client := dotcom.NewMockClient()
client.MakeRequestFunc.SetDefaultHook(func(_ context.Context, _ *graphql.Request, resp *graphql.Response) error {
resp.Data.(*dotcom.CheckAccessTokenResponse).Dotcom = dotcom.CheckAccessTokenDotcomDotcomQuery{
ProductSubscriptionByAccessToken: dotcom.CheckAccessTokenDotcomDotcomQueryProductSubscriptionByAccessTokenProductSubscription{

View File

@ -7,10 +7,8 @@
package auth
import (
"context"
"sync"
graphql "github.com/Khan/genqlient/graphql"
httpcache "github.com/gregjones/httpcache"
)
@ -390,152 +388,3 @@ func (c CacheSetFuncCall) Args() []interface{} {
func (c CacheSetFuncCall) Results() []interface{} {
return []interface{}{}
}
// MockClient is a mock implementation of the Client interface (from the
// package github.com/Khan/genqlient/graphql) used for unit testing.
type MockClient struct {
// MakeRequestFunc is an instance of a mock function object controlling
// the behavior of the method MakeRequest.
MakeRequestFunc *ClientMakeRequestFunc
}
// NewMockClient creates a new mock of the Client interface. All methods
// return zero values for all results, unless overwritten.
func NewMockClient() *MockClient {
return &MockClient{
MakeRequestFunc: &ClientMakeRequestFunc{
defaultHook: func(context.Context, *graphql.Request, *graphql.Response) (r0 error) {
return
},
},
}
}
// NewStrictMockClient creates a new mock of the Client interface. All
// methods panic on invocation, unless overwritten.
func NewStrictMockClient() *MockClient {
return &MockClient{
MakeRequestFunc: &ClientMakeRequestFunc{
defaultHook: func(context.Context, *graphql.Request, *graphql.Response) error {
panic("unexpected invocation of MockClient.MakeRequest")
},
},
}
}
// NewMockClientFrom creates a new mock of the MockClient interface. All
// methods delegate to the given implementation, unless overwritten.
func NewMockClientFrom(i graphql.Client) *MockClient {
return &MockClient{
MakeRequestFunc: &ClientMakeRequestFunc{
defaultHook: i.MakeRequest,
},
}
}
// ClientMakeRequestFunc describes the behavior when the MakeRequest method
// of the parent MockClient instance is invoked.
type ClientMakeRequestFunc struct {
defaultHook func(context.Context, *graphql.Request, *graphql.Response) error
hooks []func(context.Context, *graphql.Request, *graphql.Response) error
history []ClientMakeRequestFuncCall
mutex sync.Mutex
}
// MakeRequest delegates to the next hook function in the queue and stores
// the parameter and result values of this invocation.
func (m *MockClient) MakeRequest(v0 context.Context, v1 *graphql.Request, v2 *graphql.Response) error {
r0 := m.MakeRequestFunc.nextHook()(v0, v1, v2)
m.MakeRequestFunc.appendCall(ClientMakeRequestFuncCall{v0, v1, v2, r0})
return r0
}
// SetDefaultHook sets function that is called when the MakeRequest method
// of the parent MockClient instance is invoked and the hook queue is empty.
func (f *ClientMakeRequestFunc) SetDefaultHook(hook func(context.Context, *graphql.Request, *graphql.Response) error) {
f.defaultHook = hook
}
// PushHook adds a function to the end of hook queue. Each invocation of the
// MakeRequest method of the parent MockClient instance invokes the hook at
// the front of the queue and discards it. After the queue is empty, the
// default hook function is invoked for any future action.
func (f *ClientMakeRequestFunc) PushHook(hook func(context.Context, *graphql.Request, *graphql.Response) error) {
f.mutex.Lock()
f.hooks = append(f.hooks, hook)
f.mutex.Unlock()
}
// SetDefaultReturn calls SetDefaultHook with a function that returns the
// given values.
func (f *ClientMakeRequestFunc) SetDefaultReturn(r0 error) {
f.SetDefaultHook(func(context.Context, *graphql.Request, *graphql.Response) error {
return r0
})
}
// PushReturn calls PushHook with a function that returns the given values.
func (f *ClientMakeRequestFunc) PushReturn(r0 error) {
f.PushHook(func(context.Context, *graphql.Request, *graphql.Response) error {
return r0
})
}
func (f *ClientMakeRequestFunc) nextHook() func(context.Context, *graphql.Request, *graphql.Response) error {
f.mutex.Lock()
defer f.mutex.Unlock()
if len(f.hooks) == 0 {
return f.defaultHook
}
hook := f.hooks[0]
f.hooks = f.hooks[1:]
return hook
}
func (f *ClientMakeRequestFunc) appendCall(r0 ClientMakeRequestFuncCall) {
f.mutex.Lock()
f.history = append(f.history, r0)
f.mutex.Unlock()
}
// History returns a sequence of ClientMakeRequestFuncCall objects
// describing the invocations of this function.
func (f *ClientMakeRequestFunc) History() []ClientMakeRequestFuncCall {
f.mutex.Lock()
history := make([]ClientMakeRequestFuncCall, len(f.history))
copy(history, f.history)
f.mutex.Unlock()
return history
}
// ClientMakeRequestFuncCall is an object that describes an invocation of
// method MakeRequest on an instance of MockClient.
type ClientMakeRequestFuncCall struct {
// Arg0 is the value of the 1st argument passed to this method
// invocation.
Arg0 context.Context
// Arg1 is the value of the 2nd argument passed to this method
// invocation.
Arg1 *graphql.Request
// Arg2 is the value of the 3rd argument passed to this method
// invocation.
Arg2 *graphql.Response
// Result0 is the value of the 1st result returned from this method
// invocation.
Result0 error
}
// Args returns an interface slice containing the arguments of this
// invocation.
func (c ClientMakeRequestFuncCall) Args() []interface{} {
return []interface{}{c.Arg0, c.Arg1, c.Arg2}
}
// Results returns an interface slice containing the results of this
// invocation.
func (c ClientMakeRequestFuncCall) Results() []interface{} {
return []interface{}{c.Result0}
}

View File

@ -5,6 +5,7 @@ go_library(
srcs = [
"dotcom.go",
"gen.go",
"mocks.go",
"operations.go",
],
importpath = "github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/dotcom",

View File

@ -0,0 +1,163 @@
// Code generated by go-mockgen 1.3.7; DO NOT EDIT.
//
// This file was generated by running `sg generate` (or `go-mockgen`) at the root of
// this repository. To add additional mocks to this or another package, add a new entry
// to the mockgen.yaml file in the root of this repository.
package dotcom
import (
"context"
"sync"
graphql "github.com/Khan/genqlient/graphql"
)
// MockClient is a mock implementation of the Client interface (from the
// package github.com/Khan/genqlient/graphql) used for unit testing.
type MockClient struct {
// MakeRequestFunc is an instance of a mock function object controlling
// the behavior of the method MakeRequest.
MakeRequestFunc *ClientMakeRequestFunc
}
// NewMockClient creates a new mock of the Client interface. All methods
// return zero values for all results, unless overwritten.
func NewMockClient() *MockClient {
return &MockClient{
MakeRequestFunc: &ClientMakeRequestFunc{
defaultHook: func(context.Context, *graphql.Request, *graphql.Response) (r0 error) {
return
},
},
}
}
// NewStrictMockClient creates a new mock of the Client interface. All
// methods panic on invocation, unless overwritten.
func NewStrictMockClient() *MockClient {
return &MockClient{
MakeRequestFunc: &ClientMakeRequestFunc{
defaultHook: func(context.Context, *graphql.Request, *graphql.Response) error {
panic("unexpected invocation of MockClient.MakeRequest")
},
},
}
}
// NewMockClientFrom creates a new mock of the MockClient interface. All
// methods delegate to the given implementation, unless overwritten.
func NewMockClientFrom(i graphql.Client) *MockClient {
return &MockClient{
MakeRequestFunc: &ClientMakeRequestFunc{
defaultHook: i.MakeRequest,
},
}
}
// ClientMakeRequestFunc describes the behavior when the MakeRequest method
// of the parent MockClient instance is invoked.
type ClientMakeRequestFunc struct {
defaultHook func(context.Context, *graphql.Request, *graphql.Response) error
hooks []func(context.Context, *graphql.Request, *graphql.Response) error
history []ClientMakeRequestFuncCall
mutex sync.Mutex
}
// MakeRequest delegates to the next hook function in the queue and stores
// the parameter and result values of this invocation.
func (m *MockClient) MakeRequest(v0 context.Context, v1 *graphql.Request, v2 *graphql.Response) error {
r0 := m.MakeRequestFunc.nextHook()(v0, v1, v2)
m.MakeRequestFunc.appendCall(ClientMakeRequestFuncCall{v0, v1, v2, r0})
return r0
}
// SetDefaultHook sets function that is called when the MakeRequest method
// of the parent MockClient instance is invoked and the hook queue is empty.
func (f *ClientMakeRequestFunc) SetDefaultHook(hook func(context.Context, *graphql.Request, *graphql.Response) error) {
f.defaultHook = hook
}
// PushHook adds a function to the end of hook queue. Each invocation of the
// MakeRequest method of the parent MockClient instance invokes the hook at
// the front of the queue and discards it. After the queue is empty, the
// default hook function is invoked for any future action.
func (f *ClientMakeRequestFunc) PushHook(hook func(context.Context, *graphql.Request, *graphql.Response) error) {
f.mutex.Lock()
f.hooks = append(f.hooks, hook)
f.mutex.Unlock()
}
// SetDefaultReturn calls SetDefaultHook with a function that returns the
// given values.
func (f *ClientMakeRequestFunc) SetDefaultReturn(r0 error) {
f.SetDefaultHook(func(context.Context, *graphql.Request, *graphql.Response) error {
return r0
})
}
// PushReturn calls PushHook with a function that returns the given values.
func (f *ClientMakeRequestFunc) PushReturn(r0 error) {
f.PushHook(func(context.Context, *graphql.Request, *graphql.Response) error {
return r0
})
}
func (f *ClientMakeRequestFunc) nextHook() func(context.Context, *graphql.Request, *graphql.Response) error {
f.mutex.Lock()
defer f.mutex.Unlock()
if len(f.hooks) == 0 {
return f.defaultHook
}
hook := f.hooks[0]
f.hooks = f.hooks[1:]
return hook
}
func (f *ClientMakeRequestFunc) appendCall(r0 ClientMakeRequestFuncCall) {
f.mutex.Lock()
f.history = append(f.history, r0)
f.mutex.Unlock()
}
// History returns a sequence of ClientMakeRequestFuncCall objects
// describing the invocations of this function.
func (f *ClientMakeRequestFunc) History() []ClientMakeRequestFuncCall {
f.mutex.Lock()
history := make([]ClientMakeRequestFuncCall, len(f.history))
copy(history, f.history)
f.mutex.Unlock()
return history
}
// ClientMakeRequestFuncCall is an object that describes an invocation of
// method MakeRequest on an instance of MockClient.
type ClientMakeRequestFuncCall struct {
// Arg0 is the value of the 1st argument passed to this method
// invocation.
Arg0 context.Context
// Arg1 is the value of the 2nd argument passed to this method
// invocation.
Arg1 *graphql.Request
// Arg2 is the value of the 3rd argument passed to this method
// invocation.
Arg2 *graphql.Response
// Result0 is the value of the 1st result returned from this method
// invocation.
Result0 error
}
// Args returns an interface slice containing the arguments of this
// invocation.
func (c ClientMakeRequestFuncCall) Args() []interface{} {
return []interface{}{c.Arg0, c.Arg1, c.Arg2}
}
// Results returns an interface slice containing the results of this
// invocation.
func (c ClientMakeRequestFuncCall) Results() []interface{} {
return []interface{}{c.Result0}
}

View File

@ -49,6 +49,11 @@ func (v *CheckAccessTokenDotcomDotcomQueryProductSubscriptionByAccessTokenProduc
return v.ProductSubscriptionState.Uuid
}
// GetAccount returns CheckAccessTokenDotcomDotcomQueryProductSubscriptionByAccessTokenProductSubscription.Account, and is useful for accessing the field via an interface.
func (v *CheckAccessTokenDotcomDotcomQueryProductSubscriptionByAccessTokenProductSubscription) GetAccount() *ProductSubscriptionStateAccountUser {
return v.ProductSubscriptionState.Account
}
// GetIsArchived returns CheckAccessTokenDotcomDotcomQueryProductSubscriptionByAccessTokenProductSubscription.IsArchived, and is useful for accessing the field via an interface.
func (v *CheckAccessTokenDotcomDotcomQueryProductSubscriptionByAccessTokenProductSubscription) GetIsArchived() bool {
return v.ProductSubscriptionState.IsArchived
@ -94,6 +99,8 @@ type __premarshalCheckAccessTokenDotcomDotcomQueryProductSubscriptionByAccessTok
Uuid string `json:"uuid"`
Account *ProductSubscriptionStateAccountUser `json:"account"`
IsArchived bool `json:"isArchived"`
CodyGatewayAccess ProductSubscriptionStateCodyGatewayAccess `json:"codyGatewayAccess"`
@ -114,6 +121,7 @@ func (v *CheckAccessTokenDotcomDotcomQueryProductSubscriptionByAccessTokenProduc
retval.Id = v.ProductSubscriptionState.Id
retval.Uuid = v.ProductSubscriptionState.Uuid
retval.Account = v.ProductSubscriptionState.Account
retval.IsArchived = v.ProductSubscriptionState.IsArchived
retval.CodyGatewayAccess = v.ProductSubscriptionState.CodyGatewayAccess
retval.ActiveLicense = v.ProductSubscriptionState.ActiveLicense
@ -163,6 +171,11 @@ func (v *CheckDotcomUserAccessTokenDotcomDotcomQueryCodyGatewayDotcomUserByToken
return v.DotcomUserState.Id
}
// GetUsername returns CheckDotcomUserAccessTokenDotcomDotcomQueryCodyGatewayDotcomUserByTokenCodyGatewayDotcomUser.Username, and is useful for accessing the field via an interface.
func (v *CheckDotcomUserAccessTokenDotcomDotcomQueryCodyGatewayDotcomUserByTokenCodyGatewayDotcomUser) GetUsername() string {
return v.DotcomUserState.Username
}
// GetCodyGatewayAccess returns CheckDotcomUserAccessTokenDotcomDotcomQueryCodyGatewayDotcomUserByTokenCodyGatewayDotcomUser.CodyGatewayAccess, and is useful for accessing the field via an interface.
func (v *CheckDotcomUserAccessTokenDotcomDotcomQueryCodyGatewayDotcomUserByTokenCodyGatewayDotcomUser) GetCodyGatewayAccess() DotcomUserStateCodyGatewayAccess {
return v.DotcomUserState.CodyGatewayAccess
@ -196,6 +209,8 @@ func (v *CheckDotcomUserAccessTokenDotcomDotcomQueryCodyGatewayDotcomUserByToken
type __premarshalCheckDotcomUserAccessTokenDotcomDotcomQueryCodyGatewayDotcomUserByTokenCodyGatewayDotcomUser struct {
Id string `json:"id"`
Username string `json:"username"`
CodyGatewayAccess DotcomUserStateCodyGatewayAccess `json:"codyGatewayAccess"`
}
@ -211,6 +226,7 @@ func (v *CheckDotcomUserAccessTokenDotcomDotcomQueryCodyGatewayDotcomUserByToken
var retval __premarshalCheckDotcomUserAccessTokenDotcomDotcomQueryCodyGatewayDotcomUserByTokenCodyGatewayDotcomUser
retval.Id = v.DotcomUserState.Id
retval.Username = v.DotcomUserState.Username
retval.CodyGatewayAccess = v.DotcomUserState.CodyGatewayAccess
return &retval, nil
}
@ -527,6 +543,8 @@ const (
type DotcomUserState struct {
// The id of the user
Id string `json:"id"`
// The user name of the user
Username string `json:"username"`
// Cody Gateway access granted to this user. Properties may be inferred from dotcom site config, or be defined in overrides on the user.
CodyGatewayAccess DotcomUserStateCodyGatewayAccess `json:"codyGatewayAccess"`
}
@ -534,6 +552,9 @@ type DotcomUserState struct {
// GetId returns DotcomUserState.Id, and is useful for accessing the field via an interface.
func (v *DotcomUserState) GetId() string { return v.Id }
// GetUsername returns DotcomUserState.Username, and is useful for accessing the field via an interface.
func (v *DotcomUserState) GetUsername() string { return v.Username }
// GetCodyGatewayAccess returns DotcomUserState.CodyGatewayAccess, and is useful for accessing the field via an interface.
func (v *DotcomUserState) GetCodyGatewayAccess() DotcomUserStateCodyGatewayAccess {
return v.CodyGatewayAccess
@ -644,6 +665,11 @@ func (v *ListProductSubscriptionFields) GetId() string { return v.ProductSubscri
// GetUuid returns ListProductSubscriptionFields.Uuid, and is useful for accessing the field via an interface.
func (v *ListProductSubscriptionFields) GetUuid() string { return v.ProductSubscriptionState.Uuid }
// GetAccount returns ListProductSubscriptionFields.Account, and is useful for accessing the field via an interface.
func (v *ListProductSubscriptionFields) GetAccount() *ProductSubscriptionStateAccountUser {
return v.ProductSubscriptionState.Account
}
// GetIsArchived returns ListProductSubscriptionFields.IsArchived, and is useful for accessing the field via an interface.
func (v *ListProductSubscriptionFields) GetIsArchived() bool {
return v.ProductSubscriptionState.IsArchived
@ -691,6 +717,8 @@ type __premarshalListProductSubscriptionFields struct {
Uuid string `json:"uuid"`
Account *ProductSubscriptionStateAccountUser `json:"account"`
IsArchived bool `json:"isArchived"`
CodyGatewayAccess ProductSubscriptionStateCodyGatewayAccess `json:"codyGatewayAccess"`
@ -712,6 +740,7 @@ func (v *ListProductSubscriptionFields) __premarshalJSON() (*__premarshalListPro
retval.SourcegraphAccessTokens = v.SourcegraphAccessTokens
retval.Id = v.ProductSubscriptionState.Id
retval.Uuid = v.ProductSubscriptionState.Uuid
retval.Account = v.ProductSubscriptionState.Account
retval.IsArchived = v.ProductSubscriptionState.IsArchived
retval.CodyGatewayAccess = v.ProductSubscriptionState.CodyGatewayAccess
retval.ActiveLicense = v.ProductSubscriptionState.ActiveLicense
@ -788,6 +817,11 @@ func (v *ListProductSubscriptionsDotcomDotcomQueryProductSubscriptionsProductSub
return v.ListProductSubscriptionFields.ProductSubscriptionState.Uuid
}
// GetAccount returns ListProductSubscriptionsDotcomDotcomQueryProductSubscriptionsProductSubscriptionConnectionNodesProductSubscription.Account, and is useful for accessing the field via an interface.
func (v *ListProductSubscriptionsDotcomDotcomQueryProductSubscriptionsProductSubscriptionConnectionNodesProductSubscription) GetAccount() *ProductSubscriptionStateAccountUser {
return v.ListProductSubscriptionFields.ProductSubscriptionState.Account
}
// GetIsArchived returns ListProductSubscriptionsDotcomDotcomQueryProductSubscriptionsProductSubscriptionConnectionNodesProductSubscription.IsArchived, and is useful for accessing the field via an interface.
func (v *ListProductSubscriptionsDotcomDotcomQueryProductSubscriptionsProductSubscriptionConnectionNodesProductSubscription) GetIsArchived() bool {
return v.ListProductSubscriptionFields.ProductSubscriptionState.IsArchived
@ -835,6 +869,8 @@ type __premarshalListProductSubscriptionsDotcomDotcomQueryProductSubscriptionsPr
Uuid string `json:"uuid"`
Account *ProductSubscriptionStateAccountUser `json:"account"`
IsArchived bool `json:"isArchived"`
CodyGatewayAccess ProductSubscriptionStateCodyGatewayAccess `json:"codyGatewayAccess"`
@ -856,6 +892,7 @@ func (v *ListProductSubscriptionsDotcomDotcomQueryProductSubscriptionsProductSub
retval.SourcegraphAccessTokens = v.ListProductSubscriptionFields.SourcegraphAccessTokens
retval.Id = v.ListProductSubscriptionFields.ProductSubscriptionState.Id
retval.Uuid = v.ListProductSubscriptionFields.ProductSubscriptionState.Uuid
retval.Account = v.ListProductSubscriptionFields.ProductSubscriptionState.Account
retval.IsArchived = v.ListProductSubscriptionFields.ProductSubscriptionState.IsArchived
retval.CodyGatewayAccess = v.ListProductSubscriptionFields.ProductSubscriptionState.CodyGatewayAccess
retval.ActiveLicense = v.ListProductSubscriptionFields.ProductSubscriptionState.ActiveLicense
@ -907,6 +944,8 @@ type ProductSubscriptionState struct {
// The unique UUID of this product subscription. Unlike ProductSubscription.id, this does not
// encode the type and is not a GraphQL node ID.
Uuid string `json:"uuid"`
// The user (i.e., customer) to whom this subscription is granted, or null if the account has been deleted.
Account *ProductSubscriptionStateAccountUser `json:"account"`
// Whether this product subscription was archived.
IsArchived bool `json:"isArchived"`
// Cody Gateway access granted to this subscription. Properties may be inferred from the active license, or be defined in overrides.
@ -921,6 +960,11 @@ func (v *ProductSubscriptionState) GetId() string { return v.Id }
// GetUuid returns ProductSubscriptionState.Uuid, and is useful for accessing the field via an interface.
func (v *ProductSubscriptionState) GetUuid() string { return v.Uuid }
// GetAccount returns ProductSubscriptionState.Account, and is useful for accessing the field via an interface.
func (v *ProductSubscriptionState) GetAccount() *ProductSubscriptionStateAccountUser {
return v.Account
}
// GetIsArchived returns ProductSubscriptionState.IsArchived, and is useful for accessing the field via an interface.
func (v *ProductSubscriptionState) GetIsArchived() bool { return v.IsArchived }
@ -934,6 +978,18 @@ func (v *ProductSubscriptionState) GetActiveLicense() *ProductSubscriptionStateA
return v.ActiveLicense
}
// ProductSubscriptionStateAccountUser includes the requested fields of the GraphQL type User.
// The GraphQL type's documentation follows.
//
// A user.
type ProductSubscriptionStateAccountUser struct {
// The user's username.
Username string `json:"username"`
}
// GetUsername returns ProductSubscriptionStateAccountUser.Username, and is useful for accessing the field via an interface.
func (v *ProductSubscriptionStateAccountUser) GetUsername() string { return v.Username }
// ProductSubscriptionStateActiveLicenseProductLicense includes the requested fields of the GraphQL type ProductLicense.
// The GraphQL type's documentation follows.
//
@ -1119,6 +1175,9 @@ query CheckAccessToken ($token: String!) {
fragment ProductSubscriptionState on ProductSubscription {
id
uuid
account {
username
}
isArchived
codyGatewayAccess {
... CodyGatewayAccessFields
@ -1185,6 +1244,7 @@ query CheckDotcomUserAccessToken ($token: String!) {
}
fragment DotcomUserState on CodyGatewayDotcomUser {
id
username
codyGatewayAccess {
... CodyGatewayAccessFields
}
@ -1254,6 +1314,9 @@ fragment ListProductSubscriptionFields on ProductSubscription {
fragment ProductSubscriptionState on ProductSubscription {
id
uuid
account {
username
}
isArchived
codyGatewayAccess {
... CodyGatewayAccessFields

View File

@ -21,6 +21,9 @@ fragment CodyGatewayAccessFields on CodyGatewayAccess {
fragment ProductSubscriptionState on ProductSubscription {
id
uuid
account {
username
}
isArchived
codyGatewayAccess {
...CodyGatewayAccessFields
@ -64,6 +67,7 @@ query ListProductSubscriptions {
fragment DotcomUserState on CodyGatewayDotcomUser {
id
username
codyGatewayAccess {
...CodyGatewayAccessFields
}

View File

@ -205,6 +205,6 @@ type listLimitElement struct {
Expiry *time.Time `json:"expiry,omitempty"`
}
func noopRateLimitNotifier(ctx context.Context, actorID string, actorSource codygateway.ActorSource, feature codygateway.Feature, usageRatio float32, ttl time.Duration) {
func noopRateLimitNotifier(ctx context.Context, actor codygateway.Actor, feature codygateway.Feature, usageRatio float32, ttl time.Duration) {
// nothing
}

View File

@ -26,7 +26,7 @@ var tracer = otel.Tracer("internal/notify")
// given thresholds. At most one notification will be sent per actor per
// threshold until the TTL is reached (that clears the counter). It is best to
// align the TTL with the rate limit window.
type RateLimitNotifier func(ctx context.Context, actorID string, actorSource codygateway.ActorSource, feature codygateway.Feature, usageRatio float32, ttl time.Duration)
type RateLimitNotifier func(ctx context.Context, actor codygateway.Actor, feature codygateway.Feature, usageRatio float32, ttl time.Duration)
// Thresholds map actor sources to percentage rate limit usage increments
// to notify on. Each threshold will only trigger the notification once during
@ -55,8 +55,8 @@ func NewSlackRateLimitNotifier(
) RateLimitNotifier {
baseLogger = baseLogger.Scoped("slackRateLimitNotifier", "notifications for usage rate limit approaching thresholds")
return func(ctx context.Context, actorID string, actorSource codygateway.ActorSource, feature codygateway.Feature, usageRatio float32, ttl time.Duration) {
thresholds := actorSourceThresholds.Get(actorSource)
return func(ctx context.Context, actor codygateway.Actor, feature codygateway.Feature, usageRatio float32, ttl time.Duration) {
thresholds := actorSourceThresholds.Get(actor.GetSource())
if len(thresholds) == 0 {
return
}
@ -73,7 +73,7 @@ func NewSlackRateLimitNotifier(
attribute.Float64("alert.ttlSeconds", ttl.Seconds())))
logger := sgtrace.Logger(ctx, baseLogger)
if err := handleNotify(ctx, logger, rs, dotcomURL, thresholds, slackWebhookURL, slackSender, actorID, actorSource, feature, usagePercentage, ttl); err != nil {
if err := handleNotify(ctx, logger, rs, dotcomURL, thresholds, slackWebhookURL, slackSender, actor, feature, usagePercentage, ttl); err != nil {
span.RecordError(err)
logger.Error("failed to notification", log.Error(err))
}
@ -92,15 +92,14 @@ func handleNotify(
slackWebhookURL string,
slackSender func(ctx context.Context, url string, msg *slack.WebhookMessage) error,
actorID string,
actorSource codygateway.ActorSource,
actor codygateway.Actor,
feature codygateway.Feature,
usagePercentage int,
ttl time.Duration,
) error {
span := trace.SpanFromContext(ctx)
lockKey := fmt.Sprintf("rate_limit:%s:alert:lock:%s", feature, actorID)
lockKey := fmt.Sprintf("rate_limit:%s:alert:lock:%s", feature, actor.GetID())
acquired, release, err := redislock.TryAcquire(rs, lockKey, 30*time.Second)
span.SetAttributes(attribute.Bool("lock.acquired", acquired))
if err != nil {
@ -119,7 +118,7 @@ func handleNotify(
}
span.SetAttributes(attribute.Int("bucket", bucket))
key := fmt.Sprintf("rate_limit:%s:alert:%s", feature, actorID)
key := fmt.Sprintf("rate_limit:%s:alert:%s", feature, actor.GetID())
lastBucket, err := rs.Get(key).Int()
if err != nil && err != redis.ErrNil {
return errors.Wrap(err, "failed to get last alert bucket")
@ -140,8 +139,8 @@ func handleNotify(
if slackWebhookURL == "" {
logger.Debug("new usage alert",
log.Object("actor",
log.String("id", actorID),
log.String("source", string(actorSource)),
log.String("id", actor.GetID()),
log.String("source", string(actor.GetSource())),
),
log.String("feature", string(feature)),
log.Int("usagePercentage", usagePercentage),
@ -150,18 +149,18 @@ func handleNotify(
}
var actorLink string
switch actorSource {
switch actor.GetSource() {
case codygateway.ActorSourceProductSubscription:
actorLink = fmt.Sprintf("<%[1]s/site-admin/dotcom/product/subscriptions/%[2]s|%[2]s>", dotcomURL, actorID)
actorLink = fmt.Sprintf("<%s/site-admin/dotcom/product/subscriptions/%s|%s>", dotcomURL, actor.GetID(), actor.GetName())
default:
actorLink = fmt.Sprintf("`%s`", actorID)
actorLink = fmt.Sprintf("`%s`", actor.GetID())
}
span.SetAttributes(
attribute.String("actor.link", actorLink),
attribute.Bool("sendToSlack", true))
text := fmt.Sprintf("The actor %s from %q has exceeded *%d%%* of its rate limit quota for `%s`. The quota will reset in `%s` at `%s`.",
actorLink, actorSource, usagePercentage, feature, ttl.String(), time.Now().Add(ttl).Format(time.RFC3339))
actorLink, actor.GetSource(), usagePercentage, feature, ttl.String(), time.Now().Add(ttl).Format(time.RFC3339))
// NOTE: The context timeout must below the lock timeout we set above (30 seconds
// ) to make sure the lock doesn't expire when we release it, i.e. avoid

View File

@ -27,6 +27,16 @@ func TestThresholds(t *testing.T) {
autogold.Expect([]int{}).Equal(t, th.Get(codygateway.ActorSource("anonymous")))
}
type mockActor struct {
id string
name string
source codygateway.ActorSource
}
func (m *mockActor) GetID() string { return m.id }
func (m *mockActor) GetName() string { return m.name }
func (m *mockActor) GetSource() codygateway.ActorSource { return m.source }
func TestSlackRateLimitNotifier(t *testing.T) {
logger := logtest.NoOp(t)
@ -91,8 +101,11 @@ func TestSlackRateLimitNotifier(t *testing.T) {
)
alerter(context.Background(),
"alice",
codygateway.ActorSourceProductSubscription,
&mockActor{
id: "foobar",
name: "alice",
source: codygateway.ActorSourceProductSubscription,
},
codygateway.FeatureChatCompletions,
test.usageRatio,
time.Minute)

View File

@ -4,6 +4,7 @@ go_library(
name = "codygateway",
srcs = [
"consts.go",
"ifaces.go",
"types.go",
],
importpath = "github.com/sourcegraph/sourcegraph/internal/codygateway",

View File

@ -0,0 +1,11 @@
package codygateway
// Actor represents an actor that is making requests to the Cody Gateway.
type Actor interface {
// GetID returns the unique identifier for this actor.
GetID() string
// GetName returns the human-readable name for this actor.
GetName() string
// GetSource returns the source of this actor.
GetSource() ActorSource
}

View File

@ -16,8 +16,8 @@ var AllFeatures = []Feature{
// NOTE: When you add a new feature here, make sure to add it to the slice above as well.
const (
FeatureCodeCompletions Feature = Feature(types.CompletionsFeatureCode)
FeatureChatCompletions Feature = Feature(types.CompletionsFeatureChat)
FeatureCodeCompletions = Feature(types.CompletionsFeatureCode)
FeatureChatCompletions = Feature(types.CompletionsFeatureChat)
FeatureEmbeddings Feature = "embeddings"
)

View File

@ -385,6 +385,8 @@
- path: github.com/gregjones/httpcache
interfaces:
- Cache
- filename: enterprise/cmd/cody-gateway/internal/dotcom/mocks.go
sources:
- path: github.com/Khan/genqlient/graphql
interfaces:
- Client