diff --git a/cmd/frontend/graphqlbackend/product_license_info.go b/cmd/frontend/graphqlbackend/product_license_info.go index 834b787a536..cf822e4664d 100644 --- a/cmd/frontend/graphqlbackend/product_license_info.go +++ b/cmd/frontend/graphqlbackend/product_license_info.go @@ -48,7 +48,7 @@ func (r ProductLicenseInfo) ProductNameWithBrand() string { } func (r ProductLicenseInfo) IsFreePlan() bool { - return r.Plan.IsFree() + return r.Plan.IsFreePlan() } func (r ProductLicenseInfo) Tags() []string { return r.TagsValue } diff --git a/cmd/frontend/graphqlbackend/product_subscription_status.go b/cmd/frontend/graphqlbackend/product_subscription_status.go index c3d99e6e197..5603af38802 100644 --- a/cmd/frontend/graphqlbackend/product_subscription_status.go +++ b/cmd/frontend/graphqlbackend/product_subscription_status.go @@ -34,7 +34,7 @@ func (productSubscriptionStatus) NoLicenseWarningUserCount(ctx context.Context) } // We only show this warning to free license instances. - if !info.Plan.IsFree() { + if !info.Plan.IsFreePlan() { return nil, nil } @@ -46,7 +46,7 @@ func (productSubscriptionStatus) MaximumAllowedUserCount(ctx context.Context) (* if err != nil { return nil, err } - if !info.Plan.IsFree() { + if !info.Plan.IsFreePlan() { tmp := info.UserCount() return &tmp, nil } diff --git a/cmd/frontend/graphqlbackend/site_flags.go b/cmd/frontend/graphqlbackend/site_flags.go index 2a38c75db17..79d32b7ecb5 100644 --- a/cmd/frontend/graphqlbackend/site_flags.go +++ b/cmd/frontend/graphqlbackend/site_flags.go @@ -55,7 +55,7 @@ func (r *siteResolver) FreeUsersExceeded(ctx context.Context) (bool, error) { return false, err } // Only show alert if the license is a free plan. - if !info.Plan.IsFree() { + if !info.Plan.IsFreePlan() { return false, nil } diff --git a/cmd/frontend/internal/licensing/enforcement/users.go b/cmd/frontend/internal/licensing/enforcement/users.go index 7f9c3f76ba0..723a9397779 100644 --- a/cmd/frontend/internal/licensing/enforcement/users.go +++ b/cmd/frontend/internal/licensing/enforcement/users.go @@ -88,7 +88,7 @@ func NewAfterCreateUserHook() func(context.Context, database.DB, *types.User) er return err } - if info.Plan().IsFree() { + if info.Plan().IsFreePlan() { store := tx.Users() user.SiteAdmin = true if err := store.SetIsSiteAdmin(ctx, user.ID, user.SiteAdmin); err != nil { @@ -122,12 +122,12 @@ func NewBeforeSetUserIsSiteAdmin() func(ctx context.Context, isSiteAdmin bool) e if info.IsExpired() { return errors.New("The Sourcegraph license has expired. No site-admins can be created until the license is updated.") } - if !info.Plan().IsFree() { + if !info.Plan().IsFreePlan() { return nil } // Allow users to be promoted to site admins on the Free plan. - if info.Plan().IsFree() && isSiteAdmin { + if info.Plan().IsFreePlan() && isSiteAdmin { return nil } } diff --git a/cmd/worker/internal/licensecheck/check.go b/cmd/worker/internal/licensecheck/check.go index 7c48fe8feba..bc6d510f079 100644 --- a/cmd/worker/internal/licensecheck/check.go +++ b/cmd/worker/internal/licensecheck/check.go @@ -59,8 +59,8 @@ func (l *licenseChecker) Handle(ctx context.Context) error { if err != nil { return err } - if info.HasTag("dev") || info.HasTag("internal") { - l.logger.Debug("internal or dev license, skipping license verification check") + if info.HasTag("dev") || info.HasTag("internal") || info.Plan().IsFreePlan() { + l.logger.Debug("internal, dev, or free license, skipping license verification check") if err := store.Set(licensing.LicenseValidityStoreKey, true); err != nil { return err } @@ -70,7 +70,6 @@ func (l *licenseChecker) Handle(ctx context.Context) error { payload, err := json.Marshal(struct { ClientSiteID string `json:"siteID"` }{ClientSiteID: l.siteID}) - if err != nil { return err } @@ -155,7 +154,6 @@ func calcDurationSinceLastCalled(clock glock.Clock) (time.Duration, error) { // license validity from dotcom and stores the result in redis. // It re-runs the check if the license key changes. func StartLicenseCheck(originalCtx context.Context, logger log.Logger, db database.DB) { - if licenseCheckStarted { logger.Info("license check already started") return diff --git a/cmd/worker/internal/licensecheck/check_test.go b/cmd/worker/internal/licensecheck/check_test.go index 52191a61e8c..38bac4095eb 100644 --- a/cmd/worker/internal/licensecheck/check_test.go +++ b/cmd/worker/internal/licensecheck/check_test.go @@ -120,92 +120,64 @@ func Test_licenseChecker(t *testing.T) { siteID := "some-site-id" token := "test-token" - t.Run("skips check if license is air-gapped", func(t *testing.T) { - cleanupStore() - var featureChecked licensing.Feature - defaultMock := licensing.MockCheckFeature - licensing.MockCheckFeature = func(feature licensing.Feature) error { - featureChecked = feature - return nil - } - - t.Cleanup(func() { - licensing.MockCheckFeature = defaultMock - }) - - doer := &mockDoer{ - status: '1', - response: []byte(``), - } - handler := licenseChecker{ - siteID: siteID, - token: token, - doer: doer, - logger: logtest.NoOp(t), - } - - err := handler.Handle(context.Background()) - require.NoError(t, err) - - // check feature was checked - require.Equal(t, licensing.FeatureAllowAirGapped, featureChecked) - - // check doer NOT called - require.False(t, doer.DoCalled) - - // check result was set to true - valid, err := store.Get(licensing.LicenseValidityStoreKey).Bool() - require.NoError(t, err) - require.True(t, valid) - - // check last called at was set - lastCalledAt, err := store.Get(lastCalledAtStoreKey).String() - require.NoError(t, err) - require.NotEmpty(t, lastCalledAt) - }) - - t.Run("skips check if license has dev tag", func(t *testing.T) { - defaultMockGetLicense := licensing.MockGetConfiguredProductLicenseInfo - licensing.MockGetConfiguredProductLicenseInfo = func() (*license.Info, string, error) { - return &license.Info{ + skipTests := map[string]struct { + license *license.Info + }{ + "skips check if license is air gapped": { + license: &license.Info{ + Tags: []string{string(licensing.FeatureAllowAirGapped)}, + }, + }, + "skips check on dev license": { + license: &license.Info{ Tags: []string{"dev"}, - }, "", nil - } + }, + }, + "skips check on free license": { + license: &licensing.GetFreeLicenseInfo().Info, + }, + } - t.Cleanup(func() { - licensing.MockGetConfiguredProductLicenseInfo = defaultMockGetLicense + for name, test := range skipTests { + t.Run(name, func(t *testing.T) { + cleanupStore() + defaultMockGetLicense := licensing.MockGetConfiguredProductLicenseInfo + licensing.MockGetConfiguredProductLicenseInfo = func() (*license.Info, string, error) { + return test.license, "", nil + } + + t.Cleanup(func() { + licensing.MockGetConfiguredProductLicenseInfo = defaultMockGetLicense + }) + + doer := &mockDoer{ + status: '1', + response: []byte(``), + } + handler := licenseChecker{ + siteID: siteID, + token: token, + doer: doer, + logger: logtest.NoOp(t), + } + + err := handler.Handle(context.Background()) + require.NoError(t, err) + + // check doer NOT called + require.False(t, doer.DoCalled) + + // check result was set to true + valid, err := store.Get(licensing.LicenseValidityStoreKey).Bool() + require.NoError(t, err) + require.True(t, valid) + + // check last called at was set + lastCalledAt, err := store.Get(lastCalledAtStoreKey).String() + require.NoError(t, err) + require.NotEmpty(t, lastCalledAt) }) - - _ = store.Del(licensing.LicenseValidityStoreKey) - _ = store.Del(lastCalledAtStoreKey) - - doer := &mockDoer{ - status: '1', - response: []byte(``), - } - handler := licenseChecker{ - siteID: siteID, - token: token, - doer: doer, - logger: logtest.NoOp(t), - } - - err := handler.Handle(context.Background()) - require.NoError(t, err) - - // check doer NOT called - require.False(t, doer.DoCalled) - - // check result was set to true - valid, err := store.Get(licensing.LicenseValidityStoreKey).Bool() - require.NoError(t, err) - require.True(t, valid) - - // check last called at was set - lastCalledAt, err := store.Get(lastCalledAtStoreKey).String() - require.NoError(t, err) - require.NotEmpty(t, lastCalledAt) - }) + } tests := map[string]struct { response []byte @@ -247,6 +219,13 @@ func Test_licenseChecker(t *testing.T) { for name, test := range tests { t.Run(name, func(t *testing.T) { cleanupStore() + defaultMockGetLicense := licensing.MockGetConfiguredProductLicenseInfo + licensing.MockGetConfiguredProductLicenseInfo = func() (*license.Info, string, error) { + return &license.Info{Tags: []string{"plan:enterprise-0"}}, "", nil + } + t.Cleanup(func() { + licensing.MockGetConfiguredProductLicenseInfo = defaultMockGetLicense + }) mockDotcomURL(t, test.baseUrl) diff --git a/internal/licensing/plans.go b/internal/licensing/plans.go index 4662b09f98a..4c5f21a2907 100644 --- a/internal/licensing/plans.go +++ b/internal/licensing/plans.go @@ -58,10 +58,6 @@ func (p Plan) isKnown() bool { return false } -func (p Plan) IsFree() bool { - return p == PlanFree0 || p == PlanFree1 -} - // Plan is the pricing plan of the license. func (info *Info) Plan() Plan { return PlanFromTags(info.Tags)