mirror of
https://github.com/sourcegraph/sourcegraph.git
synced 2026-02-06 18:51:59 +00:00
[license-checks] use dotcom url from env (#53936)
## Test plan Unit tests
This commit is contained in:
parent
122cb76558
commit
2dbfaa16c2
@ -6,6 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/derision-test/glock"
|
||||
"github.com/sourcegraph/log"
|
||||
@ -13,6 +14,7 @@ import (
|
||||
"context"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/internal/conf"
|
||||
"github.com/sourcegraph/sourcegraph/internal/env"
|
||||
"github.com/sourcegraph/sourcegraph/internal/goroutine"
|
||||
"github.com/sourcegraph/sourcegraph/internal/httpcli"
|
||||
"github.com/sourcegraph/sourcegraph/internal/licensing"
|
||||
@ -23,6 +25,7 @@ import (
|
||||
var (
|
||||
licenseCheckStarted = false
|
||||
store = redispool.Store
|
||||
baseUrl = env.Get("SOURCEGRAPH_API_URL", "https://sourcegraph.com", "Base URL for license check API")
|
||||
)
|
||||
|
||||
const (
|
||||
@ -41,12 +44,16 @@ type licenseChecker struct {
|
||||
|
||||
func (l *licenseChecker) Handle(ctx context.Context) error {
|
||||
l.logger.Debug("starting license check", log.String("siteID", l.siteID))
|
||||
store.Set(lastCalledAtStoreKey, time.Now().Format(time.RFC3339))
|
||||
if err := store.Set(lastCalledAtStoreKey, time.Now().Format(time.RFC3339)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// skip if has explicitly allowed air-gapped feature
|
||||
if err := Check(FeatureAllowAirGapped); err == nil {
|
||||
l.logger.Debug("license is air-gapped, skipping check", log.String("siteID", l.siteID))
|
||||
store.Set(licenseValidityStoreKey, true)
|
||||
if err := store.Set(licenseValidityStoreKey, true); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -54,9 +61,11 @@ func (l *licenseChecker) Handle(ctx context.Context) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if info.HasTag("dev") {
|
||||
l.logger.Debug("dev license, skipping license verification check")
|
||||
store.Set(licenseValidityStoreKey, true)
|
||||
if info.HasTag("dev") || info.HasTag("internal") {
|
||||
l.logger.Debug("internal or dev license, skipping license verification check")
|
||||
if err := store.Set(licenseValidityStoreKey, true); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -68,7 +77,12 @@ func (l *licenseChecker) Handle(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, "https://sourcegraph.com/.api/license/check", bytes.NewBuffer(payload))
|
||||
u, err := url.JoinPath(baseUrl, "/.api/license/check")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req, err := http.NewRequest(http.MethodPost, u, bytes.NewBuffer(payload))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -103,7 +117,9 @@ func (l *licenseChecker) Handle(ctx context.Context) error {
|
||||
return errors.New("No data returned from license check")
|
||||
}
|
||||
|
||||
store.Set(licenseValidityStoreKey, body.Data.IsValid)
|
||||
if err := store.Set(licenseValidityStoreKey, body.Data.IsValid); err != nil {
|
||||
return err
|
||||
}
|
||||
l.logger.Debug("finished license check", log.String("siteID", l.siteID))
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -6,6 +6,7 @@ import (
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -14,8 +15,10 @@ import (
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/sourcegraph/log/logtest"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/enterprise/internal/license"
|
||||
"github.com/sourcegraph/sourcegraph/internal/redispool"
|
||||
"github.com/sourcegraph/sourcegraph/lib/pointers"
|
||||
)
|
||||
|
||||
func Test_calcDurationToWaitForNextHandle(t *testing.T) {
|
||||
@ -26,8 +29,8 @@ func Test_calcDurationToWaitForNextHandle(t *testing.T) {
|
||||
})
|
||||
|
||||
cleanupStore := func() {
|
||||
store.Del(licenseValidityStoreKey)
|
||||
store.Del(lastCalledAtStoreKey)
|
||||
_ = store.Del(licenseValidityStoreKey)
|
||||
_ = store.Del(lastCalledAtStoreKey)
|
||||
}
|
||||
|
||||
now := time.Now().Round(time.Second)
|
||||
@ -74,7 +77,7 @@ func Test_calcDurationToWaitForNextHandle(t *testing.T) {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
cleanupStore()
|
||||
if test.lastCalledAt != "" {
|
||||
store.Set(lastCalledAtStoreKey, test.lastCalledAt)
|
||||
_ = store.Set(lastCalledAtStoreKey, test.lastCalledAt)
|
||||
}
|
||||
|
||||
got, err := calcDurationSinceLastCalled(clock)
|
||||
@ -88,6 +91,19 @@ func Test_calcDurationToWaitForNextHandle(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func mockDotcomURL(t *testing.T, u *string) {
|
||||
t.Helper()
|
||||
|
||||
origBaseURL := baseUrl
|
||||
t.Cleanup(func() {
|
||||
baseUrl = origBaseURL
|
||||
})
|
||||
|
||||
if u != nil {
|
||||
baseUrl = *u
|
||||
}
|
||||
}
|
||||
|
||||
func Test_licenseChecker(t *testing.T) {
|
||||
// Connect to local redis for testing, this is the same URL used in rcache.SetupForTest
|
||||
store = redispool.NewKeyValue("127.0.0.1:6379", &redis.Pool{
|
||||
@ -96,8 +112,8 @@ func Test_licenseChecker(t *testing.T) {
|
||||
})
|
||||
|
||||
cleanupStore := func() {
|
||||
store.Del(licenseValidityStoreKey)
|
||||
store.Del(lastCalledAtStoreKey)
|
||||
_ = store.Del(licenseValidityStoreKey)
|
||||
_ = store.Del(lastCalledAtStoreKey)
|
||||
}
|
||||
|
||||
siteID := "some-site-id"
|
||||
@ -159,8 +175,8 @@ func Test_licenseChecker(t *testing.T) {
|
||||
MockGetConfiguredProductLicenseInfo = defaultMockGetLicense
|
||||
})
|
||||
|
||||
store.Del(licenseValidityStoreKey)
|
||||
store.Del(lastCalledAtStoreKey)
|
||||
_ = store.Del(licenseValidityStoreKey)
|
||||
_ = store.Del(lastCalledAtStoreKey)
|
||||
|
||||
doer := &mockDoer{
|
||||
status: '1',
|
||||
@ -195,6 +211,7 @@ func Test_licenseChecker(t *testing.T) {
|
||||
status int
|
||||
want bool
|
||||
err bool
|
||||
baseUrl *string
|
||||
}{
|
||||
"returns error if unable to make a request to license server": {
|
||||
response: []byte(`{"error": "some error"}`),
|
||||
@ -216,12 +233,20 @@ func Test_licenseChecker(t *testing.T) {
|
||||
status: http.StatusOK,
|
||||
want: false,
|
||||
},
|
||||
`uses sourcegraph baseURL from env`: {
|
||||
response: []byte(`{"data": {"is_valid": true}}`),
|
||||
status: http.StatusOK,
|
||||
want: true,
|
||||
baseUrl: pointers.Ptr("https://foo.bar"),
|
||||
},
|
||||
}
|
||||
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
cleanupStore()
|
||||
|
||||
mockDotcomURL(t, test.baseUrl)
|
||||
|
||||
doer := &mockDoer{
|
||||
status: test.status,
|
||||
response: test.response,
|
||||
@ -254,9 +279,10 @@ func Test_licenseChecker(t *testing.T) {
|
||||
require.NotEmpty(t, lastCalledAt)
|
||||
|
||||
// check doer with proper parameters
|
||||
rUrl, _ := url.JoinPath(baseUrl, "/.api/license/check")
|
||||
require.True(t, doer.DoCalled)
|
||||
require.Equal(t, "POST", doer.Request.Method)
|
||||
require.Equal(t, "https://sourcegraph.com/.api/license/check", doer.Request.URL.String())
|
||||
require.Equal(t, rUrl, doer.Request.URL.String())
|
||||
require.Equal(t, "application/json", doer.Request.Header.Get("Content-Type"))
|
||||
require.Equal(t, "Bearer "+token, doer.Request.Header.Get("Authorization"))
|
||||
var body struct {
|
||||
@ -269,8 +295,6 @@ func Test_licenseChecker(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
var strPtr = func(s string) *string { return &s }
|
||||
|
||||
type mockDoer struct {
|
||||
DoCalled bool
|
||||
Request *http.Request
|
||||
|
||||
Loading…
Reference in New Issue
Block a user