[license-checks] use dotcom url from env (#53936)

## Test plan

Unit tests
This commit is contained in:
Milan Freml 2023-06-22 14:17:05 +02:00 committed by GitHub
parent 122cb76558
commit 2dbfaa16c2
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 57 additions and 17 deletions

View File

@ -6,6 +6,7 @@ import (
"time"
"net/http"
"net/url"
"github.com/derision-test/glock"
"github.com/sourcegraph/log"
@ -13,6 +14,7 @@ import (
"context"
"github.com/sourcegraph/sourcegraph/internal/conf"
"github.com/sourcegraph/sourcegraph/internal/env"
"github.com/sourcegraph/sourcegraph/internal/goroutine"
"github.com/sourcegraph/sourcegraph/internal/httpcli"
"github.com/sourcegraph/sourcegraph/internal/licensing"
@ -23,6 +25,7 @@ import (
var (
licenseCheckStarted = false
store = redispool.Store
baseUrl = env.Get("SOURCEGRAPH_API_URL", "https://sourcegraph.com", "Base URL for license check API")
)
const (
@ -41,12 +44,16 @@ type licenseChecker struct {
func (l *licenseChecker) Handle(ctx context.Context) error {
l.logger.Debug("starting license check", log.String("siteID", l.siteID))
store.Set(lastCalledAtStoreKey, time.Now().Format(time.RFC3339))
if err := store.Set(lastCalledAtStoreKey, time.Now().Format(time.RFC3339)); err != nil {
return err
}
// skip if has explicitly allowed air-gapped feature
if err := Check(FeatureAllowAirGapped); err == nil {
l.logger.Debug("license is air-gapped, skipping check", log.String("siteID", l.siteID))
store.Set(licenseValidityStoreKey, true)
if err := store.Set(licenseValidityStoreKey, true); err != nil {
return err
}
return nil
}
@ -54,9 +61,11 @@ func (l *licenseChecker) Handle(ctx context.Context) error {
if err != nil {
return err
}
if info.HasTag("dev") {
l.logger.Debug("dev license, skipping license verification check")
store.Set(licenseValidityStoreKey, true)
if info.HasTag("dev") || info.HasTag("internal") {
l.logger.Debug("internal or dev license, skipping license verification check")
if err := store.Set(licenseValidityStoreKey, true); err != nil {
return err
}
return nil
}
@ -68,7 +77,12 @@ func (l *licenseChecker) Handle(ctx context.Context) error {
return err
}
req, err := http.NewRequest(http.MethodPost, "https://sourcegraph.com/.api/license/check", bytes.NewBuffer(payload))
u, err := url.JoinPath(baseUrl, "/.api/license/check")
if err != nil {
return err
}
req, err := http.NewRequest(http.MethodPost, u, bytes.NewBuffer(payload))
if err != nil {
return err
}
@ -103,7 +117,9 @@ func (l *licenseChecker) Handle(ctx context.Context) error {
return errors.New("No data returned from license check")
}
store.Set(licenseValidityStoreKey, body.Data.IsValid)
if err := store.Set(licenseValidityStoreKey, body.Data.IsValid); err != nil {
return err
}
l.logger.Debug("finished license check", log.String("siteID", l.siteID))
return nil
}

View File

@ -6,6 +6,7 @@ import (
"encoding/json"
"io"
"net/http"
"net/url"
"testing"
"time"
@ -14,8 +15,10 @@ import (
"github.com/stretchr/testify/require"
"github.com/sourcegraph/log/logtest"
"github.com/sourcegraph/sourcegraph/enterprise/internal/license"
"github.com/sourcegraph/sourcegraph/internal/redispool"
"github.com/sourcegraph/sourcegraph/lib/pointers"
)
func Test_calcDurationToWaitForNextHandle(t *testing.T) {
@ -26,8 +29,8 @@ func Test_calcDurationToWaitForNextHandle(t *testing.T) {
})
cleanupStore := func() {
store.Del(licenseValidityStoreKey)
store.Del(lastCalledAtStoreKey)
_ = store.Del(licenseValidityStoreKey)
_ = store.Del(lastCalledAtStoreKey)
}
now := time.Now().Round(time.Second)
@ -74,7 +77,7 @@ func Test_calcDurationToWaitForNextHandle(t *testing.T) {
t.Run(name, func(t *testing.T) {
cleanupStore()
if test.lastCalledAt != "" {
store.Set(lastCalledAtStoreKey, test.lastCalledAt)
_ = store.Set(lastCalledAtStoreKey, test.lastCalledAt)
}
got, err := calcDurationSinceLastCalled(clock)
@ -88,6 +91,19 @@ func Test_calcDurationToWaitForNextHandle(t *testing.T) {
}
}
func mockDotcomURL(t *testing.T, u *string) {
t.Helper()
origBaseURL := baseUrl
t.Cleanup(func() {
baseUrl = origBaseURL
})
if u != nil {
baseUrl = *u
}
}
func Test_licenseChecker(t *testing.T) {
// Connect to local redis for testing, this is the same URL used in rcache.SetupForTest
store = redispool.NewKeyValue("127.0.0.1:6379", &redis.Pool{
@ -96,8 +112,8 @@ func Test_licenseChecker(t *testing.T) {
})
cleanupStore := func() {
store.Del(licenseValidityStoreKey)
store.Del(lastCalledAtStoreKey)
_ = store.Del(licenseValidityStoreKey)
_ = store.Del(lastCalledAtStoreKey)
}
siteID := "some-site-id"
@ -159,8 +175,8 @@ func Test_licenseChecker(t *testing.T) {
MockGetConfiguredProductLicenseInfo = defaultMockGetLicense
})
store.Del(licenseValidityStoreKey)
store.Del(lastCalledAtStoreKey)
_ = store.Del(licenseValidityStoreKey)
_ = store.Del(lastCalledAtStoreKey)
doer := &mockDoer{
status: '1',
@ -195,6 +211,7 @@ func Test_licenseChecker(t *testing.T) {
status int
want bool
err bool
baseUrl *string
}{
"returns error if unable to make a request to license server": {
response: []byte(`{"error": "some error"}`),
@ -216,12 +233,20 @@ func Test_licenseChecker(t *testing.T) {
status: http.StatusOK,
want: false,
},
`uses sourcegraph baseURL from env`: {
response: []byte(`{"data": {"is_valid": true}}`),
status: http.StatusOK,
want: true,
baseUrl: pointers.Ptr("https://foo.bar"),
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
cleanupStore()
mockDotcomURL(t, test.baseUrl)
doer := &mockDoer{
status: test.status,
response: test.response,
@ -254,9 +279,10 @@ func Test_licenseChecker(t *testing.T) {
require.NotEmpty(t, lastCalledAt)
// check doer with proper parameters
rUrl, _ := url.JoinPath(baseUrl, "/.api/license/check")
require.True(t, doer.DoCalled)
require.Equal(t, "POST", doer.Request.Method)
require.Equal(t, "https://sourcegraph.com/.api/license/check", doer.Request.URL.String())
require.Equal(t, rUrl, doer.Request.URL.String())
require.Equal(t, "application/json", doer.Request.Header.Get("Content-Type"))
require.Equal(t, "Bearer "+token, doer.Request.Header.Get("Authorization"))
var body struct {
@ -269,8 +295,6 @@ func Test_licenseChecker(t *testing.T) {
}
}
var strPtr = func(s string) *string { return &s }
type mockDoer struct {
DoCalled bool
Request *http.Request