diff --git a/enterprise/internal/licensing/check.go b/enterprise/internal/licensing/check.go index cc0d83d68d1..ea9b5ed9229 100644 --- a/enterprise/internal/licensing/check.go +++ b/enterprise/internal/licensing/check.go @@ -6,6 +6,7 @@ import ( "time" "net/http" + "net/url" "github.com/derision-test/glock" "github.com/sourcegraph/log" @@ -13,6 +14,7 @@ import ( "context" "github.com/sourcegraph/sourcegraph/internal/conf" + "github.com/sourcegraph/sourcegraph/internal/env" "github.com/sourcegraph/sourcegraph/internal/goroutine" "github.com/sourcegraph/sourcegraph/internal/httpcli" "github.com/sourcegraph/sourcegraph/internal/licensing" @@ -23,6 +25,7 @@ import ( var ( licenseCheckStarted = false store = redispool.Store + baseUrl = env.Get("SOURCEGRAPH_API_URL", "https://sourcegraph.com", "Base URL for license check API") ) const ( @@ -41,12 +44,16 @@ type licenseChecker struct { func (l *licenseChecker) Handle(ctx context.Context) error { l.logger.Debug("starting license check", log.String("siteID", l.siteID)) - store.Set(lastCalledAtStoreKey, time.Now().Format(time.RFC3339)) + if err := store.Set(lastCalledAtStoreKey, time.Now().Format(time.RFC3339)); err != nil { + return err + } // skip if has explicitly allowed air-gapped feature if err := Check(FeatureAllowAirGapped); err == nil { l.logger.Debug("license is air-gapped, skipping check", log.String("siteID", l.siteID)) - store.Set(licenseValidityStoreKey, true) + if err := store.Set(licenseValidityStoreKey, true); err != nil { + return err + } return nil } @@ -54,9 +61,11 @@ func (l *licenseChecker) Handle(ctx context.Context) error { if err != nil { return err } - if info.HasTag("dev") { - l.logger.Debug("dev license, skipping license verification check") - store.Set(licenseValidityStoreKey, true) + if info.HasTag("dev") || info.HasTag("internal") { + l.logger.Debug("internal or dev license, skipping license verification check") + if err := store.Set(licenseValidityStoreKey, true); err != nil { + return err + } return nil } @@ -68,7 +77,12 @@ func (l *licenseChecker) Handle(ctx context.Context) error { return err } - req, err := http.NewRequest(http.MethodPost, "https://sourcegraph.com/.api/license/check", bytes.NewBuffer(payload)) + u, err := url.JoinPath(baseUrl, "/.api/license/check") + if err != nil { + return err + } + + req, err := http.NewRequest(http.MethodPost, u, bytes.NewBuffer(payload)) if err != nil { return err } @@ -103,7 +117,9 @@ func (l *licenseChecker) Handle(ctx context.Context) error { return errors.New("No data returned from license check") } - store.Set(licenseValidityStoreKey, body.Data.IsValid) + if err := store.Set(licenseValidityStoreKey, body.Data.IsValid); err != nil { + return err + } l.logger.Debug("finished license check", log.String("siteID", l.siteID)) return nil } diff --git a/enterprise/internal/licensing/check_test.go b/enterprise/internal/licensing/check_test.go index b2cbd6f595a..7027028a9cc 100644 --- a/enterprise/internal/licensing/check_test.go +++ b/enterprise/internal/licensing/check_test.go @@ -6,6 +6,7 @@ import ( "encoding/json" "io" "net/http" + "net/url" "testing" "time" @@ -14,8 +15,10 @@ import ( "github.com/stretchr/testify/require" "github.com/sourcegraph/log/logtest" + "github.com/sourcegraph/sourcegraph/enterprise/internal/license" "github.com/sourcegraph/sourcegraph/internal/redispool" + "github.com/sourcegraph/sourcegraph/lib/pointers" ) func Test_calcDurationToWaitForNextHandle(t *testing.T) { @@ -26,8 +29,8 @@ func Test_calcDurationToWaitForNextHandle(t *testing.T) { }) cleanupStore := func() { - store.Del(licenseValidityStoreKey) - store.Del(lastCalledAtStoreKey) + _ = store.Del(licenseValidityStoreKey) + _ = store.Del(lastCalledAtStoreKey) } now := time.Now().Round(time.Second) @@ -74,7 +77,7 @@ func Test_calcDurationToWaitForNextHandle(t *testing.T) { t.Run(name, func(t *testing.T) { cleanupStore() if test.lastCalledAt != "" { - store.Set(lastCalledAtStoreKey, test.lastCalledAt) + _ = store.Set(lastCalledAtStoreKey, test.lastCalledAt) } got, err := calcDurationSinceLastCalled(clock) @@ -88,6 +91,19 @@ func Test_calcDurationToWaitForNextHandle(t *testing.T) { } } +func mockDotcomURL(t *testing.T, u *string) { + t.Helper() + + origBaseURL := baseUrl + t.Cleanup(func() { + baseUrl = origBaseURL + }) + + if u != nil { + baseUrl = *u + } +} + func Test_licenseChecker(t *testing.T) { // Connect to local redis for testing, this is the same URL used in rcache.SetupForTest store = redispool.NewKeyValue("127.0.0.1:6379", &redis.Pool{ @@ -96,8 +112,8 @@ func Test_licenseChecker(t *testing.T) { }) cleanupStore := func() { - store.Del(licenseValidityStoreKey) - store.Del(lastCalledAtStoreKey) + _ = store.Del(licenseValidityStoreKey) + _ = store.Del(lastCalledAtStoreKey) } siteID := "some-site-id" @@ -159,8 +175,8 @@ func Test_licenseChecker(t *testing.T) { MockGetConfiguredProductLicenseInfo = defaultMockGetLicense }) - store.Del(licenseValidityStoreKey) - store.Del(lastCalledAtStoreKey) + _ = store.Del(licenseValidityStoreKey) + _ = store.Del(lastCalledAtStoreKey) doer := &mockDoer{ status: '1', @@ -195,6 +211,7 @@ func Test_licenseChecker(t *testing.T) { status int want bool err bool + baseUrl *string }{ "returns error if unable to make a request to license server": { response: []byte(`{"error": "some error"}`), @@ -216,12 +233,20 @@ func Test_licenseChecker(t *testing.T) { status: http.StatusOK, want: false, }, + `uses sourcegraph baseURL from env`: { + response: []byte(`{"data": {"is_valid": true}}`), + status: http.StatusOK, + want: true, + baseUrl: pointers.Ptr("https://foo.bar"), + }, } for name, test := range tests { t.Run(name, func(t *testing.T) { cleanupStore() + mockDotcomURL(t, test.baseUrl) + doer := &mockDoer{ status: test.status, response: test.response, @@ -254,9 +279,10 @@ func Test_licenseChecker(t *testing.T) { require.NotEmpty(t, lastCalledAt) // check doer with proper parameters + rUrl, _ := url.JoinPath(baseUrl, "/.api/license/check") require.True(t, doer.DoCalled) require.Equal(t, "POST", doer.Request.Method) - require.Equal(t, "https://sourcegraph.com/.api/license/check", doer.Request.URL.String()) + require.Equal(t, rUrl, doer.Request.URL.String()) require.Equal(t, "application/json", doer.Request.Header.Get("Content-Type")) require.Equal(t, "Bearer "+token, doer.Request.Header.Get("Authorization")) var body struct { @@ -269,8 +295,6 @@ func Test_licenseChecker(t *testing.T) { } } -var strPtr = func(s string) *string { return &s } - type mockDoer struct { DoCalled bool Request *http.Request