mirror of
https://github.com/sourcegraph/sourcegraph.git
synced 2026-02-06 16:31:47 +00:00
add enterprise/ directory
This commit is contained in:
parent
200fb1bed9
commit
ffd2ccfc84
18
enterprise/.gitignore
vendored
Normal file
18
enterprise/.gitignore
vendored
Normal file
@ -0,0 +1,18 @@
|
||||
cmd/server/dockerfile.go
|
||||
|
||||
# Web
|
||||
.nyc_output/
|
||||
coverage/
|
||||
node_modules/
|
||||
out/
|
||||
package-lock.json
|
||||
package-lock.json
|
||||
puppeteer/
|
||||
src/backend/graphqlschema.ts
|
||||
yarn-error.log
|
||||
/ui/assets/*
|
||||
!/ui/assets/img
|
||||
!/ui/assets/img/*
|
||||
/.bin
|
||||
/cmd/frontend/internal/assets/distassets_vfsdata.go
|
||||
/vendor/.bin
|
||||
17
enterprise/.prettierignore
Normal file
17
enterprise/.prettierignore
Normal file
@ -0,0 +1,17 @@
|
||||
.bin/
|
||||
*.bundle.*
|
||||
client/phabricator/scripts/loader.js
|
||||
cmd/frontend/internal/db/schema.md
|
||||
cmd/xlang-python/python-langserver/
|
||||
package-lock.json
|
||||
package.json
|
||||
ui/assets/
|
||||
vendor/
|
||||
.nyc_output/
|
||||
coverage/
|
||||
out/
|
||||
src/backend/graphqlschema.ts
|
||||
src/schema/
|
||||
ts-node-*
|
||||
xlang/testdata
|
||||
cmd/xlang-go/internal/server/testdata/
|
||||
3
enterprise/.stylelintrc.json
Normal file
3
enterprise/.stylelintrc.json
Normal file
@ -0,0 +1,3 @@
|
||||
{
|
||||
"extends": ["@sourcegraph/stylelint-config"]
|
||||
}
|
||||
12
enterprise/.vscode/settings.json
vendored
Normal file
12
enterprise/.vscode/settings.json
vendored
Normal file
@ -0,0 +1,12 @@
|
||||
{
|
||||
"typescript.tsdk": "node_modules/typescript/lib",
|
||||
"files.associations": {
|
||||
"**/dev/config.json": "jsonc"
|
||||
},
|
||||
"json.schemas": [
|
||||
{
|
||||
"fileMatch": ["dev/config.json"],
|
||||
"url": "/vendor/github.com/sourcegraph/sourcegraph/schema/site.schema.json"
|
||||
}
|
||||
]
|
||||
}
|
||||
78
enterprise/README.dev.md
Normal file
78
enterprise/README.dev.md
Normal file
@ -0,0 +1,78 @@
|
||||
# Sourcegraph development process
|
||||
|
||||
1. Configure your repository with two remotes, `oss` and `ent`, by running the following:
|
||||
|
||||
```bash
|
||||
enterprise/dev/init-repo.sh
|
||||
```
|
||||
|
||||
After running this script, the following should hold:
|
||||
|
||||
- `oss` should point to `https://github.com/sourcegraph/sourcegraph`.
|
||||
- `ent` should point to `https://github.com/sourcegraph/enterprise`.
|
||||
- There should be no `origin` remote.
|
||||
|
||||
2. Decide whether you will branch off the open-source (`oss`) or enterprise (`ent`) repo. When in
|
||||
doubt, prefer `oss`.
|
||||
|
||||
## Developing off `oss`
|
||||
|
||||
1. Run `dev/launch.sh` from the root of this repository.
|
||||
1. Push your branch up to `oss` and open a PR.
|
||||
1. Wait for CI to pass, resolve any merge conflicts.
|
||||
1. Merge the PR into `oss` `master`. This will trigger another PR for you in `ent` that contains the
|
||||
commits you just merged. This PR will be automatically merged if `ent` CI passes and there are no
|
||||
merge conflicts.
|
||||
1. If the `ent` PR cannot be automatically merged, update the PR to resolve any CI errors or merge
|
||||
conflicts.
|
||||
- If you make any changes to OSS code in the `ent` repository, you are responsible for ensuring
|
||||
these changes make it into `oss`.
|
||||
|
||||
## Developing off `ent`
|
||||
|
||||
If a substantial subset of your change can be made as an independent change to `oss`, prefer to do
|
||||
that.
|
||||
|
||||
1. Run `dev/start.sh` from `enterprise` directory.
|
||||
1. Push your branch up to `ent` and open a PR.
|
||||
1. Wait for CI to pass, resolve any merge conflicts.
|
||||
1. Merge the PR into `ent` `master`.
|
||||
1. Cherry-pick your changes onto `oss/master` using `dev/prune-pick.sh`.
|
||||
1. If there are no conflicts, push directly to `oss/master`. If there are conflicts, resolve them,
|
||||
push to a `oss` branch, wait for CI, and then merge into `oss/master`.
|
||||
- If you make any additional changes, you are responsible for syncing these into `ent`.
|
||||
|
||||
### Build notes
|
||||
|
||||
**IMPORTANT:** Commands that build enterprise targets (e.g., `go build`, `yarn`,
|
||||
`enterprise/dev/go-install.sh`) should always be run with the `enterprise` directory as the current
|
||||
working directory. Otherwise, build tools like `yarn` and `go` may try to update the root
|
||||
`package.json` and `go.mod` files as a side effect, instead of updating `enterprise/package.json`
|
||||
and `enterprise/go.mod`.
|
||||
|
||||
The OSS web app is `yarn link`ed into `enterprise/node_modules`. It will run both the build of the
|
||||
enterprise webapp as well as the part of the build for the OSS repo that generates the distributed
|
||||
files for the npm package.
|
||||
|
||||
## Fallback syncing
|
||||
|
||||
Following the above instructions should prevent most long-term divergence between `oss` and
|
||||
`ent`. Divergence between the two can easily be tested by running `dev/git-diff-no-enterprise.sh oss/master ent/master`.
|
||||
|
||||
If `oss` and `ent` diverge too severely to the point where it becomes onerous to cherry-pick
|
||||
specific commits between the two, syncing can be accomplished by merging `oss/master` into
|
||||
`ent/master`:
|
||||
|
||||
```
|
||||
git fetch oss
|
||||
git fetch ent
|
||||
git checkout ent/master -b ent-master
|
||||
git merge oss/master --no-commit -X theirs # this allows us to view an accurate history of oss code in the enterprise repo
|
||||
git checkout oss/master -- . # overwrite all oss code in the enterprise repo with the state from oss repo
|
||||
git reset enterprise && git checkout enterprise
|
||||
git commit -m"Sync oss to ent $(date '+%Y-%m-%d')"
|
||||
git push ent HEAD:master
|
||||
```
|
||||
|
||||
Note that this means that `oss` is the source of truth for all code outside the `enterprise`
|
||||
directory.
|
||||
25
enterprise/README.md
Normal file
25
enterprise/README.md
Normal file
@ -0,0 +1,25 @@
|
||||
# Sourcegraph Enterprise
|
||||
|
||||
[](https://buildkite.com/sourcegraph/enterprise)
|
||||
[](https://codecov.io/gh/sourcegraph/enterprise)
|
||||
[](https://github.com/prettier/prettier)
|
||||
|
||||
This repository contains all of the Sourcegraph Enterprise code.
|
||||
|
||||
## Project layout
|
||||
|
||||
- The main Sourcegraph codebase is open source, see [github.com/sourcegraph/sourcegraph](https://github.com/sourcegraph/sourcegraph).
|
||||
- This codebase just wraps the open source codebase and links in some private code for enterprise features.
|
||||
- Only the enterprise codebase is published to e.g. Docker Hub. Enterprise features are behind paywalls (or good-faith). The open-source codebase is not published on Docker Hub (this avoids confusion and keeps the upgrade/downgrade process from open source <-> enterprise easy).
|
||||
|
||||
## Dev
|
||||
|
||||
See [README.dev.md](README.dev.md).
|
||||
|
||||
### Updating dependencies
|
||||
|
||||
- `go get -u $MODULE` to update `$MODULE` and all its transitive dependencies to their latest version.
|
||||
- `go mod edit -replace $MODULE@$VERSION` to update `$MODULE` to a specific version.
|
||||
- `go mod tidy` if updates to `go.mod` or `go.sum` have been made as a result of other build
|
||||
invocations during development and you wish now to update `go.mod` and `go.sum` to be consistent
|
||||
with how the build will run in CI.
|
||||
17
enterprise/babel.config.js
Normal file
17
enterprise/babel.config.js
Normal file
@ -0,0 +1,17 @@
|
||||
// @ts-check
|
||||
|
||||
/** @type {import('@babel/core').TransformOptions} */
|
||||
const config = {
|
||||
plugins: ['@babel/plugin-syntax-dynamic-import', 'babel-plugin-lodash'],
|
||||
presets: [
|
||||
[
|
||||
'@babel/preset-env',
|
||||
{
|
||||
useBuiltIns: 'entry',
|
||||
modules: false,
|
||||
},
|
||||
],
|
||||
],
|
||||
}
|
||||
|
||||
module.exports = config
|
||||
38
enterprise/cmd/frontend/auth/httpheader/config.go
Normal file
38
enterprise/cmd/frontend/auth/httpheader/config.go
Normal file
@ -0,0 +1,38 @@
|
||||
package httpheader
|
||||
|
||||
import (
|
||||
"github.com/sourcegraph/sourcegraph/pkg/conf"
|
||||
"github.com/sourcegraph/sourcegraph/schema"
|
||||
)
|
||||
|
||||
// getProviderConfig returns the HTTP header auth provider config. At most 1 can be specified in
|
||||
// site config; if there is more than 1, it returns multiple == true (which the caller should handle
|
||||
// by returning an error and refusing to proceed with auth).
|
||||
func getProviderConfig() (pc *schema.HTTPHeaderAuthProvider, multiple bool) {
|
||||
for _, p := range conf.Get().AuthProviders {
|
||||
if p.HttpHeader != nil {
|
||||
if pc != nil {
|
||||
return pc, true // multiple http-header auth providers
|
||||
}
|
||||
pc = p.HttpHeader
|
||||
}
|
||||
}
|
||||
return pc, false
|
||||
}
|
||||
|
||||
func init() {
|
||||
conf.ContributeValidator(validateConfig)
|
||||
}
|
||||
|
||||
func validateConfig(c schema.SiteConfiguration) (problems []string) {
|
||||
var httpHeaderAuthProviders int
|
||||
for _, p := range c.AuthProviders {
|
||||
if p.HttpHeader != nil {
|
||||
httpHeaderAuthProviders++
|
||||
}
|
||||
}
|
||||
if httpHeaderAuthProviders >= 2 {
|
||||
problems = append(problems, `at most 1 http-header auth provider may be used`)
|
||||
}
|
||||
return problems
|
||||
}
|
||||
38
enterprise/cmd/frontend/auth/httpheader/config_test.go
Normal file
38
enterprise/cmd/frontend/auth/httpheader/config_test.go
Normal file
@ -0,0 +1,38 @@
|
||||
package httpheader
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/pkg/conf"
|
||||
"github.com/sourcegraph/sourcegraph/schema"
|
||||
)
|
||||
|
||||
func TestValidateCustom(t *testing.T) {
|
||||
tests := map[string]struct {
|
||||
input schema.SiteConfiguration
|
||||
wantProblems []string
|
||||
}{
|
||||
"single": {
|
||||
input: schema.SiteConfiguration{
|
||||
AuthProviders: []schema.AuthProviders{
|
||||
{HttpHeader: &schema.HTTPHeaderAuthProvider{Type: "http-header"}},
|
||||
},
|
||||
},
|
||||
wantProblems: nil,
|
||||
},
|
||||
"multiple": {
|
||||
input: schema.SiteConfiguration{
|
||||
AuthProviders: []schema.AuthProviders{
|
||||
{HttpHeader: &schema.HTTPHeaderAuthProvider{Type: "http-header"}},
|
||||
{HttpHeader: &schema.HTTPHeaderAuthProvider{Type: "http-header"}},
|
||||
},
|
||||
},
|
||||
wantProblems: []string{"at most 1"},
|
||||
},
|
||||
}
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
conf.TestValidator(t, test.input, validateConfig, test.wantProblems)
|
||||
})
|
||||
}
|
||||
}
|
||||
49
enterprise/cmd/frontend/auth/httpheader/config_watch.go
Normal file
49
enterprise/cmd/frontend/auth/httpheader/config_watch.go
Normal file
@ -0,0 +1,49 @@
|
||||
package httpheader
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sync"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/auth"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/conf"
|
||||
"github.com/sourcegraph/sourcegraph/schema"
|
||||
log15 "gopkg.in/inconshreveable/log15.v2"
|
||||
)
|
||||
|
||||
// Watch for configuration changes related to the http-header auth provider.
|
||||
func init() {
|
||||
var (
|
||||
init = true
|
||||
|
||||
mu sync.Mutex
|
||||
pc *schema.HTTPHeaderAuthProvider
|
||||
pi auth.Provider
|
||||
)
|
||||
conf.Watch(func() {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
// Only react when the config changes.
|
||||
newPC, _ := getProviderConfig()
|
||||
if reflect.DeepEqual(newPC, pc) {
|
||||
return
|
||||
}
|
||||
|
||||
if !init {
|
||||
log15.Info("Reloading changed http-header authentication provider configuration.")
|
||||
}
|
||||
updates := map[auth.Provider]bool{}
|
||||
var newPI auth.Provider
|
||||
if newPC != nil {
|
||||
newPI = &provider{c: newPC}
|
||||
updates[newPI] = true
|
||||
}
|
||||
if pi != nil {
|
||||
updates[pi] = false
|
||||
}
|
||||
auth.UpdateProviders(updates)
|
||||
pc = newPC
|
||||
pi = newPI
|
||||
})
|
||||
init = false
|
||||
}
|
||||
96
enterprise/cmd/frontend/auth/httpheader/middleware.go
Normal file
96
enterprise/cmd/frontend/auth/httpheader/middleware.go
Normal file
@ -0,0 +1,96 @@
|
||||
package httpheader
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/sourcegraph/enterprise/cmd/frontend/internal/licensing"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/auth"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/actor"
|
||||
log15 "gopkg.in/inconshreveable/log15.v2"
|
||||
)
|
||||
|
||||
const providerType = "http-header"
|
||||
|
||||
// Middleware is the same for both the app and API because the HTTP proxy is assumed to wrap
|
||||
// requests to both the app and API and add headers.
|
||||
//
|
||||
// See the "func middleware" docs for more information.
|
||||
var Middleware = &auth.Middleware{
|
||||
API: middleware,
|
||||
App: middleware,
|
||||
}
|
||||
|
||||
// middleware is middleware that checks for an HTTP header from an auth proxy that specifies the
|
||||
// client's authenticated username. It's for use with auth proxies like
|
||||
// https://github.com/bitly/oauth2_proxy and is configured with the http-header auth provider in
|
||||
// site config.
|
||||
//
|
||||
// TESTING: Use the testproxy test program to test HTTP auth proxy behavior. For example, run `go
|
||||
// run cmd/frontend/external/auth/httpheader/testproxy.go -username=alice` then go to
|
||||
// http://localhost:4080. See `-h` for flag help.
|
||||
//
|
||||
// 🚨 SECURITY
|
||||
func middleware(next http.Handler) http.Handler {
|
||||
const misconfiguredMessage = "Misconfigured http-header auth provider."
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
authProvider, multiple := getProviderConfig()
|
||||
if multiple {
|
||||
log15.Error("At most 1 HTTP header auth provider may be set in site config.")
|
||||
http.Error(w, misconfiguredMessage, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if authProvider == nil {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
if authProvider.UsernameHeader == "" {
|
||||
log15.Error("No HTTP header set for username (set the http-header auth provider's usernameHeader property).")
|
||||
http.Error(w, "misconfigured http-header auth provider", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
headerValue := r.Header.Get(authProvider.UsernameHeader)
|
||||
// Continue onto next auth provider if no header is set (in case the auth proxy allows
|
||||
// unauthenticated users to bypass it, which some do). Also respect already authenticated
|
||||
// actors (e.g., via access token).
|
||||
//
|
||||
// It would NOT add any additional security to return an error here, because a user who can
|
||||
// access this HTTP endpoint directly can just as easily supply a fake username whose
|
||||
// identity to assume.
|
||||
if headerValue == "" || actor.FromContext(r.Context()).IsAuthenticated() {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// License check.
|
||||
if !licensing.IsFeatureEnabledLenient(licensing.FeatureExternalAuthProvider) {
|
||||
licensing.WriteSubscriptionErrorResponseForFeature(w, "http-header user authentication (SSO)")
|
||||
return
|
||||
}
|
||||
|
||||
// Otherwise, get or create the user and proceed with the authenticated request.
|
||||
username, err := auth.NormalizeUsername(headerValue)
|
||||
if err != nil {
|
||||
log15.Error("Error normalizing username from HTTP auth proxy.", "username", headerValue, "err", err)
|
||||
http.Error(w, "unable to normalize username", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
userID, safeErrMsg, err := auth.CreateOrUpdateUser(r.Context(), db.NewUser{Username: username}, db.ExternalAccountSpec{
|
||||
ServiceType: providerType,
|
||||
|
||||
// Store headerValue, not normalized username, to prevent two users with distinct
|
||||
// pre-normalization usernames from being merged into the same normalized username
|
||||
// (and therefore letting them each impersonate the other).
|
||||
AccountID: headerValue,
|
||||
}, db.ExternalAccountData{})
|
||||
if err != nil {
|
||||
log15.Error("unable to get/create user from SSO header", "header", authProvider.UsernameHeader, "headerValue", headerValue, "err", err, "userErr", safeErrMsg)
|
||||
http.Error(w, safeErrMsg, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
r = r.WithContext(actor.WithActor(r.Context(), &actor.Actor{UID: userID}))
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
117
enterprise/cmd/frontend/auth/httpheader/middleware_test.go
Normal file
117
enterprise/cmd/frontend/auth/httpheader/middleware_test.go
Normal file
@ -0,0 +1,117 @@
|
||||
package httpheader
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/sourcegraph/enterprise/cmd/frontend/internal/licensing"
|
||||
"github.com/sourcegraph/enterprise/pkg/license"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/auth"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/actor"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/conf"
|
||||
"github.com/sourcegraph/sourcegraph/schema"
|
||||
)
|
||||
|
||||
// SEE ALSO FOR MANUAL TESTING: See the Middleware docstring for information about the testproxy
|
||||
// helper program, which helps with manual testing of the HTTP auth proxy behavior.
|
||||
func TestMiddleware(t *testing.T) {
|
||||
licensing.MockGetConfiguredProductLicenseInfo = func() (*license.Info, error) {
|
||||
return &license.Info{Tags: licensing.EnterpriseTags}, nil
|
||||
}
|
||||
defer func() { licensing.MockGetConfiguredProductLicenseInfo = nil }()
|
||||
|
||||
handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
actor := actor.FromContext(r.Context())
|
||||
if actor.IsAuthenticated() {
|
||||
fmt.Fprintf(w, "user %v", actor.UID)
|
||||
} else {
|
||||
fmt.Fprint(w, "no user")
|
||||
}
|
||||
}))
|
||||
|
||||
const headerName = "x-sso-user-header"
|
||||
conf.Mock(&schema.SiteConfiguration{AuthProviders: []schema.AuthProviders{{HttpHeader: &schema.HTTPHeaderAuthProvider{UsernameHeader: headerName}}}})
|
||||
defer conf.Mock(nil)
|
||||
|
||||
t.Run("not sent", func(t *testing.T) {
|
||||
rr := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/", nil)
|
||||
handler.ServeHTTP(rr, req)
|
||||
if got, want := rr.Body.String(), "no user"; got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("not sent, actor present", func(t *testing.T) {
|
||||
rr := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/", nil)
|
||||
req = req.WithContext(actor.WithActor(context.Background(), &actor.Actor{UID: 123}))
|
||||
handler.ServeHTTP(rr, req)
|
||||
if got, want := rr.Body.String(), "user 123"; got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sent, user", func(t *testing.T) {
|
||||
rr := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/", nil)
|
||||
req.Header.Set(headerName, "alice")
|
||||
var calledMock bool
|
||||
auth.SetMockCreateOrUpdateUser(func(u db.NewUser, a db.ExternalAccountSpec) (userID int32, err error) {
|
||||
calledMock = true
|
||||
if a.ServiceType == "http-header" && a.ServiceID == "" && a.ClientID == "" && a.AccountID == "alice" {
|
||||
return 1, nil
|
||||
}
|
||||
return 0, fmt.Errorf("account %v not found in mock", a)
|
||||
})
|
||||
defer auth.SetMockCreateOrUpdateUser(nil)
|
||||
handler.ServeHTTP(rr, req)
|
||||
if got, want := rr.Body.String(), "user 1"; got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
if !calledMock {
|
||||
t.Error("!calledMock")
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sent, actor already set", func(t *testing.T) {
|
||||
rr := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/", nil)
|
||||
req.Header.Set(headerName, "alice")
|
||||
req = req.WithContext(actor.WithActor(context.Background(), &actor.Actor{UID: 123}))
|
||||
handler.ServeHTTP(rr, req)
|
||||
if got, want := rr.Body.String(), "user 123"; got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("sent, with un-normalized username", func(t *testing.T) {
|
||||
rr := httptest.NewRecorder()
|
||||
req, _ := http.NewRequest("GET", "/", nil)
|
||||
req.Header.Set(headerName, "alice.zhao")
|
||||
const wantNormalizedUsername = "alice-zhao"
|
||||
var calledMock bool
|
||||
auth.SetMockCreateOrUpdateUser(func(u db.NewUser, a db.ExternalAccountSpec) (userID int32, err error) {
|
||||
calledMock = true
|
||||
if u.Username != wantNormalizedUsername {
|
||||
t.Errorf("got %q, want %q", u.Username, wantNormalizedUsername)
|
||||
}
|
||||
if a.ServiceType == "http-header" && a.ServiceID == "" && a.ClientID == "" && a.AccountID == "alice.zhao" {
|
||||
return 1, nil
|
||||
}
|
||||
return 0, fmt.Errorf("account %v not found in mock", a)
|
||||
})
|
||||
defer auth.SetMockCreateOrUpdateUser(nil)
|
||||
handler.ServeHTTP(rr, req)
|
||||
if got, want := rr.Body.String(), "user 1"; got != want {
|
||||
t.Errorf("got %q, want %q", got, want)
|
||||
}
|
||||
if !calledMock {
|
||||
t.Error("!calledMock")
|
||||
}
|
||||
})
|
||||
}
|
||||
30
enterprise/cmd/frontend/auth/httpheader/provider.go
Normal file
30
enterprise/cmd/frontend/auth/httpheader/provider.go
Normal file
@ -0,0 +1,30 @@
|
||||
package httpheader
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/textproto"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/auth"
|
||||
"github.com/sourcegraph/sourcegraph/schema"
|
||||
)
|
||||
|
||||
type provider struct {
|
||||
c *schema.HTTPHeaderAuthProvider
|
||||
}
|
||||
|
||||
// ConfigID implements auth.Provider.
|
||||
func (provider) ConfigID() auth.ProviderConfigID { return auth.ProviderConfigID{Type: providerType} }
|
||||
|
||||
// Config implements auth.Provider.
|
||||
func (p provider) Config() schema.AuthProviders { return schema.AuthProviders{HttpHeader: p.c} }
|
||||
|
||||
// Refresh implements auth.Provider.
|
||||
func (p provider) Refresh(context.Context) error { return nil }
|
||||
|
||||
// CachedInfo implements auth.Provider.
|
||||
func (p provider) CachedInfo() *auth.ProviderInfo {
|
||||
return &auth.ProviderInfo{
|
||||
DisplayName: fmt.Sprintf("HTTP authentication proxy (%q header)", textproto.CanonicalMIMEHeaderKey(p.c.UsernameHeader)),
|
||||
}
|
||||
}
|
||||
47
enterprise/cmd/frontend/auth/httpheader/testproxy.go
Normal file
47
enterprise/cmd/frontend/auth/httpheader/testproxy.go
Normal file
@ -0,0 +1,47 @@
|
||||
// The testproxy command runs a simple HTTP proxy that wraps a Sourcegraph server running with the
|
||||
// http-header auth provider to test the authentication HTTP proxy support.
|
||||
|
||||
// +build ignore
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"log"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"os"
|
||||
)
|
||||
|
||||
var (
|
||||
addr = flag.String("addr", ":4080", "HTTP listen address")
|
||||
urlStr = flag.String("url", "http://localhost:3080", "proxy origin URL (Sourcegraph HTTP/HTTPS URL)") // CI:LOCALHOST_OK
|
||||
username = flag.String("username", os.Getenv("USER"), "username to report to Sourcegraph")
|
||||
httpHeader = flag.String("header", "X-Forwarded-User", "name of HTTP header to add to request")
|
||||
)
|
||||
|
||||
func main() {
|
||||
flag.Parse()
|
||||
log.SetFlags(0)
|
||||
|
||||
url, err := url.Parse(*urlStr)
|
||||
if err != nil {
|
||||
log.Fatalf("Error: Invalid -url: %s.", err)
|
||||
}
|
||||
if *username == "" {
|
||||
log.Fatal("Error: No -username specified.")
|
||||
}
|
||||
if *httpHeader == "" {
|
||||
log.Fatal("Error: No -header specified.")
|
||||
}
|
||||
log.Printf(`Listening on %s, forwarding requests to %s with added header "%s: %s"`, *addr, url, *httpHeader, *username)
|
||||
p := httputil.NewSingleHostReverseProxy(url)
|
||||
log.Fatalf("Server error: %s.", http.ListenAndServe(*addr, &httputil.ReverseProxy{
|
||||
Director: func(r *http.Request) {
|
||||
r.Header.Set(*httpHeader, *username)
|
||||
r.Host = url.Host
|
||||
p.Director(r)
|
||||
},
|
||||
}))
|
||||
}
|
||||
85
enterprise/cmd/frontend/auth/init.go
Normal file
85
enterprise/cmd/frontend/auth/init.go
Normal file
@ -0,0 +1,85 @@
|
||||
// Package auth is imported for side-effects to enable enterprise-only SSO.
|
||||
package auth
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/sourcegraph/enterprise/cmd/frontend/auth/httpheader"
|
||||
"github.com/sourcegraph/enterprise/cmd/frontend/auth/openidconnect"
|
||||
"github.com/sourcegraph/enterprise/cmd/frontend/auth/saml"
|
||||
"github.com/sourcegraph/enterprise/cmd/frontend/internal/licensing"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/app"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/auth"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/conf"
|
||||
log15 "gopkg.in/inconshreveable/log15.v2"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Register enterprise auth middleware
|
||||
auth.RegisterMiddlewares(
|
||||
openidconnect.Middleware,
|
||||
saml.Middleware,
|
||||
httpheader.Middleware,
|
||||
)
|
||||
// Register app-level sign-out handler
|
||||
app.RegisterSSOSignOutHandler(ssoSignOutHandler)
|
||||
}
|
||||
|
||||
func ssoSignOutHandler(w http.ResponseWriter, r *http.Request) (signOutURLs []app.SignOutURL) {
|
||||
for _, p := range conf.Get().AuthProviders {
|
||||
var e app.SignOutURL
|
||||
var err error
|
||||
switch {
|
||||
case p.Openidconnect != nil:
|
||||
e.ProviderDisplayName = p.Openidconnect.DisplayName
|
||||
e.ProviderServiceType = p.Openidconnect.Type
|
||||
e.URL, err = openidconnect.SignOut(w, r)
|
||||
case p.Saml != nil:
|
||||
e.ProviderDisplayName = p.Saml.DisplayName
|
||||
e.ProviderServiceType = p.Saml.Type
|
||||
e.URL, err = saml.SignOut(w, r)
|
||||
}
|
||||
if e.URL != "" {
|
||||
signOutURLs = append(signOutURLs, e)
|
||||
}
|
||||
if err != nil {
|
||||
log15.Error("Error clearing auth provider session data.", "err", err)
|
||||
}
|
||||
}
|
||||
|
||||
return signOutURLs
|
||||
}
|
||||
|
||||
func init() {
|
||||
// Warn about usage of auth providers that are not enabled by the license.
|
||||
graphqlbackend.AlertFuncs = append(graphqlbackend.AlertFuncs, func(args graphqlbackend.AlertFuncArgs) []*graphqlbackend.Alert {
|
||||
// Only site admins can act on this alert, so only show it to site admins.
|
||||
if !args.IsSiteAdmin {
|
||||
return nil
|
||||
}
|
||||
|
||||
if licensing.IsFeatureEnabledLenient(licensing.FeatureExternalAuthProvider) {
|
||||
return nil
|
||||
}
|
||||
|
||||
var externalAuthProviderTypes []string
|
||||
for _, p := range conf.Get().AuthProviders {
|
||||
if p.Builtin == nil {
|
||||
externalAuthProviderTypes = append(externalAuthProviderTypes, conf.AuthProviderType(p))
|
||||
}
|
||||
}
|
||||
if len(externalAuthProviderTypes) > 0 {
|
||||
return []*graphqlbackend.Alert{
|
||||
{
|
||||
TypeValue: graphqlbackend.AlertTypeError,
|
||||
MessageValue: fmt.Sprintf("A Sourcegraph license is required for user authentication providers (SSO): %s. [**Get a license.**](/site-admin/license)", strings.Join(externalAuthProviderTypes, ", ")),
|
||||
},
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
94
enterprise/cmd/frontend/auth/openidconnect/config.go
Normal file
94
enterprise/cmd/frontend/auth/openidconnect/config.go
Normal file
@ -0,0 +1,94 @@
|
||||
package openidconnect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/sourcegraph/enterprise/cmd/frontend/internal/licensing"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/auth"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/conf"
|
||||
"github.com/sourcegraph/sourcegraph/schema"
|
||||
log15 "gopkg.in/inconshreveable/log15.v2"
|
||||
)
|
||||
|
||||
var mockGetProviderValue *provider
|
||||
|
||||
// getProvider looks up the registered openidconnect auth provider with the given ID.
|
||||
func getProvider(id string) *provider {
|
||||
if mockGetProviderValue != nil {
|
||||
return mockGetProviderValue
|
||||
}
|
||||
p, _ := auth.GetProviderByConfigID(auth.ProviderConfigID{Type: providerType, ID: id}).(*provider)
|
||||
return p
|
||||
}
|
||||
|
||||
func handleGetProvider(ctx context.Context, w http.ResponseWriter, id string) (p *provider, handled bool) {
|
||||
handled = true // safer default
|
||||
|
||||
// License check.
|
||||
if !licensing.IsFeatureEnabledLenient(licensing.FeatureExternalAuthProvider) {
|
||||
licensing.WriteSubscriptionErrorResponseForFeature(w, "OpenID Connect user authentication (SSO)")
|
||||
return nil, true
|
||||
}
|
||||
|
||||
p = getProvider(id)
|
||||
if p == nil {
|
||||
log15.Error("No OpenID Connect auth provider found with ID.", "id", id)
|
||||
http.Error(w, "Misconfigured OpenID Connect auth provider.", http.StatusInternalServerError)
|
||||
return nil, true
|
||||
}
|
||||
if p.config.Issuer == "" {
|
||||
log15.Error("No issuer set for OpenID Connect auth provider (set the openidconnect auth provider's issuer property).", "id", p.ConfigID())
|
||||
http.Error(w, "Misconfigured OpenID Connect auth provider.", http.StatusInternalServerError)
|
||||
return nil, true
|
||||
}
|
||||
if err := p.Refresh(ctx); err != nil {
|
||||
log15.Error("Error refreshing OpenID Connect auth provider.", "id", p.ConfigID(), "error", err)
|
||||
http.Error(w, "Unexpected error refreshing OpenID Connect authentication provider.", http.StatusInternalServerError)
|
||||
return nil, true
|
||||
}
|
||||
return p, false
|
||||
}
|
||||
|
||||
func init() {
|
||||
conf.ContributeValidator(validateConfig)
|
||||
}
|
||||
|
||||
func validateConfig(c schema.SiteConfiguration) (problems []string) {
|
||||
var loggedNeedsAppURL bool
|
||||
for _, p := range c.AuthProviders {
|
||||
if p.Openidconnect != nil && c.AppURL == "" && !loggedNeedsAppURL {
|
||||
problems = append(problems, `openidconnect auth provider requires appURL to be set to the external URL of your site (example: https://sourcegraph.example.com)`)
|
||||
loggedNeedsAppURL = true
|
||||
}
|
||||
}
|
||||
|
||||
seen := map[schema.OpenIDConnectAuthProvider]int{}
|
||||
for i, p := range c.AuthProviders {
|
||||
if p.Openidconnect != nil {
|
||||
if j, ok := seen[*p.Openidconnect]; ok {
|
||||
problems = append(problems, fmt.Sprintf("OpenID Connect auth provider at index %d is duplicate of index %d, ignoring", i, j))
|
||||
} else {
|
||||
seen[*p.Openidconnect] = i
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return problems
|
||||
}
|
||||
|
||||
// providerConfigID produces a semi-stable identifier for an openidconnect auth provider config
|
||||
// object. It is used to distinguish between multiple auth providers of the same type when in
|
||||
// multi-step auth flows. Its value is never persisted, and it must be deterministic.
|
||||
func providerConfigID(pc *schema.OpenIDConnectAuthProvider) string {
|
||||
data, err := json.Marshal(pc)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
b := sha256.Sum256(data)
|
||||
return base64.RawURLEncoding.EncodeToString(b[:16])
|
||||
}
|
||||
40
enterprise/cmd/frontend/auth/openidconnect/config_test.go
Normal file
40
enterprise/cmd/frontend/auth/openidconnect/config_test.go
Normal file
@ -0,0 +1,40 @@
|
||||
package openidconnect
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/pkg/conf"
|
||||
"github.com/sourcegraph/sourcegraph/schema"
|
||||
)
|
||||
|
||||
func TestValidateCustom(t *testing.T) {
|
||||
tests := map[string]struct {
|
||||
input schema.SiteConfiguration
|
||||
wantProblems []string
|
||||
}{
|
||||
"duplicates": {
|
||||
input: schema.SiteConfiguration{
|
||||
AppURL: "x",
|
||||
AuthProviders: []schema.AuthProviders{
|
||||
{Openidconnect: &schema.OpenIDConnectAuthProvider{Type: "openidconnect", Issuer: "x"}},
|
||||
{Openidconnect: &schema.OpenIDConnectAuthProvider{Type: "openidconnect", Issuer: "x"}},
|
||||
},
|
||||
},
|
||||
wantProblems: []string{"OpenID Connect auth provider at index 1 is duplicate of index 0"},
|
||||
},
|
||||
}
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
conf.TestValidator(t, test.input, validateConfig, test.wantProblems)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderConfigID(t *testing.T) {
|
||||
p := schema.OpenIDConnectAuthProvider{Issuer: "x"}
|
||||
id1 := providerConfigID(&p)
|
||||
id2 := providerConfigID(&p)
|
||||
if id1 != id2 {
|
||||
t.Errorf("id1 (%q) != id2 (%q)", id1, id2)
|
||||
}
|
||||
}
|
||||
84
enterprise/cmd/frontend/auth/openidconnect/config_watch.go
Normal file
84
enterprise/cmd/frontend/auth/openidconnect/config_watch.go
Normal file
@ -0,0 +1,84 @@
|
||||
package openidconnect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/auth"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/conf"
|
||||
"github.com/sourcegraph/sourcegraph/schema"
|
||||
log15 "gopkg.in/inconshreveable/log15.v2"
|
||||
)
|
||||
|
||||
// Start trying to populate the cache of issuer metadata (given the configured OpenID Connect issuer
|
||||
// URL) immediately upon server startup and site config changes so users don't incur the wait on the
|
||||
// first auth flow request.
|
||||
func init() {
|
||||
providersOfType := func(ps []schema.AuthProviders) []*schema.OpenIDConnectAuthProvider {
|
||||
var pcs []*schema.OpenIDConnectAuthProvider
|
||||
for _, p := range ps {
|
||||
if p.Openidconnect != nil {
|
||||
pcs = append(pcs, p.Openidconnect)
|
||||
}
|
||||
}
|
||||
return pcs
|
||||
}
|
||||
|
||||
var (
|
||||
init = true
|
||||
|
||||
mu sync.Mutex
|
||||
cur []*schema.OpenIDConnectAuthProvider
|
||||
reg = map[schema.OpenIDConnectAuthProvider]auth.Provider{}
|
||||
)
|
||||
conf.Watch(func() {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
// Only react when the config changes.
|
||||
new := providersOfType(conf.Get().AuthProviders)
|
||||
diff := diffProviderConfig(cur, new)
|
||||
if len(diff) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if !init {
|
||||
log15.Info("Reloading changed OpenID Connect authentication provider configuration.")
|
||||
}
|
||||
updates := make(map[auth.Provider]bool, len(diff))
|
||||
for pc, op := range diff {
|
||||
if old, ok := reg[pc]; ok {
|
||||
delete(reg, pc)
|
||||
updates[old] = false
|
||||
}
|
||||
if op {
|
||||
new := &provider{config: pc}
|
||||
reg[pc] = new
|
||||
updates[new] = true
|
||||
go func(p *provider) {
|
||||
if err := p.Refresh(context.Background()); err != nil {
|
||||
log15.Error("Error prefetching OpenID Connect service provider metadata.", "error", err)
|
||||
}
|
||||
}(new)
|
||||
}
|
||||
}
|
||||
auth.UpdateProviders(updates)
|
||||
cur = new
|
||||
})
|
||||
init = false
|
||||
}
|
||||
|
||||
func diffProviderConfig(old, new []*schema.OpenIDConnectAuthProvider) map[schema.OpenIDConnectAuthProvider]bool {
|
||||
diff := map[schema.OpenIDConnectAuthProvider]bool{}
|
||||
for _, oldPC := range old {
|
||||
diff[*oldPC] = false
|
||||
}
|
||||
for _, newPC := range new {
|
||||
if _, ok := diff[*newPC]; ok {
|
||||
delete(diff, *newPC)
|
||||
} else {
|
||||
diff[*newPC] = true
|
||||
}
|
||||
}
|
||||
return diff
|
||||
}
|
||||
@ -0,0 +1,46 @@
|
||||
package openidconnect
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/schema"
|
||||
)
|
||||
|
||||
func TestDiffProviderConfig(t *testing.T) {
|
||||
var (
|
||||
pc0 = &schema.OpenIDConnectAuthProvider{Issuer: "0"}
|
||||
pc0c = &schema.OpenIDConnectAuthProvider{Issuer: "0", ClientSecret: "x"}
|
||||
pc1 = &schema.OpenIDConnectAuthProvider{Issuer: "1"}
|
||||
)
|
||||
|
||||
tests := map[string]struct {
|
||||
old, new []*schema.OpenIDConnectAuthProvider
|
||||
want map[schema.OpenIDConnectAuthProvider]bool
|
||||
}{
|
||||
"empty": {want: map[schema.OpenIDConnectAuthProvider]bool{}},
|
||||
"added": {
|
||||
old: nil,
|
||||
new: []*schema.OpenIDConnectAuthProvider{pc0, pc1},
|
||||
want: map[schema.OpenIDConnectAuthProvider]bool{*pc0: true, *pc1: true},
|
||||
},
|
||||
"changed": {
|
||||
old: []*schema.OpenIDConnectAuthProvider{pc0, pc1},
|
||||
new: []*schema.OpenIDConnectAuthProvider{pc0c, pc1},
|
||||
want: map[schema.OpenIDConnectAuthProvider]bool{*pc0: false, *pc0c: true},
|
||||
},
|
||||
"removed": {
|
||||
old: []*schema.OpenIDConnectAuthProvider{pc0, pc1},
|
||||
new: []*schema.OpenIDConnectAuthProvider{pc1},
|
||||
want: map[schema.OpenIDConnectAuthProvider]bool{*pc0: false},
|
||||
},
|
||||
}
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
diff := diffProviderConfig(test.old, test.new)
|
||||
if !reflect.DeepEqual(diff, test.want) {
|
||||
t.Errorf("got != want\n got %+v\nwant %+v", diff, test.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
317
enterprise/cmd/frontend/auth/openidconnect/middleware.go
Normal file
317
enterprise/cmd/frontend/auth/openidconnect/middleware.go
Normal file
@ -0,0 +1,317 @@
|
||||
package openidconnect
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/csrf"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/auth"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/session"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/actor"
|
||||
"golang.org/x/oauth2"
|
||||
log15 "gopkg.in/inconshreveable/log15.v2"
|
||||
|
||||
oidc "github.com/coreos/go-oidc"
|
||||
)
|
||||
|
||||
const stateCookieName = "sg-oidc-state"
|
||||
|
||||
// All OpenID Connect endpoints are under this path prefix.
|
||||
const authPrefix = auth.AuthURLPrefix + "/openidconnect"
|
||||
|
||||
type userClaims struct {
|
||||
Name string `json:"name"`
|
||||
GivenName string `json:"given_name"`
|
||||
FamilyName string `json:"family_name"`
|
||||
PreferredUsername string `json:"preferred_username"`
|
||||
Picture string `json:"picture"`
|
||||
EmailVerified *bool `json:"email_verified"`
|
||||
}
|
||||
|
||||
// Middleware is middleware for OpenID Connect (OIDC) authentication, adding endpoints under the
|
||||
// auth path prefix ("/.auth") to enable the login flow and requiring login for all other endpoints.
|
||||
//
|
||||
// The OIDC spec (http://openid.net/specs/openid-connect-core-1_0.html) describes an authentication protocol
|
||||
// that involves 3 parties: the Relying Party (e.g., Sourcegraph), the OpenID Provider (e.g., Okta, OneLogin,
|
||||
// or another SSO provider), and the End User (e.g., a user's web browser).
|
||||
//
|
||||
// This middleware implements two things: (1) the OIDC Authorization Code Flow
|
||||
// (http://openid.net/specs/openid-connect-core-1_0.html#CodeFlowAuth) and (2) Sourcegraph-specific session management
|
||||
// (outside the scope of the OIDC spec). Upon successful completion of the OIDC login flow, the handler will create
|
||||
// a new session and session cookie. The expiration of the session is the expiration of the OIDC ID Token.
|
||||
//
|
||||
// 🚨 SECURITY
|
||||
var Middleware = &auth.Middleware{
|
||||
API: func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handleOpenIDConnectAuth(w, r, next, true)
|
||||
})
|
||||
},
|
||||
App: func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
handleOpenIDConnectAuth(w, r, next, false)
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
// handleOpenIDConnectAuth performs OpenID Connect authentication (if configured) for HTTP requests,
|
||||
// both API requests and non-API requests.
|
||||
func handleOpenIDConnectAuth(w http.ResponseWriter, r *http.Request, next http.Handler, isAPIRequest bool) {
|
||||
// Fixup URL path. We use "/.auth/callback" as the redirect URI for OpenID Connect, but the rest
|
||||
// of this middleware's handlers expect paths of "/.auth/openidconnect/...", so add the
|
||||
// "openidconnect" path component. We can't change the redirect URI because it is hardcoded in
|
||||
// instances' external auth providers.
|
||||
if r.URL.Path == auth.AuthURLPrefix+"/callback" {
|
||||
// Rewrite "/.auth/callback" -> "/.auth/openidconnect/callback".
|
||||
r.URL.Path = authPrefix + "/callback"
|
||||
}
|
||||
|
||||
// Delegate to the OpenID Connect auth handler.
|
||||
if !isAPIRequest && strings.HasPrefix(r.URL.Path, authPrefix+"/") {
|
||||
authHandler(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// If the actor is authenticated and not performing an OpenID Connect flow, then proceed to
|
||||
// next.
|
||||
if actor.FromContext(r.Context()).IsAuthenticated() {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// If there is only one auth provider configured, the single auth provider is OpenID Connect,
|
||||
// and it's an app request, redirect to signin immediately. The user wouldn't be able to do
|
||||
// anything else anyway; there's no point in showing them a signin screen with just a single
|
||||
// signin option.
|
||||
if ps := auth.Providers(); len(ps) == 1 && ps[0].Config().Openidconnect != nil && !isAPIRequest {
|
||||
p, handled := handleGetProvider(r.Context(), w, ps[0].ConfigID().ID)
|
||||
if handled {
|
||||
return
|
||||
}
|
||||
redirectToAuthRequest(w, r, p, auth.SafeRedirectURL(r.URL.String()))
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
// mockVerifyIDToken mocks the OIDC ID Token verification step. It should only be set in tests.
|
||||
var mockVerifyIDToken func(rawIDToken string) *oidc.IDToken
|
||||
|
||||
// authHandler handles the OIDC Authentication Code Flow
|
||||
// (http://openid.net/specs/openid-connect-core-1_0.html#CodeFlowAuth) on the Relying Party's end.
|
||||
//
|
||||
// 🚨 SECURITY
|
||||
func authHandler(w http.ResponseWriter, r *http.Request) {
|
||||
switch strings.TrimPrefix(r.URL.Path, authPrefix) {
|
||||
case "/login":
|
||||
// Endpoint that starts the Authentication Request Code Flow.
|
||||
p, handled := handleGetProvider(r.Context(), w, r.URL.Query().Get("pc"))
|
||||
if handled {
|
||||
return
|
||||
}
|
||||
redirectToAuthRequest(w, r, p, r.URL.Query().Get("redirect"))
|
||||
return
|
||||
|
||||
case "/callback":
|
||||
// Endpoint for the OIDC Authorization Response. See http://openid.net/specs/openid-connect-core-1_0.html#AuthResponse.
|
||||
ctx := r.Context()
|
||||
if authError := r.URL.Query().Get("error"); authError != "" {
|
||||
errorDesc := r.URL.Query().Get("error_description")
|
||||
log15.Error("OpenID Connect auth provider returned error to callback.", "error", authError, "description", errorDesc)
|
||||
http.Error(w, fmt.Sprintf("Authentication failed. Try signing in again (and clearing cookies for the current site). The authentication provider reported the following problems.\n\n%s\n\n%s", authError, errorDesc), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Validate state parameter to prevent CSRF attacks
|
||||
stateParam := r.URL.Query().Get("state")
|
||||
if stateParam == "" {
|
||||
http.Error(w, "Authentication failed. Try signing in again (and clearing cookies for the current site). No OpenID Connect state query parameter specified.", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
stateCookie, err := r.Cookie(stateCookieName)
|
||||
if err == http.ErrNoCookie {
|
||||
log15.Error("OpenID Connect auth failed: no state cookie found (possible request forgery).")
|
||||
http.Error(w, fmt.Sprintf("Authentication failed. Try signing in again (and clearing cookies for the current site). The error was: no OpenID Connect state cookie found (possible request forgery, or more than %s elapsed since you started the authentication process).", stateCookieTimeout), http.StatusBadRequest)
|
||||
return
|
||||
} else if err != nil {
|
||||
log15.Error("OpenID Connect auth failed: could not read state cookie (possible request forgery).", "error", err)
|
||||
http.Error(w, "Authentication failed. Try signing in again (and clearing cookies for the current site). The error was: invalid OpenID Connect state cookie.", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if stateCookie.Value != stateParam {
|
||||
log15.Error("OpenID Connect auth failed: state cookie mismatch (possible request forgery).")
|
||||
http.Error(w, "Authentication failed. Try signing in again (and clearing cookies for the current site). The error was: OpenID Connect state parameter did not match the expected value (possible request forgery).", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// Decode state param value
|
||||
var state authnState
|
||||
if err := state.Decode(stateParam); err != nil {
|
||||
log15.Error("OpenID Connect auth failed: state parameter was malformed.", "error", err)
|
||||
http.Error(w, "Authentication failed. OpenID Connect state parameter was malformed.", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
// 🚨 SECURITY: TODO(sqs): Do we need to check state.CSRFToken?
|
||||
|
||||
p, handled := handleGetProvider(r.Context(), w, state.ProviderID)
|
||||
if handled {
|
||||
return
|
||||
}
|
||||
verifier := p.oidc.Verifier(&oidc.Config{ClientID: p.config.ClientID})
|
||||
|
||||
// Exchange the code for an access token. See http://openid.net/specs/openid-connect-core-1_0.html#TokenRequest.
|
||||
oauth2Token, err := p.oauth2Config().Exchange(ctx, r.URL.Query().Get("code"))
|
||||
if err != nil {
|
||||
log15.Error("OpenID Connect auth failed: failed to obtain access token from OP.", "error", err)
|
||||
http.Error(w, "Authentication failed. Try signing in again. The error was: unable to obtain access token from issuer.", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Extract the ID Token from the Access Token. See http://openid.net/specs/openid-connect-core-1_0.html#TokenResponse.
|
||||
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
log15.Error("OpenID Connect auth failed: the issuer's authorization response did not contain an ID token.")
|
||||
http.Error(w, "Authentication failed. Try signing in again. The error was: the issuer's authorization response did not contain an ID token.", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
// Parse and verify ID Token payload. See http://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation.
|
||||
var idToken *oidc.IDToken
|
||||
if mockVerifyIDToken != nil {
|
||||
idToken = mockVerifyIDToken(rawIDToken)
|
||||
} else {
|
||||
idToken, err = verifier.Verify(ctx, rawIDToken)
|
||||
if err != nil {
|
||||
log15.Error("OpenID Connect auth failed: the ID token verification failed.", "error", err)
|
||||
http.Error(w, "Authentication failed. Try signing in again. The error was: OpenID Connect ID token could not be verified.", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// Validate the nonce. The Verify method explicitly doesn't handle nonce validation, so we do that here.
|
||||
// We set the nonce to be the same as the state in the Authentication Request state, so we check for equality
|
||||
// here.
|
||||
if idToken.Nonce != stateParam {
|
||||
log15.Error("OpenID Connect auth failed: nonce is incorrect (possible replay attach).")
|
||||
http.Error(w, "Authentication failed. Try signing in again (and clearing cookies for the current site). The error was: OpenID Connect nonce is incorrect (possible replay attack).", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
userInfo, err := p.oidc.UserInfo(ctx, oauth2.StaticTokenSource(oauth2Token))
|
||||
if err != nil {
|
||||
log15.Error("Failed to get userinfo", "error", err)
|
||||
http.Error(w, "Failed to get userinfo: "+err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if p.config.RequireEmailDomain != "" && !strings.HasSuffix(userInfo.Email, "@"+p.config.RequireEmailDomain) {
|
||||
log15.Error("OpenID Connect auth failed: user's email is not from allowed domain.", "userEmail", userInfo.Email, "requireEmailDomain", p.config.RequireEmailDomain)
|
||||
http.Error(w, fmt.Sprintf("Authentication failed. Only users in %q are allowed.", p.config.RequireEmailDomain), http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
var claims userClaims
|
||||
if err := userInfo.Claims(&claims); err != nil {
|
||||
log15.Warn("OpenID Connect auth: could not parse userInfo claims.", "error", err)
|
||||
}
|
||||
actr, safeErrMsg, err := getOrCreateUser(ctx, p, idToken, userInfo, &claims)
|
||||
if err != nil {
|
||||
log15.Error("OpenID Connect auth failed: error looking up OpenID-authenticated user.", "error", err, "userErr", safeErrMsg)
|
||||
http.Error(w, safeErrMsg, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
var exp time.Duration
|
||||
// 🚨 SECURITY: TODO(sqs): We *should* uncomment the lines below to make our own sessions
|
||||
// only last for as long as the OP said the access token is active for. Unfortunately,
|
||||
// until we support refreshing access tokens in the background
|
||||
// (https://github.com/sourcegraph/sourcegraph/issues/11340), this provides a bad user
|
||||
// experience because users need to re-authenticate via OIDC every minute or so
|
||||
// (assuming their OIDC OP, like many, has a 1-minute access token validity period).
|
||||
//
|
||||
// if !idToken.Expiry.IsZero() {
|
||||
// exp = time.Until(idToken.Expiry)
|
||||
// }
|
||||
if err := session.SetActor(w, r, actr, exp); err != nil {
|
||||
log15.Error("OpenID Connect auth failed: could not initiate session.", "error", err)
|
||||
http.Error(w, "Authentication failed. Try signing in again (and clearing cookies for the current site). The error was: could not initiate session.", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
data := sessionData{
|
||||
ID: p.ConfigID(),
|
||||
AccessToken: oauth2Token.AccessToken,
|
||||
TokenType: oauth2Token.TokenType,
|
||||
}
|
||||
if err := session.SetData(w, r, sessionKey, data); err != nil {
|
||||
// It's not fatal if this fails. It just means we won't be able to sign the user out of
|
||||
// the OP.
|
||||
log15.Warn("Failed to set OpenID Connect session data. The session is still secure, but Sourcegraph will be unable to revoke the user's token or redirect the user to the end-session endpoint after the user signs out of Sourcegraph.", "error", err)
|
||||
}
|
||||
|
||||
// 🚨 SECURITY: Call auth.SafeRedirectURL to avoid an open-redirect vuln.
|
||||
http.Redirect(w, r, auth.SafeRedirectURL(state.Redirect), http.StatusFound)
|
||||
|
||||
default:
|
||||
http.Error(w, "", http.StatusNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
// authnState is the state parameter passed to the Authn request and returned in the Authn response callback.
|
||||
type authnState struct {
|
||||
CSRFToken string `json:"csrfToken"`
|
||||
Redirect string `json:"redirect"`
|
||||
|
||||
// Allow /.auth/callback to demux callbacks from multiple OpenID Connect OPs.
|
||||
ProviderID string `json:"p"`
|
||||
}
|
||||
|
||||
// Encode returns the base64-encoded JSON representation of the authn state.
|
||||
func (s *authnState) Encode() string {
|
||||
b, _ := json.Marshal(s)
|
||||
return base64.StdEncoding.EncodeToString(b)
|
||||
}
|
||||
|
||||
// Decode decodes the base64-encoded JSON representation of the authn state into the receiver.
|
||||
func (s *authnState) Decode(encoded string) error {
|
||||
b, err := base64.StdEncoding.DecodeString(encoded)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(b, s)
|
||||
}
|
||||
|
||||
const stateCookieTimeout = time.Minute * 15
|
||||
|
||||
func redirectToAuthRequest(w http.ResponseWriter, r *http.Request, p *provider, returnToURL string) {
|
||||
// The state parameter is an opaque value used to maintain state between the original Authentication Request
|
||||
// and the callback. We do not record any state beyond a CSRF token used to defend against CSRF attacks against the callback.
|
||||
// We use the CSRF token created by gorilla/csrf that is used for other app endpoints as the OIDC state parameter.
|
||||
//
|
||||
// See http://openid.net/specs/openid-connect-core-1_0.html#AuthRequest of the OIDC spec.
|
||||
state := (&authnState{
|
||||
CSRFToken: csrf.Token(r),
|
||||
Redirect: returnToURL,
|
||||
ProviderID: p.ConfigID().ID,
|
||||
}).Encode()
|
||||
http.SetCookie(w, &http.Cookie{
|
||||
Name: stateCookieName,
|
||||
Value: state,
|
||||
Path: auth.AuthURLPrefix + "/", // include the OIDC redirect URI (/.auth/callback not /.auth/openidconnect/callback for BACKCOMPAT)
|
||||
Expires: time.Now().Add(stateCookieTimeout),
|
||||
})
|
||||
|
||||
// Redirect to the OP's Authorization Endpoint for authentication. The nonce is an optional
|
||||
// string value used to associate a Client session with an ID Token and to mitigate replay attacks.
|
||||
// Whereas the state parameter is used in validating the Authentication Request
|
||||
// callback, the nonce is used in validating the response to the ID Token request.
|
||||
// We re-use the Authn request state as the nonce.
|
||||
//
|
||||
// See http://openid.net/specs/openid-connect-core-1_0.html#AuthRequest of the OIDC spec.
|
||||
http.Redirect(w, r, p.oauth2Config().AuthCodeURL(state, oidc.Nonce(state)), http.StatusFound)
|
||||
}
|
||||
365
enterprise/cmd/frontend/auth/openidconnect/middleware_test.go
Normal file
365
enterprise/cmd/frontend/auth/openidconnect/middleware_test.go
Normal file
@ -0,0 +1,365 @@
|
||||
package openidconnect
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"os"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/sourcegraph/enterprise/cmd/frontend/internal/licensing"
|
||||
"github.com/sourcegraph/enterprise/pkg/license"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/auth"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/actor"
|
||||
"github.com/sourcegraph/sourcegraph/schema"
|
||||
|
||||
oidc "github.com/coreos/go-oidc"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/session"
|
||||
)
|
||||
|
||||
// providerJSON is the JSON structure the OIDC provider returns at its discovery endpoing
|
||||
type providerJSON struct {
|
||||
Issuer string `json:"issuer"`
|
||||
AuthURL string `json:"authorization_endpoint"`
|
||||
TokenURL string `json:"token_endpoint"`
|
||||
JWKSURL string `json:"jwks_uri"`
|
||||
UserInfoURL string `json:"userinfo_endpoint"`
|
||||
}
|
||||
|
||||
var (
|
||||
testOIDCUser = "bob-test-user"
|
||||
testClientID = "aaaaaaaaaaaaaa"
|
||||
)
|
||||
|
||||
// new OIDCIDServer returns a new running mock OIDC ID Provider service. It is the caller's
|
||||
// responsibility to call Close().
|
||||
func newOIDCIDServer(t *testing.T, code string, oidcProvider *schema.OpenIDConnectAuthProvider) (server *httptest.Server, emailPtr *string) {
|
||||
idBearerToken := "test_id_token_f4bdefbd77f"
|
||||
s := http.NewServeMux()
|
||||
|
||||
s.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
json.NewEncoder(w).Encode(providerJSON{
|
||||
Issuer: oidcProvider.Issuer,
|
||||
AuthURL: oidcProvider.Issuer + "/oauth2/v1/authorize",
|
||||
TokenURL: oidcProvider.Issuer + "/oauth2/v1/token",
|
||||
UserInfoURL: oidcProvider.Issuer + "/oauth2/v1/userinfo",
|
||||
})
|
||||
})
|
||||
s.HandleFunc("/oauth2/v1/token", func(w http.ResponseWriter, r *http.Request) {
|
||||
if r.Method != "POST" {
|
||||
http.Error(w, "unexpected", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
b, _ := ioutil.ReadAll(r.Body)
|
||||
values, _ := url.ParseQuery(string(b))
|
||||
|
||||
if values.Get("code") != code {
|
||||
t.Errorf("got code %q, want %q", values.Get("code"), code)
|
||||
}
|
||||
if got, want := values.Get("grant_type"), "authorization_code"; got != want {
|
||||
t.Errorf("got grant_type %v, want %v", got, want)
|
||||
}
|
||||
redirectURI, _ := url.QueryUnescape(values.Get("redirect_uri"))
|
||||
if want := "http://example.com/.auth/callback"; redirectURI != want {
|
||||
t.Errorf("got redirect_uri %v, want %v", redirectURI, want)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(fmt.Sprintf(`{
|
||||
"access_token": "aaaaa",
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 3600,
|
||||
"scope": "openid",
|
||||
"id_token": %q
|
||||
}`, idBearerToken)))
|
||||
})
|
||||
email := "bob@example.com"
|
||||
s.HandleFunc("/oauth2/v1/userinfo", func(w http.ResponseWriter, r *http.Request) {
|
||||
authzHeader := r.Header.Get("Authorization")
|
||||
authzParts := strings.Split(authzHeader, " ")
|
||||
if len(authzParts) != 2 {
|
||||
t.Fatalf("Expected 2 parts to authz header, instead got %d: %q", len(authzParts), authzHeader)
|
||||
}
|
||||
if authzParts[0] != "Bearer" {
|
||||
t.Fatalf("No bearer token found in authz header %q", authzHeader)
|
||||
}
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.Write([]byte(fmt.Sprintf(`{
|
||||
"sub": %q,
|
||||
"profile": "This is a profile",
|
||||
"email": "`+email+`",
|
||||
"email_verified": true,
|
||||
"picture": "https://example.com/picture.png"
|
||||
}`, testOIDCUser)))
|
||||
})
|
||||
|
||||
srv := httptest.NewServer(s)
|
||||
|
||||
auth.SetMockCreateOrUpdateUser(func(u db.NewUser, a db.ExternalAccountSpec) (userID int32, err error) {
|
||||
if a.ServiceType == "openidconnect" && a.ServiceID == oidcProvider.Issuer && a.ClientID == testClientID && a.AccountID == testOIDCUser {
|
||||
return 123, nil
|
||||
}
|
||||
return 0, fmt.Errorf("account %v not found in mock", a)
|
||||
})
|
||||
|
||||
return srv, &email
|
||||
}
|
||||
|
||||
func TestMiddleware(t *testing.T) {
|
||||
cleanup := session.ResetMockSessionStore(t)
|
||||
defer cleanup()
|
||||
|
||||
licensing.MockGetConfiguredProductLicenseInfo = func() (*license.Info, error) {
|
||||
return &license.Info{Tags: licensing.EnterpriseTags}, nil
|
||||
}
|
||||
defer func() { licensing.MockGetConfiguredProductLicenseInfo = nil }()
|
||||
|
||||
tempdir, err := ioutil.TempDir("", "sourcegraph-oidc-test")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(tempdir)
|
||||
|
||||
mockGetProviderValue = &provider{
|
||||
config: schema.OpenIDConnectAuthProvider{
|
||||
ClientID: testClientID,
|
||||
ClientSecret: "aaaaaaaaaaaaaaaaaaaaaaaaa",
|
||||
RequireEmailDomain: "example.com",
|
||||
},
|
||||
}
|
||||
defer func() { mockGetProviderValue = nil }()
|
||||
auth.SetMockProviders([]auth.Provider{mockGetProviderValue})
|
||||
defer func() { auth.SetMockProviders(nil) }()
|
||||
|
||||
oidcIDServer, emailPtr := newOIDCIDServer(t, "THECODE", &mockGetProviderValue.config)
|
||||
defer oidcIDServer.Close()
|
||||
defer func() { auth.SetMockCreateOrUpdateUser(nil) }()
|
||||
mockGetProviderValue.config.Issuer = oidcIDServer.URL
|
||||
|
||||
if err := mockGetProviderValue.Refresh(context.Background()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
validState := (&authnState{CSRFToken: "THE_CSRF_TOKEN", Redirect: "/redirect", ProviderID: mockGetProviderValue.ConfigID().ID}).Encode()
|
||||
mockVerifyIDToken = func(rawIDToken string) *oidc.IDToken {
|
||||
if rawIDToken != "test_id_token_f4bdefbd77f" {
|
||||
t.Fatalf("unexpected raw ID token: %s", rawIDToken)
|
||||
}
|
||||
return &oidc.IDToken{
|
||||
Issuer: oidcIDServer.URL,
|
||||
Subject: testOIDCUser,
|
||||
Expiry: time.Now().Add(time.Hour),
|
||||
Nonce: validState, // we re-use the state param as the nonce
|
||||
}
|
||||
}
|
||||
|
||||
const mockUserID = 123
|
||||
|
||||
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
authedHandler := http.NewServeMux()
|
||||
authedHandler.Handle("/.api/", Middleware.API(h))
|
||||
authedHandler.Handle("/", Middleware.App(h))
|
||||
|
||||
doRequest := func(method, urlStr, body string, cookies []*http.Cookie, authed bool) *http.Response {
|
||||
req := httptest.NewRequest(method, urlStr, bytes.NewBufferString(body))
|
||||
for _, cookie := range cookies {
|
||||
req.AddCookie(cookie)
|
||||
}
|
||||
if authed {
|
||||
req = req.WithContext(actor.WithActor(context.Background(), &actor.Actor{UID: mockUserID}))
|
||||
}
|
||||
respRecorder := httptest.NewRecorder()
|
||||
authedHandler.ServeHTTP(respRecorder, req)
|
||||
return respRecorder.Result()
|
||||
}
|
||||
|
||||
state := func(t *testing.T, urlStr string) (state authnState) {
|
||||
u, _ := url.Parse(urlStr)
|
||||
if err := state.Decode(u.Query().Get("nonce")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return state
|
||||
}
|
||||
t.Run("unauthenticated homepage visit -> oidc auth flow", func(t *testing.T) {
|
||||
resp := doRequest("GET", "http://example.com/", "", nil, false)
|
||||
if want := http.StatusFound; resp.StatusCode != want {
|
||||
t.Errorf("got response code %v, want %v", resp.StatusCode, want)
|
||||
}
|
||||
if got, want := resp.Header.Get("Location"), "/oauth2/v1/authorize?"; !strings.Contains(got, want) {
|
||||
t.Errorf("got redirect URL %v, want contains %v", got, want)
|
||||
}
|
||||
if state, want := state(t, resp.Header.Get("Location")), "/"; state.Redirect != want {
|
||||
t.Errorf("got redirect destination %q, want %q", state.Redirect, want)
|
||||
}
|
||||
})
|
||||
t.Run("unauthenticated subpage visit -> oidc auth flow", func(t *testing.T) {
|
||||
resp := doRequest("GET", "http://example.com/page", "", nil, false)
|
||||
if want := http.StatusFound; resp.StatusCode != want {
|
||||
t.Errorf("got response code %v, want %v", resp.StatusCode, want)
|
||||
}
|
||||
if got, want := resp.Header.Get("Location"), "/oauth2/v1/authorize?"; !strings.Contains(got, want) {
|
||||
t.Errorf("got redirect URL %v, want contains %v", got, want)
|
||||
}
|
||||
if state, want := state(t, resp.Header.Get("Location")), "/page"; state.Redirect != want {
|
||||
t.Errorf("got redirect destination %q, want %q", state.Redirect, want)
|
||||
}
|
||||
})
|
||||
t.Run("unauthenticated non-existent page visit -> oidc auth flow", func(t *testing.T) {
|
||||
resp := doRequest("GET", "http://example.com/nonexistent", "", nil, false)
|
||||
if want := http.StatusFound; resp.StatusCode != want {
|
||||
t.Errorf("got response code %v, want %v", resp.StatusCode, want)
|
||||
}
|
||||
if got, want := resp.Header.Get("Location"), "/oauth2/v1/authorize?"; !strings.Contains(got, want) {
|
||||
t.Errorf("got redirect URL %v, want contains %v", got, want)
|
||||
}
|
||||
if state, want := state(t, resp.Header.Get("Location")), "/nonexistent"; state.Redirect != want {
|
||||
t.Errorf("got redirect destination %q, want %q", state.Redirect, want)
|
||||
}
|
||||
})
|
||||
t.Run("unauthenticated API request -> pass through", func(t *testing.T) {
|
||||
resp := doRequest("GET", "http://example.com/.api/foo", "", nil, false)
|
||||
if want := http.StatusOK; resp.StatusCode != want {
|
||||
t.Errorf("got response code %v, want %v", resp.StatusCode, want)
|
||||
}
|
||||
})
|
||||
t.Run("login -> oidc auth flow", func(t *testing.T) {
|
||||
resp := doRequest("GET", "http://example.com/.auth/openidconnect/login?p="+mockGetProviderValue.ConfigID().ID, "", nil, false)
|
||||
if want := http.StatusFound; resp.StatusCode != want {
|
||||
t.Errorf("got response code %v, want %v", resp.StatusCode, want)
|
||||
}
|
||||
locHeader := resp.Header.Get("Location")
|
||||
if !strings.HasPrefix(locHeader, mockGetProviderValue.config.Issuer+"/") {
|
||||
t.Error("did not redirect to OIDC Provider")
|
||||
}
|
||||
idpLoginURL, err := url.Parse(locHeader)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got, want := idpLoginURL.Query().Get("client_id"), mockGetProviderValue.config.ClientID; got != want {
|
||||
t.Errorf("got client id %q, want %q", got, want)
|
||||
}
|
||||
if got, want := idpLoginURL.Query().Get("redirect_uri"), "http://example.com/.auth/callback"; got != want {
|
||||
t.Errorf("got redirect_uri %v, want %v", got, want)
|
||||
}
|
||||
if got, want := idpLoginURL.Query().Get("response_type"), "code"; got != want {
|
||||
t.Errorf("got response_type %v, want %v", got, want)
|
||||
}
|
||||
if got, want := idpLoginURL.Query().Get("scope"), "openid profile email"; got != want {
|
||||
t.Errorf("got scope %v, want %v", got, want)
|
||||
}
|
||||
})
|
||||
t.Run("OIDC callback without CSRF token -> error", func(t *testing.T) {
|
||||
invalidState := (&authnState{CSRFToken: "bad", ProviderID: mockGetProviderValue.ConfigID().ID}).Encode()
|
||||
resp := doRequest("GET", "http://example.com/.auth/callback?code=THECODE&state="+url.PathEscape(invalidState), "", nil, false)
|
||||
if want := http.StatusBadRequest; resp.StatusCode != want {
|
||||
t.Errorf("got status code %v, want %v", resp.StatusCode, want)
|
||||
}
|
||||
})
|
||||
t.Run("OIDC callback with CSRF token -> set auth cookies", func(t *testing.T) {
|
||||
resp := doRequest("GET", "http://example.com/.auth/callback?code=THECODE&state="+url.PathEscape(validState), "", []*http.Cookie{{Name: stateCookieName, Value: validState}}, false)
|
||||
if want := http.StatusFound; resp.StatusCode != want {
|
||||
t.Errorf("got status code %v, want %v", resp.StatusCode, want)
|
||||
}
|
||||
if got, want := resp.Header.Get("Location"), "/redirect"; got != want {
|
||||
t.Errorf("got redirect URL %v, want %v", got, want)
|
||||
}
|
||||
})
|
||||
*emailPtr = "bob@invalid.com" // doesn't match requiredEmailDomain
|
||||
t.Run("OIDC callback with bad email domain -> error", func(t *testing.T) {
|
||||
resp := doRequest("GET", "http://example.com/.auth/callback?code=THECODE&state="+url.PathEscape(validState), "", []*http.Cookie{{Name: stateCookieName, Value: validState}}, false)
|
||||
if want := http.StatusUnauthorized; resp.StatusCode != want {
|
||||
t.Errorf("got status code %v, want %v", resp.StatusCode, want)
|
||||
}
|
||||
})
|
||||
t.Run("authenticated app request", func(t *testing.T) {
|
||||
resp := doRequest("GET", "http://example.com/", "", nil, true)
|
||||
if want := http.StatusOK; resp.StatusCode != want {
|
||||
t.Errorf("got response code %v, want %v", resp.StatusCode, want)
|
||||
}
|
||||
})
|
||||
t.Run("authenticated API request", func(t *testing.T) {
|
||||
resp := doRequest("GET", "http://example.com/.api/foo", "", nil, true)
|
||||
if want := http.StatusOK; resp.StatusCode != want {
|
||||
t.Errorf("got response code %v, want %v", resp.StatusCode, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestMiddleware_NoOpenRedirect(t *testing.T) {
|
||||
cleanup := session.ResetMockSessionStore(t)
|
||||
defer cleanup()
|
||||
|
||||
licensing.MockGetConfiguredProductLicenseInfo = func() (*license.Info, error) {
|
||||
return &license.Info{Tags: licensing.EnterpriseTags}, nil
|
||||
}
|
||||
defer func() { licensing.MockGetConfiguredProductLicenseInfo = nil }()
|
||||
|
||||
tempdir, err := ioutil.TempDir("", "sourcegraph-oidc-test-no-open-redirect")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer os.RemoveAll(tempdir)
|
||||
|
||||
mockGetProviderValue = &provider{
|
||||
config: schema.OpenIDConnectAuthProvider{
|
||||
ClientID: testClientID,
|
||||
ClientSecret: "aaaaaaaaaaaaaaaaaaaaaaaaa",
|
||||
},
|
||||
}
|
||||
defer func() { mockGetProviderValue = nil }()
|
||||
auth.SetMockProviders([]auth.Provider{mockGetProviderValue})
|
||||
defer func() { auth.SetMockProviders(nil) }()
|
||||
|
||||
oidcIDServer, _ := newOIDCIDServer(t, "THECODE", &mockGetProviderValue.config)
|
||||
defer oidcIDServer.Close()
|
||||
defer func() { auth.SetMockCreateOrUpdateUser(nil) }()
|
||||
mockGetProviderValue.config.Issuer = oidcIDServer.URL
|
||||
|
||||
if err := mockGetProviderValue.Refresh(context.Background()); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
state := (&authnState{CSRFToken: "THE_CSRF_TOKEN", Redirect: "http://evil.com", ProviderID: mockGetProviderValue.ConfigID().ID}).Encode()
|
||||
mockVerifyIDToken = func(rawIDToken string) *oidc.IDToken {
|
||||
if rawIDToken != "test_id_token_f4bdefbd77f" {
|
||||
t.Fatalf("unexpected raw ID token: %s", rawIDToken)
|
||||
}
|
||||
return &oidc.IDToken{
|
||||
Issuer: oidcIDServer.URL,
|
||||
Subject: testOIDCUser,
|
||||
Expiry: time.Now().Add(time.Hour),
|
||||
Nonce: state, // we re-use the state param as the nonce
|
||||
}
|
||||
}
|
||||
|
||||
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
|
||||
authedHandler := Middleware.App(h)
|
||||
|
||||
doRequest := func(method, urlStr, body string, cookies []*http.Cookie) *http.Response {
|
||||
req := httptest.NewRequest(method, urlStr, bytes.NewBufferString(body))
|
||||
for _, cookie := range cookies {
|
||||
req.AddCookie(cookie)
|
||||
}
|
||||
respRecorder := httptest.NewRecorder()
|
||||
authedHandler.ServeHTTP(respRecorder, req)
|
||||
return respRecorder.Result()
|
||||
}
|
||||
|
||||
t.Run("OIDC callback with CSRF token -> set auth cookies", func(t *testing.T) {
|
||||
resp := doRequest("GET", "http://example.com/.auth/callback?code=THECODE&state="+url.PathEscape(state), "", []*http.Cookie{{Name: stateCookieName, Value: state}})
|
||||
if want := http.StatusFound; resp.StatusCode != want {
|
||||
t.Errorf("got status code %v, want %v", resp.StatusCode, want)
|
||||
}
|
||||
if got, want := resp.Header.Get("Location"), "/"; got != want {
|
||||
t.Errorf("got redirect URL %v, want %v", got, want)
|
||||
} // Redirect to "/", NOT "http://evil.com"
|
||||
})
|
||||
}
|
||||
158
enterprise/cmd/frontend/auth/openidconnect/provider.go
Normal file
158
enterprise/cmd/frontend/auth/openidconnect/provider.go
Normal file
@ -0,0 +1,158 @@
|
||||
package openidconnect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
oidc "github.com/coreos/go-oidc"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/auth"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/globals"
|
||||
"github.com/sourcegraph/sourcegraph/schema"
|
||||
"golang.org/x/net/context/ctxhttp"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
const providerType = "openidconnect"
|
||||
|
||||
type provider struct {
|
||||
config schema.OpenIDConnectAuthProvider
|
||||
|
||||
mu sync.Mutex
|
||||
oidc *oidcProvider
|
||||
refreshErr error
|
||||
}
|
||||
|
||||
// ConfigID implements auth.Provider.
|
||||
func (p *provider) ConfigID() auth.ProviderConfigID {
|
||||
return auth.ProviderConfigID{
|
||||
Type: providerType,
|
||||
ID: providerConfigID(&p.config),
|
||||
}
|
||||
}
|
||||
|
||||
// Config implements auth.Provider.
|
||||
func (p *provider) Config() schema.AuthProviders {
|
||||
return schema.AuthProviders{Openidconnect: &p.config}
|
||||
}
|
||||
|
||||
// Refresh implements auth.Provider.
|
||||
func (p *provider) Refresh(ctx context.Context) error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.oidc, p.refreshErr = newProvider(ctx, p.config.Issuer)
|
||||
return p.refreshErr
|
||||
}
|
||||
|
||||
func (p *provider) getCachedInfoAndError() (*auth.ProviderInfo, error) {
|
||||
info := auth.ProviderInfo{
|
||||
ServiceID: p.config.Issuer,
|
||||
ClientID: p.config.ClientID,
|
||||
DisplayName: p.config.DisplayName,
|
||||
AuthenticationURL: (&url.URL{
|
||||
Path: path.Join(authPrefix, "login"),
|
||||
RawQuery: (url.Values{"pc": []string{providerConfigID(&p.config)}}).Encode(),
|
||||
}).String(),
|
||||
}
|
||||
if info.DisplayName == "" {
|
||||
info.DisplayName = "OpenID Connect"
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
err := p.refreshErr
|
||||
if err != nil {
|
||||
err = errors.WithMessage(err, "failed to initialize OpenID Connect auth provider")
|
||||
} else if p.oidc == nil {
|
||||
err = errors.New("OpenID Connect auth provider is not yet initialized")
|
||||
}
|
||||
return &info, err
|
||||
}
|
||||
|
||||
// CachedInfo implements auth.Provider.
|
||||
func (p *provider) CachedInfo() *auth.ProviderInfo {
|
||||
info, _ := p.getCachedInfoAndError()
|
||||
return info
|
||||
}
|
||||
|
||||
func (p *provider) oauth2Config() *oauth2.Config {
|
||||
return &oauth2.Config{
|
||||
ClientID: p.config.ClientID,
|
||||
ClientSecret: p.config.ClientSecret,
|
||||
|
||||
// It would be nice if this was "/.auth/openidconnect/callback" not "/.auth/callback", but
|
||||
// many instances have the "/.auth/callback" value hardcoded in their external auth
|
||||
// provider, so we can't change it easily
|
||||
RedirectURL: globals.AppURL().ResolveReference(&url.URL{Path: path.Join(auth.AuthURLPrefix, "callback")}).String(),
|
||||
|
||||
Endpoint: p.oidc.Endpoint(),
|
||||
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
|
||||
}
|
||||
}
|
||||
|
||||
// oidcProvider is an OpenID Connect oidcProvider with additional claims parsed from the service oidcProvider
|
||||
// discovery response (beyond what github.com/coreos/go-oidc parses by default).
|
||||
type oidcProvider struct {
|
||||
oidc.Provider
|
||||
providerExtraClaims
|
||||
}
|
||||
|
||||
type providerExtraClaims struct {
|
||||
// EndSessionEndpoint is the URL of the OP's endpoint that logs the user out of the OP (provided
|
||||
// in the "end_session_endpoint" field of the OP's service discovery response). See
|
||||
// https://openid.net/specs/openid-connect-session-1_0.html#OPMetadata.
|
||||
EndSessionEndpoint string `json:"end_session_endpoint,omitempty"`
|
||||
|
||||
// RevocationEndpoint is the URL of the OP's revocation endpoint (provided in the
|
||||
// "revocation_endpoint" field of the OP's service discovery response). See
|
||||
// https://openid.net/specs/openid-heart-openid-connect-1_0.html#rfc.section.3.5 and
|
||||
// https://tools.ietf.org/html/rfc7009.
|
||||
RevocationEndpoint string `json:"revocation_endpoint,omitempty"`
|
||||
}
|
||||
|
||||
var mockNewProvider func(issuerURL string) (*oidcProvider, error)
|
||||
|
||||
func newProvider(ctx context.Context, issuerURL string) (*oidcProvider, error) {
|
||||
if mockNewProvider != nil {
|
||||
return mockNewProvider(issuerURL)
|
||||
}
|
||||
|
||||
bp, err := oidc.NewProvider(context.Background(), issuerURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p := &oidcProvider{Provider: *bp}
|
||||
if err := bp.Claims(&p.providerExtraClaims); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// revokeToken implements Token Revocation. See https://tools.ietf.org/html/rfc7009.
|
||||
func revokeToken(ctx context.Context, p *provider, accessToken, tokenType string) error {
|
||||
postData := url.Values{}
|
||||
postData.Set("token", accessToken)
|
||||
if tokenType != "" {
|
||||
postData.Set("token_type_hint", tokenType)
|
||||
}
|
||||
req, err := http.NewRequest(p.oidc.RevocationEndpoint, "application/x-www-form-urlencoded", strings.NewReader(postData.Encode()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
req.SetBasicAuth(p.config.ClientID, p.config.ClientSecret)
|
||||
resp, err := ctxhttp.Do(ctx, nil, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return fmt.Errorf("non-200 HTTP response from token revocation endpoint %s: HTTP %d", p.oidc.RevocationEndpoint, resp.StatusCode)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
56
enterprise/cmd/frontend/auth/openidconnect/session.go
Normal file
56
enterprise/cmd/frontend/auth/openidconnect/session.go
Normal file
@ -0,0 +1,56 @@
|
||||
package openidconnect
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/auth"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/session"
|
||||
)
|
||||
|
||||
const sessionKey = "oidc@0"
|
||||
|
||||
type sessionData struct {
|
||||
ID auth.ProviderConfigID
|
||||
|
||||
// Store only the oauth2.Token fields we need, to avoid hitting the ~4096-byte session data
|
||||
// limit.
|
||||
AccessToken string
|
||||
TokenType string
|
||||
}
|
||||
|
||||
// SignOut clears OpenID Connect-related data from the session. If possible, it revokes the token
|
||||
// from the OP. If there is an end-session endpoint, it returns that for the caller to present to
|
||||
// the user.
|
||||
func SignOut(w http.ResponseWriter, r *http.Request) (endSessionEndpoint string, err error) {
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_ = session.SetData(w, r, sessionKey, nil) // clear the bad data
|
||||
}
|
||||
}()
|
||||
|
||||
var data *sessionData
|
||||
if err := session.GetData(r, sessionKey, &data); err != nil {
|
||||
return "", errors.WithMessage(err, "reading OpenID Connect session data")
|
||||
}
|
||||
if err := session.SetData(w, r, sessionKey, nil); err != nil {
|
||||
return "", errors.WithMessage(err, "clearing OpenID Connect session data")
|
||||
}
|
||||
if data != nil {
|
||||
p := getProvider(data.ID.ID)
|
||||
if p == nil {
|
||||
return "", fmt.Errorf("unable to revoke token or end session for OpenID Connect because no provider %q exists", data.ID)
|
||||
}
|
||||
|
||||
endSessionEndpoint = p.oidc.EndSessionEndpoint
|
||||
|
||||
if p.oidc.RevocationEndpoint != "" {
|
||||
if err := revokeToken(r.Context(), p, data.AccessToken, data.TokenType); err != nil {
|
||||
return endSessionEndpoint, errors.WithMessage(err, "revoking OpenID Connect token")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return endSessionEndpoint, nil
|
||||
}
|
||||
73
enterprise/cmd/frontend/auth/openidconnect/user.go
Normal file
73
enterprise/cmd/frontend/auth/openidconnect/user.go
Normal file
@ -0,0 +1,73 @@
|
||||
package openidconnect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
oidc "github.com/coreos/go-oidc"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/auth"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/actor"
|
||||
)
|
||||
|
||||
// getOrCreateUser gets or creates a user account based on the OpenID Connect token. It returns the
|
||||
// authenticated actor if successful; otherwise it returns an friendly error message (safeErrMsg)
|
||||
// that is safe to display to users, and a non-nil err with lower-level error details.
|
||||
func getOrCreateUser(ctx context.Context, p *provider, idToken *oidc.IDToken, userInfo *oidc.UserInfo, claims *userClaims) (_ *actor.Actor, safeErrMsg string, err error) {
|
||||
if userInfo.Email == "" {
|
||||
return nil, "Only users with an email address may authenticate to Sourcegraph.", errors.New("no email address in claims")
|
||||
}
|
||||
if unverifiedEmail := claims.EmailVerified != nil && !*claims.EmailVerified; unverifiedEmail {
|
||||
// If the OP explicitly reports `"email_verified": false`, then reject the authentication
|
||||
// attempt. If undefined or true, then it will be allowed.
|
||||
return nil, fmt.Sprintf("Only users with verified email addresses may authenticate to Sourcegraph. The email address %q is not verified on the external authentication provider.", userInfo.Email), fmt.Errorf("refusing unverified user email address %q", userInfo.Email)
|
||||
}
|
||||
|
||||
pi, err := p.getCachedInfoAndError()
|
||||
if err != nil {
|
||||
return nil, "", err
|
||||
}
|
||||
|
||||
login := claims.PreferredUsername
|
||||
if login == "" {
|
||||
login = userInfo.Email
|
||||
}
|
||||
email := userInfo.Email
|
||||
var displayName = claims.GivenName
|
||||
if displayName == "" {
|
||||
if claims.Name == "" {
|
||||
displayName = claims.Name
|
||||
} else {
|
||||
displayName = login
|
||||
}
|
||||
}
|
||||
login, err = auth.NormalizeUsername(login)
|
||||
if err != nil {
|
||||
return nil, fmt.Sprintf("Error normalizing the username %q. See https://about.sourcegraph.com/docs/config/authentication#username-normalization.", login), err
|
||||
}
|
||||
|
||||
var data db.ExternalAccountData
|
||||
auth.SetExternalAccountData(&data.AccountData, struct {
|
||||
IDToken *oidc.IDToken `json:"idToken"`
|
||||
UserInfo *oidc.UserInfo `json:"userInfo"`
|
||||
UserClaims *userClaims `json:"userClaims"`
|
||||
}{IDToken: idToken, UserInfo: userInfo, UserClaims: claims})
|
||||
|
||||
userID, safeErrMsg, err := auth.CreateOrUpdateUser(ctx, db.NewUser{
|
||||
Username: login,
|
||||
Email: email,
|
||||
EmailIsVerified: email != "", // TODO(sqs): https://github.com/sourcegraph/sourcegraph/issues/10118
|
||||
DisplayName: displayName,
|
||||
AvatarURL: claims.Picture,
|
||||
}, db.ExternalAccountSpec{
|
||||
ServiceType: providerType,
|
||||
ServiceID: pi.ServiceID,
|
||||
ClientID: pi.ClientID,
|
||||
AccountID: idToken.Subject,
|
||||
}, data)
|
||||
if err != nil {
|
||||
return nil, safeErrMsg, err
|
||||
}
|
||||
return actor.FromUser(userID), "", nil
|
||||
}
|
||||
150
enterprise/cmd/frontend/auth/saml/config.go
Normal file
150
enterprise/cmd/frontend/auth/saml/config.go
Normal file
@ -0,0 +1,150 @@
|
||||
package saml
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"path"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/sourcegraph/enterprise/cmd/frontend/internal/licensing"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/auth"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/conf"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/env"
|
||||
"github.com/sourcegraph/sourcegraph/schema"
|
||||
log15 "gopkg.in/inconshreveable/log15.v2"
|
||||
)
|
||||
|
||||
var mockGetProviderValue *provider
|
||||
|
||||
// getProvider looks up the registered saml auth provider with the given ID.
|
||||
func getProvider(pcID string) *provider {
|
||||
if mockGetProviderValue != nil {
|
||||
return mockGetProviderValue
|
||||
}
|
||||
|
||||
p, _ := auth.GetProviderByConfigID(auth.ProviderConfigID{Type: providerType, ID: pcID}).(*provider)
|
||||
if p != nil {
|
||||
return p
|
||||
}
|
||||
|
||||
// Special case: if there is only a single SAML auth provider, return it regardless of the pcID.
|
||||
for _, ap := range auth.Providers() {
|
||||
if ap.Config().Saml != nil {
|
||||
if p != nil {
|
||||
return nil // multiple SAML providers, can't use this special case
|
||||
}
|
||||
p = ap.(*provider)
|
||||
}
|
||||
}
|
||||
|
||||
return p
|
||||
}
|
||||
|
||||
func handleGetProvider(ctx context.Context, w http.ResponseWriter, pcID string) (p *provider, handled bool) {
|
||||
handled = true // safer default
|
||||
|
||||
// License check.
|
||||
if !licensing.IsFeatureEnabledLenient(licensing.FeatureExternalAuthProvider) {
|
||||
licensing.WriteSubscriptionErrorResponseForFeature(w, "SAML user authentication (SSO)")
|
||||
return nil, true
|
||||
}
|
||||
|
||||
p = getProvider(pcID)
|
||||
if p == nil {
|
||||
log15.Error("No SAML auth provider found with ID.", "id", pcID)
|
||||
http.Error(w, "Misconfigured SAML auth provider.", http.StatusInternalServerError)
|
||||
return nil, true
|
||||
}
|
||||
if err := p.Refresh(ctx); err != nil {
|
||||
log15.Error("Error refreshing SAML auth provider.", "id", p.ConfigID(), "error", err)
|
||||
http.Error(w, "Unexpected error refreshing SAML authentication provider.", http.StatusInternalServerError)
|
||||
return nil, true
|
||||
}
|
||||
return p, false
|
||||
}
|
||||
|
||||
func init() {
|
||||
conf.ContributeValidator(validateConfig)
|
||||
}
|
||||
|
||||
func validateConfig(c schema.SiteConfiguration) (problems []string) {
|
||||
var loggedNeedsAppURL bool
|
||||
for _, p := range c.AuthProviders {
|
||||
if p.Saml != nil && c.AppURL == "" && !loggedNeedsAppURL {
|
||||
problems = append(problems, `saml auth provider requires appURL to be set to the external URL of your site (example: https://sourcegraph.example.com)`)
|
||||
loggedNeedsAppURL = true
|
||||
}
|
||||
}
|
||||
|
||||
seen := map[schema.SAMLAuthProvider]int{}
|
||||
for i, p := range c.AuthProviders {
|
||||
if p.Saml != nil {
|
||||
if j, ok := seen[*p.Saml]; ok {
|
||||
problems = append(problems, fmt.Sprintf("SAML auth provider at index %d is duplicate of index %d, ignoring", i, j))
|
||||
} else {
|
||||
seen[*p.Saml] = i
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return problems
|
||||
}
|
||||
|
||||
func withConfigDefaults(pc *schema.SAMLAuthProvider) *schema.SAMLAuthProvider {
|
||||
if pc.ServiceProviderIssuer == "" {
|
||||
appURL := conf.Get().AppURL
|
||||
if appURL == "" {
|
||||
// An empty issuer will be detected as an error later.
|
||||
return pc
|
||||
}
|
||||
|
||||
// Derive default issuer from appURL.
|
||||
tmp := *pc
|
||||
tmp.ServiceProviderIssuer = strings.TrimSuffix(conf.Get().AppURL, "/") + path.Join(authPrefix, "metadata")
|
||||
return &tmp
|
||||
}
|
||||
return pc
|
||||
}
|
||||
|
||||
func getNameIDFormat(pc *schema.SAMLAuthProvider) string {
|
||||
// Persistent is best because users will reuse their user_external_accounts row instead of (as
|
||||
// with transient) creating a new one each time they authenticate.
|
||||
const defaultNameIDFormat = "urn:oasis:names:tc:SAML:2.0:nameid-format:persistent"
|
||||
if pc.NameIDFormat != "" {
|
||||
return pc.NameIDFormat
|
||||
}
|
||||
return defaultNameIDFormat
|
||||
}
|
||||
|
||||
// providerConfigID produces a semi-stable identifier for a saml auth provider config object. It is
|
||||
// used to distinguish between multiple auth providers of the same type when in multi-step auth
|
||||
// flows. Its value is never persisted, and it must be deterministic.
|
||||
//
|
||||
// If there is only a single saml auth provider, it returns the empty string because that satisfies
|
||||
// the requirements above.
|
||||
func providerConfigID(pc *schema.SAMLAuthProvider, multiple bool) string {
|
||||
if !multiple {
|
||||
return ""
|
||||
}
|
||||
data, err := json.Marshal(pc)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
b := sha256.Sum256(data)
|
||||
return base64.RawURLEncoding.EncodeToString(b[:16])
|
||||
}
|
||||
|
||||
var traceLogEnabled, _ = strconv.ParseBool(env.Get("INSECURE_SAML_LOG_TRACES", "false", "Log all SAML requests and responses. Only use during testing because the log messages will contain sensitive data."))
|
||||
|
||||
func traceLog(description, body string) {
|
||||
if traceLogEnabled {
|
||||
const n = 40
|
||||
log.Printf("%s SAML trace: %s\n%s\n%s", strings.Repeat("=", n), description, body, strings.Repeat("=", n+len(description)+1))
|
||||
}
|
||||
}
|
||||
40
enterprise/cmd/frontend/auth/saml/config_test.go
Normal file
40
enterprise/cmd/frontend/auth/saml/config_test.go
Normal file
@ -0,0 +1,40 @@
|
||||
package saml
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/pkg/conf"
|
||||
"github.com/sourcegraph/sourcegraph/schema"
|
||||
)
|
||||
|
||||
func TestValidateCustom(t *testing.T) {
|
||||
tests := map[string]struct {
|
||||
input schema.SiteConfiguration
|
||||
wantProblems []string
|
||||
}{
|
||||
"duplicates": {
|
||||
input: schema.SiteConfiguration{
|
||||
AppURL: "x",
|
||||
AuthProviders: []schema.AuthProviders{
|
||||
{Saml: &schema.SAMLAuthProvider{Type: "saml", IdentityProviderMetadataURL: "x"}},
|
||||
{Saml: &schema.SAMLAuthProvider{Type: "saml", IdentityProviderMetadataURL: "x"}},
|
||||
},
|
||||
},
|
||||
wantProblems: []string{"SAML auth provider at index 1 is duplicate of index 0"},
|
||||
},
|
||||
}
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
conf.TestValidator(t, test.input, validateConfig, test.wantProblems)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestProviderConfigID(t *testing.T) {
|
||||
p := schema.SAMLAuthProvider{ServiceProviderIssuer: "x"}
|
||||
id1 := providerConfigID(&p, true)
|
||||
id2 := providerConfigID(&p, true)
|
||||
if id1 != id2 {
|
||||
t.Errorf("id1 (%q) != id2 (%q)", id1, id2)
|
||||
}
|
||||
}
|
||||
84
enterprise/cmd/frontend/auth/saml/config_watch.go
Normal file
84
enterprise/cmd/frontend/auth/saml/config_watch.go
Normal file
@ -0,0 +1,84 @@
|
||||
package saml
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/auth"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/conf"
|
||||
"github.com/sourcegraph/sourcegraph/schema"
|
||||
log15 "gopkg.in/inconshreveable/log15.v2"
|
||||
)
|
||||
|
||||
// Start trying to populate the cache of SAML IdP metadata immediately upon server startup and site
|
||||
// config changes so users don't incur the wait on the first auth flow request.
|
||||
func init() {
|
||||
providersOfType := func(ps []schema.AuthProviders) []*schema.SAMLAuthProvider {
|
||||
var pcs []*schema.SAMLAuthProvider
|
||||
for _, p := range ps {
|
||||
if p.Saml != nil {
|
||||
pcs = append(pcs, withConfigDefaults(p.Saml))
|
||||
}
|
||||
}
|
||||
return pcs
|
||||
}
|
||||
|
||||
var (
|
||||
init = true
|
||||
|
||||
mu sync.Mutex
|
||||
cur []*schema.SAMLAuthProvider
|
||||
reg = map[schema.SAMLAuthProvider]auth.Provider{}
|
||||
)
|
||||
conf.Watch(func() {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
// Only react when the config changes.
|
||||
new := providersOfType(conf.Get().AuthProviders)
|
||||
diff := diffProviderConfig(cur, new)
|
||||
if len(diff) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if !init {
|
||||
log15.Info("Reloading changed SAML authentication provider configuration.")
|
||||
}
|
||||
multiple := len(new) >= 2
|
||||
updates := make(map[auth.Provider]bool, len(diff))
|
||||
for pc, op := range diff {
|
||||
if old, ok := reg[pc]; ok {
|
||||
delete(reg, pc)
|
||||
updates[old] = false
|
||||
}
|
||||
if op {
|
||||
new := &provider{config: pc, multiple: multiple}
|
||||
reg[pc] = new
|
||||
updates[new] = true
|
||||
go func(p *provider) {
|
||||
if err := p.Refresh(context.Background()); err != nil {
|
||||
log15.Error("Error prefetching SAML service provider metadata.", "error", err)
|
||||
}
|
||||
}(new)
|
||||
}
|
||||
}
|
||||
auth.UpdateProviders(updates)
|
||||
cur = new
|
||||
})
|
||||
init = false
|
||||
}
|
||||
|
||||
func diffProviderConfig(old, new []*schema.SAMLAuthProvider) map[schema.SAMLAuthProvider]bool {
|
||||
diff := map[schema.SAMLAuthProvider]bool{}
|
||||
for _, oldPC := range old {
|
||||
diff[*oldPC] = false
|
||||
}
|
||||
for _, newPC := range new {
|
||||
if _, ok := diff[*newPC]; ok {
|
||||
delete(diff, *newPC)
|
||||
} else {
|
||||
diff[*newPC] = true
|
||||
}
|
||||
}
|
||||
return diff
|
||||
}
|
||||
46
enterprise/cmd/frontend/auth/saml/config_watch_test.go
Normal file
46
enterprise/cmd/frontend/auth/saml/config_watch_test.go
Normal file
@ -0,0 +1,46 @@
|
||||
package saml
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/schema"
|
||||
)
|
||||
|
||||
func TestDiffProviderConfig(t *testing.T) {
|
||||
var (
|
||||
pc0 = &schema.SAMLAuthProvider{ServiceProviderIssuer: "0"}
|
||||
pc0c = &schema.SAMLAuthProvider{ServiceProviderIssuer: "0", ServiceProviderPrivateKey: "x"}
|
||||
pc1 = &schema.SAMLAuthProvider{ServiceProviderIssuer: "1"}
|
||||
)
|
||||
|
||||
tests := map[string]struct {
|
||||
old, new []*schema.SAMLAuthProvider
|
||||
want map[schema.SAMLAuthProvider]bool
|
||||
}{
|
||||
"empty": {want: map[schema.SAMLAuthProvider]bool{}},
|
||||
"added": {
|
||||
old: nil,
|
||||
new: []*schema.SAMLAuthProvider{pc0, pc1},
|
||||
want: map[schema.SAMLAuthProvider]bool{*pc0: true, *pc1: true},
|
||||
},
|
||||
"changed": {
|
||||
old: []*schema.SAMLAuthProvider{pc0, pc1},
|
||||
new: []*schema.SAMLAuthProvider{pc0c, pc1},
|
||||
want: map[schema.SAMLAuthProvider]bool{*pc0: false, *pc0c: true},
|
||||
},
|
||||
"removed": {
|
||||
old: []*schema.SAMLAuthProvider{pc0, pc1},
|
||||
new: []*schema.SAMLAuthProvider{pc1},
|
||||
want: map[schema.SAMLAuthProvider]bool{*pc0: false},
|
||||
},
|
||||
}
|
||||
for name, test := range tests {
|
||||
t.Run(name, func(t *testing.T) {
|
||||
diff := diffProviderConfig(test.old, test.new)
|
||||
if !reflect.DeepEqual(diff, test.want) {
|
||||
t.Errorf("got != want\n got %+v\nwant %+v", diff, test.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
3
enterprise/cmd/frontend/auth/saml/doc.go
Normal file
3
enterprise/cmd/frontend/auth/saml/doc.go
Normal file
@ -0,0 +1,3 @@
|
||||
// Package SAML provides HTTP middleware that provides the necessary endpoints for a SAML Service
|
||||
// Provider (SP) to complete the SAML authentication flow to authenticate to the frontend.
|
||||
package saml
|
||||
262
enterprise/cmd/frontend/auth/saml/middleware.go
Normal file
262
enterprise/cmd/frontend/auth/saml/middleware.go
Normal file
@ -0,0 +1,262 @@
|
||||
package saml
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log15 "gopkg.in/inconshreveable/log15.v2"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/auth"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/session"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/actor"
|
||||
)
|
||||
|
||||
// All SAML endpoints are under this path prefix.
|
||||
const authPrefix = auth.AuthURLPrefix + "/saml"
|
||||
|
||||
// Middleware is middleware for SAML authentication, adding endpoints under the auth path prefix to
|
||||
// enable the login flow an requiring login for all other endpoints.
|
||||
//
|
||||
// 🚨 SECURITY
|
||||
var Middleware = &auth.Middleware{
|
||||
API: func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
authHandler(w, r, next, true)
|
||||
})
|
||||
},
|
||||
App: func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
authHandler(w, r, next, false)
|
||||
})
|
||||
},
|
||||
}
|
||||
|
||||
// authHandler is the new SAML HTTP auth handler.
|
||||
//
|
||||
// It uses github.com/russelhaering/gosaml2 and (unlike authHandler1) makes it possible to support
|
||||
// multiple auth providers with SAML and expose more SAML functionality.
|
||||
func authHandler(w http.ResponseWriter, r *http.Request, next http.Handler, isAPIRequest bool) {
|
||||
// Delegate to SAML ACS and metadata endpoint handlers.
|
||||
if !isAPIRequest && strings.HasPrefix(r.URL.Path, auth.AuthURLPrefix+"/saml/") {
|
||||
samlSPHandler(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// If the actor is authenticated and not performing a SAML operation, then proceed to next.
|
||||
if actor.FromContext(r.Context()).IsAuthenticated() {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
// If there is only one auth provider configured, the single auth provider is SAML, and it's an
|
||||
// app request, redirect to signin immediately. The user wouldn't be able to do anything else
|
||||
// anyway; there's no point in showing them a signin screen with just a single signin option.
|
||||
if ps := auth.Providers(); len(ps) == 1 && ps[0].Config().Saml != nil && !isAPIRequest {
|
||||
p, handled := handleGetProvider(r.Context(), w, ps[0].ConfigID().ID)
|
||||
if handled {
|
||||
return
|
||||
}
|
||||
redirectToAuthURL(w, r, p, auth.SafeRedirectURL(r.URL.String()))
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
func samlSPHandler(w http.ResponseWriter, r *http.Request) {
|
||||
requestPath := strings.TrimPrefix(r.URL.Path, authPrefix)
|
||||
|
||||
// Handle GET endpoints.
|
||||
if r.Method == "GET" {
|
||||
// All of these endpoints expect the provider ID in the URL query.
|
||||
p, handled := handleGetProvider(r.Context(), w, r.URL.Query().Get("pc"))
|
||||
if handled {
|
||||
return
|
||||
}
|
||||
|
||||
switch requestPath {
|
||||
case "/metadata":
|
||||
metadata, err := p.samlSP.Metadata()
|
||||
if err != nil {
|
||||
log15.Error("Error generating SAML service provider metadata.", "err", err)
|
||||
http.Error(w, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
buf, err := xml.MarshalIndent(metadata, "", " ")
|
||||
if err != nil {
|
||||
log15.Error("Error encoding SAML service provider metadata.", "err", err)
|
||||
http.Error(w, "", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
traceLog(fmt.Sprintf("Service Provider metadata: %s", p.ConfigID().ID), string(buf))
|
||||
w.Header().Set("Content-Type", "application/samlmetadata+xml; charset=utf-8")
|
||||
w.Write(buf)
|
||||
return
|
||||
|
||||
case "/login":
|
||||
// It is safe to use r.Referer() because the redirect-to URL will be checked later,
|
||||
// before the client is actually instructed to navigate there.
|
||||
redirectToAuthURL(w, r, p, r.Referer())
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if r.Method != "POST" {
|
||||
http.Error(w, "", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
if err := r.ParseForm(); err != nil {
|
||||
http.Error(w, "", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
// The remaining endpoints all expect the provider ID in the POST data's RelayState.
|
||||
traceLog("SAML RelayState", r.FormValue("RelayState"))
|
||||
var relayState relayState
|
||||
relayState.decode(r.FormValue("RelayState"))
|
||||
|
||||
p, handled := handleGetProvider(r.Context(), w, relayState.ProviderID)
|
||||
if handled {
|
||||
return
|
||||
}
|
||||
|
||||
// Handle POST endpoints.
|
||||
switch requestPath {
|
||||
case "/acs":
|
||||
info, err := readAuthnResponse(p, r.FormValue("SAMLResponse"))
|
||||
if err != nil {
|
||||
log15.Error("Error validating SAML assertions. Set the env var INSECURE_SAML_LOG_TRACES=1 to log all SAML requests and responses.", "err", err)
|
||||
http.Error(w, "Error validating SAML assertions. Try signing in again. If the problem persists, a site admin must check the configuration.", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
actor, safeErrMsg, err := getOrCreateUser(r.Context(), info)
|
||||
if err != nil {
|
||||
log15.Error("Error looking up SAML-authenticated user.", "err", err, "userErr", safeErrMsg)
|
||||
http.Error(w, safeErrMsg, http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
var exp time.Duration
|
||||
// 🚨 SECURITY: TODO(sqs): We *should* uncomment the line below to make our own sessions
|
||||
// only last for as long as the IdP said the authn grant is active for. Unfortunately,
|
||||
// until we support refreshing SAML authn in the background
|
||||
// (https://github.com/sourcegraph/sourcegraph/issues/11340), this provides a bad user
|
||||
// experience because users need to re-authenticate via SAML every minute or so
|
||||
// (assuming their SAML IdP, like many, has a 1-minute access token validity period).
|
||||
//
|
||||
// if info.SessionNotOnOrAfter != nil {
|
||||
// exp = time.Until(*info.SessionNotOnOrAfter)
|
||||
// }
|
||||
if err := session.SetActor(w, r, actor, exp); err != nil {
|
||||
log15.Error("Error setting SAML-authenticated actor in session.", "err", err)
|
||||
http.Error(w, "Error starting SAML-authenticated session. Try signing in again.", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 🚨 SECURITY: Call auth.SafeRedirectURL to avoid an open-redirect vuln.
|
||||
http.Redirect(w, r, auth.SafeRedirectURL(relayState.ReturnToURL), http.StatusFound)
|
||||
|
||||
case "/logout":
|
||||
encodedResp := r.FormValue("SAMLResponse")
|
||||
|
||||
{
|
||||
if raw, err := base64.StdEncoding.DecodeString(encodedResp); err == nil {
|
||||
traceLog(fmt.Sprintf("LogoutResponse: %s", p.ConfigID().ID), string(raw))
|
||||
}
|
||||
}
|
||||
|
||||
// TODO(sqs): Fully validate the LogoutResponse here (i.e., also validate that the document
|
||||
// is a valid LogoutResponse). It is possible that this request is being spoofed, but it
|
||||
// doesn't let an attacker do very much (just log a user out and redirect).
|
||||
//
|
||||
// 🚨 SECURITY: If this logout handler starts to do anything more advanced, it probably must
|
||||
// validate the LogoutResponse to avoid being vulnerable to spoofing.
|
||||
_, err := p.samlSP.ValidateEncodedResponse(encodedResp)
|
||||
if err != nil && !strings.HasPrefix(err.Error(), "unable to unmarshal response:") {
|
||||
log15.Error("Error validating SAML logout response.", "err", err)
|
||||
http.Error(w, "Error validating SAML logout response.", http.StatusForbidden)
|
||||
return
|
||||
}
|
||||
|
||||
// If this is an SP-initiated logout, then the actor has already been cleared from the
|
||||
// session (but there's no harm in clearing it again). If it's an IdP-initiated logout,
|
||||
// then it hasn't, and we must clear it here.
|
||||
if err := session.SetActor(w, r, nil, 0); err != nil {
|
||||
log15.Error("Error clearing actor from session in SAML logout handler.", "err", err)
|
||||
http.Error(w, "Error signing out of SAML-authenticated session.", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
http.Redirect(w, r, "/", http.StatusFound)
|
||||
|
||||
default:
|
||||
http.Error(w, "", http.StatusNotFound)
|
||||
}
|
||||
}
|
||||
|
||||
func redirectToAuthURL(w http.ResponseWriter, r *http.Request, p *provider, returnToURL string) {
|
||||
authURL, err := buildAuthURLRedirect(p, relayState{
|
||||
ProviderID: p.ConfigID().ID,
|
||||
ReturnToURL: auth.SafeRedirectURL(returnToURL),
|
||||
})
|
||||
if err != nil {
|
||||
log15.Error("Failed to build SAML auth URL.", "err", err)
|
||||
http.Error(w, "Unexpected error in SAML authentication provider.", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
http.Redirect(w, r, authURL, http.StatusFound)
|
||||
}
|
||||
|
||||
func buildAuthURLRedirect(p *provider, relayState relayState) (string, error) {
|
||||
doc, err := p.samlSP.BuildAuthRequestDocument()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
{
|
||||
if data, err := doc.WriteToString(); err == nil {
|
||||
traceLog(fmt.Sprintf("AuthnRequest: %s", p.ConfigID().ID), data)
|
||||
}
|
||||
}
|
||||
return p.samlSP.BuildAuthURLRedirect(relayState.encode(), doc)
|
||||
}
|
||||
|
||||
// relayState represents the decoded RelayState value in both the IdP-initiated and SP-initiated
|
||||
// login flows.
|
||||
//
|
||||
// SAML overloads the term "RelayState".
|
||||
// * In the SP-initiated login flow, it is an opaque value originated from the SP and reflected
|
||||
// back in the AuthnResponse. The Sourcegraph SP uses the base64-encoded JSON of this struct as
|
||||
// the RelayState.
|
||||
// * In the IdP-initiated login flow, the RelayState can be any arbitrary hint, but in practice
|
||||
// is the desired post-login redirect URL in plain text.
|
||||
type relayState struct {
|
||||
ProviderID string `json:"k"`
|
||||
ReturnToURL string `json:"r"`
|
||||
}
|
||||
|
||||
// encode returns the base64-encoded JSON representation of the relay state.
|
||||
func (s *relayState) encode() string {
|
||||
b, _ := json.Marshal(s)
|
||||
return base64.StdEncoding.EncodeToString(b)
|
||||
}
|
||||
|
||||
// Decode decodes the base64-encoded JSON representation of the relay state into the receiver.
|
||||
func (s *relayState) decode(encoded string) {
|
||||
if strings.HasPrefix(encoded, "http://") || strings.HasPrefix(encoded, "https://") || encoded == "" {
|
||||
s.ProviderID, s.ReturnToURL = "", encoded
|
||||
return
|
||||
}
|
||||
|
||||
if b, err := base64.StdEncoding.DecodeString(encoded); err == nil {
|
||||
if err := json.Unmarshal(b, s); err == nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
s.ProviderID, s.ReturnToURL = "", ""
|
||||
}
|
||||
409
enterprise/cmd/frontend/auth/saml/middleware_test.go
Normal file
409
enterprise/cmd/frontend/auth/saml/middleware_test.go
Normal file
@ -0,0 +1,409 @@
|
||||
package saml
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"compress/flate"
|
||||
"context"
|
||||
"crypto"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/pem"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/beevik/etree"
|
||||
"github.com/crewjam/saml"
|
||||
"github.com/sourcegraph/enterprise/cmd/frontend/internal/licensing"
|
||||
"github.com/sourcegraph/enterprise/pkg/license"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/auth"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/session"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/actor"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/conf"
|
||||
"github.com/sourcegraph/sourcegraph/schema"
|
||||
|
||||
"github.com/crewjam/saml/samlidp"
|
||||
)
|
||||
|
||||
const (
|
||||
testSAMLSPCert = `-----BEGIN CERTIFICATE-----
|
||||
MIIC+zCCAeOgAwIBAgIQFkK4RCQNFkAFzj8dJHnXJjANBgkqhkiG9w0BAQsFADAS
|
||||
MRAwDgYDVQQKEwdBY21lIENvMB4XDTE4MTAyMDA4NTQ1NVoXDTI4MTAxNzA4NTQ1
|
||||
NVowEjEQMA4GA1UEChMHQWNtZSBDbzCCASIwDQYJKoZIhvcNAQEBBQADggEPADCC
|
||||
AQoCggEBAKMD2RmTnAI+1+s+hiakkdbOXHoEwRoG45yeCV5z8A7TnZtF238kReBN
|
||||
JSOUTvgrvg5WbfG8ULSariepAI45BH3yYoNOBXe0biVsCB+0h6szeV1+N6y9wj0j
|
||||
ns/AOOV6ec/GbUZufF+XeJmVX/kRoOthUCEWhCGn/ZCa9VNcr2u/EhCZhvk6JcY9
|
||||
p/gu2YYJepihYpkrzzHwlC+ye+AfPX0/LiZQLGM8ciiziXden8DqEhskkg5HqnPl
|
||||
hwscqI6qlYIcUFw5QB3xA738N4/92Uj7Jstf05ESFDf6zbUTn/hSsLXivNHI0G4P
|
||||
4gsVy5Y5pygrw3b3FuodJbuVtLU9cwMCAwEAAaNNMEswDgYDVR0PAQH/BAQDAgWg
|
||||
MBMGA1UdJQQMMAoGCCsGAQUFBwMBMAwGA1UdEwEB/wQCMAAwFgYDVR0RBA8wDYIL
|
||||
ZXhhbXBsZS5jb20wDQYJKoZIhvcNAQELBQADggEBADQ/UgbXlW7zPwWswJSlbgph
|
||||
yjepD5dJ/My1ByIM2GSSYlvnLGq9tSOwUWZ0fZY/G8WOowNSBQlUVPT7tS11j7Ce
|
||||
BdrImHWpDleZYyagk08vaU059LJFI4EM8Vzn570h3dxdXkSoIGjqNAfywat591k3
|
||||
K8llPk2wrQ8fv3KA7tNNmJW+Ee1DHIAui53aFe3sHmp7JN1tE6HlqrSLIDymSd28
|
||||
tOfJ1Y9kOvUF7DY8pkSVDukO9wsy0X568hfJOz4PQe/1LHJ1YxlomTCkyVV8xtW7
|
||||
hbnEyvPo2yr/SHbk4Fz1yXP20cBm8vO2DmKlI0kaKGQw1Rybl8NQw+OPdb/V6pM=
|
||||
-----END CERTIFICATE-----`
|
||||
|
||||
testSAMLSPKey = `-----BEGIN RSA PRIVATE KEY-----
|
||||
MIIEpAIBAAKCAQEAowPZGZOcAj7X6z6GJqSR1s5cegTBGgbjnJ4JXnPwDtOdm0Xb
|
||||
fyRF4E0lI5RO+Cu+DlZt8bxQtJquJ6kAjjkEffJig04Fd7RuJWwIH7SHqzN5XX43
|
||||
rL3CPSOez8A45Xp5z8ZtRm58X5d4mZVf+RGg62FQIRaEIaf9kJr1U1yva78SEJmG
|
||||
+Tolxj2n+C7Zhgl6mKFimSvPMfCUL7J74B89fT8uJlAsYzxyKLOJd16fwOoSGySS
|
||||
Dkeqc+WHCxyojqqVghxQXDlAHfEDvfw3j/3ZSPsmy1/TkRIUN/rNtROf+FKwteK8
|
||||
0cjQbg/iCxXLljmnKCvDdvcW6h0lu5W0tT1zAwIDAQABAoIBAC/t4LYhbVxHp+p1
|
||||
zrGr72lN8Wi63x/M6L1SxgRsaCej1pIhvwCp5JWneQT2BSX4jn/er6LEsKH5XL0y
|
||||
doRahVSWoJpkpTzl4wDDu7u+s6kFkGiJxMrYXDTntTj2FoR6Nzh86gIsWAsvGPln
|
||||
LvmnUj4CtbGU0jKnFumedgUVmko+QmalghYrkc7dReprGJ6EDWNLGb0ASG9/R7iQ
|
||||
NOKu17nK9W3yCWJc48SR8y9HWUEUqtKbsshJ6PewNNttsSC3JjeGuiH5fRmXLi6L
|
||||
wXr2l1AAPGRWbI8djrm7DFLa1s8pfJKkTV0YUDHFNBXny0h+oUGwC+KCQNsfE3t8
|
||||
GbKqdKkCgYEA0A3dFKuxzm9QbZRcmGwZbHTNfWb7EMlknNq9wqZLgbTK77P1bhXW
|
||||
l0YP7HuNZnKToMt3UrM5tYFzk28a73p4Ur8+va23lbtwwBeZ5qsZ/vAfI9b2GTj5
|
||||
AchOI5Xtwf8eWT3OIdHbe4W5hkyb/siPdkRJ1zXfDn8XN0+XIB7NeT0CgYEAyJTn
|
||||
xjtxMq3Jab5tycZ7ZVC9Y5tpd6phb6KLdLNGNcNKSPvC2EichgLH3Awckpe92HvD
|
||||
wujlQnlKod42/aPxVcOOnwqN2PMfzPXK91pcOt9QyWFzRLFtvkFB9Dd4vz5uhdWF
|
||||
CpSBDN87PpFEOuJJApy78e7hfZ5YpxWGK7N24T8CgYEAxtT4+9A6VTc8ffzToTdt
|
||||
9KCL4dSRDDHr3ZuOzn9umb7WUs6BN3vXYSqr/Sz2rXnCbGEG4Bo4hKX6dmQwMb2x
|
||||
UCNFKrDiSk6gKnRjuHa8mU+R8wZ0mxY/otxzEL8wQb42msLeRKPyRdI+w4JjctLp
|
||||
h/UrPGlXitsarNl7bE8Dv2ECgYBXOgog9rCfbVvtlFaCLMJ0qMvziR4wX/PHbFRh
|
||||
B6U8tBSV8IYnMEyBKqxnUQ0L4tk4T3ouRMGOStjd05jubGEC/uwC1cAh3Hiz1R/S
|
||||
uYTqRTsImExcTxx+ZDqeTZFA+ZFuuhAFLdeBFYLaDqoxQT6m2CoTZ+K/kiDTaFTU
|
||||
pFLKWQKBgQCNCeMVMkNwJtMeR/vKUEOPZLKFZihPrORi8F5v0f+qrcg3bKDKd1KI
|
||||
6kocCulZR3uvEFHUAoNyMNwCZs6YyIK9zEN3/Pb46ThnJNNMXv0CG+W8df/Vbnd0
|
||||
REijBJT1tjS4dXBULokRuI2640iWll8KX1KQDzqo3l++JRGqoP57Sg==
|
||||
-----END RSA PRIVATE KEY-----`
|
||||
)
|
||||
|
||||
var (
|
||||
idpCert = func() *x509.Certificate {
|
||||
b, _ := pem.Decode([]byte(`-----BEGIN CERTIFICATE-----
|
||||
MIIC/DCCAeSgAwIBAgIRANzQVAHz24KrifhcM9kaVcYwDQYJKoZIhvcNAQELBQAw
|
||||
EjEQMA4GA1UEChMHQWNtZSBDbzAeFw0xODEwMjAwODU1NDNaFw0yODEwMTcwODU1
|
||||
NDNaMBIxEDAOBgNVBAoTB0FjbWUgQ28wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw
|
||||
ggEKAoIBAQCYO2FXgBvsrHyDzFFjUxNuF5OtPYOw+yjGnMR7r1CU93ZaSFN0z9Ux
|
||||
Qv6rHAmoGwLp+dJjYZ2g/km6/TnONVGjSDh8TxCEH+cP5kIyRN4L3MPW9tsZ36Fd
|
||||
5yqNbVCrxp3gJKAHmcUHYONzQ6WxxOCEBkvVknysstG8hXhbOcElXrSIyRVPQuEu
|
||||
TBQSAJahFbQYCKFU93rlO142hgPJDkHibz8PhLgEU7v3Eo23JrOSKNUysXnp5hLT
|
||||
RhOyQyWNpXA91wwsTwETOD2KlDKDHIcpKEdMSWhRQya6S2z49RVKYpZmATW4Qq2l
|
||||
hDWWuyG7/IaheuyPum+BF0FFmDKfQY83AgMBAAGjTTBLMA4GA1UdDwEB/wQEAwIF
|
||||
oDATBgNVHSUEDDAKBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMBYGA1UdEQQPMA2C
|
||||
C2V4YW1wbGUuY29tMA0GCSqGSIb3DQEBCwUAA4IBAQBvdi9Gz1qADI0F/oC/7fh/
|
||||
TjevChAO3XsuGz53yqeM0z49yHJooCV3jzgBrjX82DAhk/nQ0kF+YZummoPG1fqf
|
||||
rycmsq0yD7Gy/do8sW7XSvpQpkbiBb9yg7rP1eVEfn+vDv0ZS0F3mqbfXl1v7FQ0
|
||||
PwPtE49OaO7rb3FbLPBtEXocvGjvga8SRmuT3/5oCI46AKldENL6+CKEEWWUIuW8
|
||||
HeKPQIRYzcqi1dy88nRk44DkCyNxe/h7X/MGt4Mu9HjDH7lDJs0038sdghsX74ET
|
||||
gN6hZwBR6U2UoJHDytj/+KtSL/XGZSwTgOFyFyMZROcqUPWRwl7Zk3dOKy+3T2Bi
|
||||
-----END CERTIFICATE-----
|
||||
`))
|
||||
c, _ := x509.ParseCertificate(b.Bytes)
|
||||
return c
|
||||
}()
|
||||
|
||||
idpKey = func() crypto.PrivateKey {
|
||||
b, _ := pem.Decode([]byte(`-----BEGIN RSA PRIVATE KEY-----
|
||||
MIIEpAIBAAKCAQEAmDthV4Ab7Kx8g8xRY1MTbheTrT2DsPsoxpzEe69QlPd2WkhT
|
||||
dM/VMUL+qxwJqBsC6fnSY2GdoP5Juv05zjVRo0g4fE8QhB/nD+ZCMkTeC9zD1vbb
|
||||
Gd+hXecqjW1Qq8ad4CSgB5nFB2Djc0OlscTghAZL1ZJ8rLLRvIV4WznBJV60iMkV
|
||||
T0LhLkwUEgCWoRW0GAihVPd65TteNoYDyQ5B4m8/D4S4BFO79xKNtyazkijVMrF5
|
||||
6eYS00YTskMljaVwPdcMLE8BEzg9ipQygxyHKShHTEloUUMmukts+PUVSmKWZgE1
|
||||
uEKtpYQ1lrshu/yGoXrsj7pvgRdBRZgyn0GPNwIDAQABAoIBAQCXS7zM5+fY6vy9
|
||||
SJ1C59gRvKDqto5hoNy/uCKXAoBF7UPVKri3Ca/Ky9irWqxGRMI6pC1y1BuDW/cP
|
||||
Pojq5pcCfs6UzUeO6N4OMTxtFYDRrVF+Hc1YA6gu2YazFIfukPFrSTs7Epp9YM/t
|
||||
SLgu24p/7HoGAxah1P8aLFSX5eiOJ+8t8byYOrKLp3Rn67lC9Y+9LX4X6GHlBMDc
|
||||
WHYupi3ZA7Q59dXQCJHFNG/hk17AMtB8lFra9rUid8teX8ZJKJQ26hU2O0UMujjM
|
||||
mFlCdmvc97lJ4LhjrWHv/9yacf90bViHIkL52Yux1jNt/jl3/7CyBwHbau4b0qoZ
|
||||
QkM4WIihAoGBAMlzsUeJxBCbUyMd/6TiLn60SDn0HMf4ZVdGqqxkhopVmsvRTn+P
|
||||
wu9YHWFPwXwVL3wdtuBjsEA9nMKWWMQKbQUZhm1Y+AQIVpVNQqesgyLctVoIUBNY
|
||||
fglvKrs8JuRuwMpE2P/3lXMsxtV9AyCpxxXhya8KqJa2jcMB/Lr+lx+fAoGBAMFz
|
||||
16yHU+Zo6OOvy7T2xh67btwOrS2kIzhGO6gcK7qabkGGeaKLShnMYEmFGp4EaHTf
|
||||
OVie+SU0OWF/e5bgFWC+fm6jWyhO0xPRbvs2P+l2KtnT2UBT9IgjhrVUIzp+Vn7t
|
||||
cjfb32m7km1kZZ48ySP9cH/4/xnT6XEC33PoNwlpAoGAG1t+w7xNyAOP8sDsKrQc
|
||||
pFBPTq98CRwOhx+tpeOw8bBWaT9vbZtUWbSZqNFv8S3fWPegEjD3ioHTfAl23Iid
|
||||
7Ydd3hOq+sE3IOdxGdwvotheOG/QkBAAbb+PCgZNMdBolg9reLdisFVwWyWy+wiT
|
||||
ZMFY5lCIPI9mCQmIDMzuMPkCgYBFJKJxh+z07YpP1wV4KLunQFbfUF+VcJUmB/RK
|
||||
ocb/azL9OJNBBYf2sJW5sVlSIUE0hJR6mFd0dLYNowMJag46Bdwqrzhlr8bBzplc
|
||||
MIenahTmxlFgLKG6Bvie1vPAdGd19mhcjrnLkL9FWhz38cHymyMammSTVqqZOe2j
|
||||
/9usAQKBgQCT//j6XflAr20gb+mNcoJqVxRTFtSsZa23kJnZ3Sfs3R8XXu5NcZEZ
|
||||
ODI9ZpZ9tg8oK9EB5vqMFTTnyjpar7F2jqFWtUmNju/rGlrQCZx0we+EAW/R2hFP
|
||||
YGYu4Z+SyXTsv/Ys5VGWuuCJO32RuRBeC4eJCmpyH0mqPhIBZmV4Jw==
|
||||
-----END RSA PRIVATE KEY-----
|
||||
`))
|
||||
k, _ := x509.ParsePKCS1PrivateKey(b.Bytes)
|
||||
return k
|
||||
}()
|
||||
)
|
||||
|
||||
// newSAMLIDPServer returns a new running SAML IDP server. It is the caller's
|
||||
// responsibility to call Close().
|
||||
func newSAMLIDPServer(t *testing.T) (*httptest.Server, *samlidp.Server) {
|
||||
h := http.NewServeMux()
|
||||
srv := httptest.NewServer(h)
|
||||
|
||||
srvURL, err := url.Parse(srv.URL)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
idpServer, err := samlidp.New(samlidp.Options{
|
||||
URL: *srvURL,
|
||||
Key: idpKey,
|
||||
Certificate: idpCert,
|
||||
Store: &samlidp.MemoryStore{},
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
h.Handle("/", idpServer)
|
||||
|
||||
return srv, idpServer
|
||||
}
|
||||
|
||||
func TestMiddleware(t *testing.T) {
|
||||
idpHTTPServer, idpServer := newSAMLIDPServer(t)
|
||||
defer idpHTTPServer.Close()
|
||||
|
||||
licensing.MockGetConfiguredProductLicenseInfo = func() (*license.Info, error) {
|
||||
return &license.Info{Tags: licensing.EnterpriseTags}, nil
|
||||
}
|
||||
defer func() { licensing.MockGetConfiguredProductLicenseInfo = nil }()
|
||||
|
||||
conf.Mock(&schema.SiteConfiguration{
|
||||
AppURL: "http://example.com",
|
||||
ExperimentalFeatures: &schema.ExperimentalFeatures{},
|
||||
})
|
||||
defer conf.Mock(nil)
|
||||
|
||||
config := withConfigDefaults(&schema.SAMLAuthProvider{
|
||||
Type: "saml",
|
||||
IdentityProviderMetadataURL: idpServer.IDP.MetadataURL.String(),
|
||||
ServiceProviderCertificate: testSAMLSPCert,
|
||||
ServiceProviderPrivateKey: testSAMLSPKey,
|
||||
})
|
||||
|
||||
mockGetProviderValue = &provider{config: *config}
|
||||
defer func() { mockGetProviderValue = nil }()
|
||||
auth.SetMockProviders([]auth.Provider{mockGetProviderValue})
|
||||
defer func() { auth.SetMockProviders(nil) }()
|
||||
|
||||
cleanup := session.ResetMockSessionStore(t)
|
||||
defer cleanup()
|
||||
|
||||
providerID := providerConfigID(&mockGetProviderValue.config, true)
|
||||
|
||||
// Mock user
|
||||
mockedExternalID := "testuser_id"
|
||||
const mockedUserID = 123
|
||||
auth.SetMockCreateOrUpdateUser(func(u db.NewUser, a db.ExternalAccountSpec) (userID int32, err error) {
|
||||
if a.ServiceType == "saml" && a.ServiceID == idpServer.IDP.MetadataURL.String() && a.ClientID == "http://example.com/.auth/saml/metadata" && a.AccountID == mockedExternalID {
|
||||
return mockedUserID, nil
|
||||
}
|
||||
return 0, fmt.Errorf("account %v not found in mock", a)
|
||||
})
|
||||
defer func() { auth.SetMockCreateOrUpdateUser(nil) }()
|
||||
|
||||
// Set up the test handler.
|
||||
authedHandler := http.NewServeMux()
|
||||
authedHandler.Handle("/.api/", Middleware.API(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
if uid := actor.FromContext(r.Context()).UID; uid != mockedUserID && uid != 0 {
|
||||
t.Errorf("got actor UID %d, want %d", uid, mockedUserID)
|
||||
}
|
||||
})))
|
||||
authedHandler.Handle("/", Middleware.App(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
switch r.URL.Path {
|
||||
case "/":
|
||||
w.Write([]byte("This is the home"))
|
||||
case "/page":
|
||||
w.Write([]byte("This is a page"))
|
||||
case "/require-authn":
|
||||
actr := actor.FromContext(r.Context())
|
||||
if actr.UID == 0 {
|
||||
t.Errorf("in authn expected-endpoint, no actor was set; expected actor with UID %d", mockedUserID)
|
||||
} else if actr.UID != mockedUserID {
|
||||
t.Errorf("in authn expected-endpoint, actor with incorrect UID was set; %d != %d", actr.UID, mockedUserID)
|
||||
}
|
||||
w.Write([]byte("Authenticated"))
|
||||
default:
|
||||
http.Error(w, "", http.StatusNotFound)
|
||||
}
|
||||
})))
|
||||
|
||||
// doRequest simulates a request to our authed handler (i.e., the SAML Service Provider).
|
||||
//
|
||||
// authed sets an authed actor in the request context to simulate an authenticated request.
|
||||
doRequest := func(method, urlStr, body string, cookies []*http.Cookie, authed bool, form url.Values) *http.Response {
|
||||
req := httptest.NewRequest(method, urlStr, bytes.NewBufferString(body))
|
||||
for _, cookie := range cookies {
|
||||
req.AddCookie(cookie)
|
||||
}
|
||||
if form != nil {
|
||||
req.PostForm = form
|
||||
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
|
||||
}
|
||||
if authed {
|
||||
req = req.WithContext(actor.WithActor(context.Background(), &actor.Actor{UID: mockedUserID}))
|
||||
}
|
||||
respRecorder := httptest.NewRecorder()
|
||||
authedHandler.ServeHTTP(respRecorder, req)
|
||||
return respRecorder.Result()
|
||||
}
|
||||
|
||||
var (
|
||||
authnRequest saml.AuthnRequest
|
||||
authnCookies []*http.Cookie
|
||||
authnRequestURL string
|
||||
)
|
||||
t.Run("unauthenticated homepage visit -> IDP SSO URL", func(t *testing.T) {
|
||||
resp := doRequest("GET", "http://example.com/", "", nil, false, nil)
|
||||
if want := http.StatusFound; resp.StatusCode != want {
|
||||
t.Errorf("got response code %v, want %v", resp.StatusCode, want)
|
||||
}
|
||||
locURL, err := url.Parse(resp.Header.Get("Location"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !strings.HasPrefix(locURL.String(), idpServer.IDP.SSOURL.String()) {
|
||||
t.Error("wrong redirect URL")
|
||||
}
|
||||
|
||||
// save cookies and Authn request
|
||||
authnCookies = unexpiredCookies(resp)
|
||||
authnRequestURL = locURL.String()
|
||||
deflatedSAMLRequest, err := base64.StdEncoding.DecodeString(locURL.Query().Get("SAMLRequest"))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := xml.NewDecoder(flate.NewReader(bytes.NewBuffer(deflatedSAMLRequest))).Decode(&authnRequest); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
})
|
||||
t.Run("unauthenticated API visit -> pass through", func(t *testing.T) {
|
||||
resp := doRequest("GET", "http://example.com/.api/foo", "", nil, false, nil)
|
||||
if got, want := resp.StatusCode, http.StatusOK; got != want {
|
||||
t.Errorf("wrong response code: got %v, want %v", got, want)
|
||||
}
|
||||
})
|
||||
var (
|
||||
loggedInCookies []*http.Cookie
|
||||
)
|
||||
t.Run("get SP metadata and register SP with IDP", func(t *testing.T) {
|
||||
resp := doRequest("GET", "http://example.com/.auth/saml/metadata?pc="+providerID, "", nil, false, nil)
|
||||
service := samlidp.Service{}
|
||||
if err := xml.NewDecoder(resp.Body).Decode(&service.Metadata); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
serviceMetadataBytes, err := xml.Marshal(service.Metadata)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
req, err := http.NewRequest("PUT", idpHTTPServer.URL+"/services/id", bytes.NewBuffer(serviceMetadataBytes))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
resp, err = http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("could not register SP with IDP, error: %s, resp: %v", err, resp)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if want := http.StatusNoContent; resp.StatusCode != want {
|
||||
t.Errorf("got HTTP %d, want %d", resp.StatusCode, want)
|
||||
}
|
||||
})
|
||||
t.Run("get SAML assertion from IDP and post the assertion to the SP ACS URL", func(t *testing.T) {
|
||||
authnReq, err := http.NewRequest("GET", authnRequestURL, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
idpAuthnReq, err := saml.NewIdpAuthnRequest(&idpServer.IDP, authnReq)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := idpAuthnReq.Validate(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
session := saml.Session{
|
||||
ID: "session-id",
|
||||
CreateTime: time.Now(),
|
||||
ExpireTime: time.Now().Add(24 * time.Hour),
|
||||
Index: "index",
|
||||
|
||||
NameID: "testuser_id",
|
||||
UserName: "testuser_username",
|
||||
UserEmail: "testuser@email.com",
|
||||
}
|
||||
if err := (saml.DefaultAssertionMaker{}).MakeAssertion(idpAuthnReq, &session); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := idpAuthnReq.MakeResponse(); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
doc := etree.NewDocument()
|
||||
doc.SetRoot(idpAuthnReq.ResponseEl)
|
||||
responseBuf, err := doc.WriteToBytes()
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
samlResponse := base64.StdEncoding.EncodeToString(responseBuf)
|
||||
reqParams := url.Values{}
|
||||
reqParams.Set("SAMLResponse", samlResponse)
|
||||
reqParams.Set("RelayState", idpAuthnReq.RelayState)
|
||||
resp := doRequest("POST", "http://example.com/.auth/saml/acs", "", authnCookies, false, reqParams)
|
||||
if want := http.StatusFound; resp.StatusCode != want {
|
||||
t.Errorf("got status code %v, want %v", resp.StatusCode, want)
|
||||
}
|
||||
if got, want1, want2 := resp.Header.Get("Location"), "http://example.com/", "/"; got != want1 && got != want2 {
|
||||
t.Errorf("got redirect location %v, want %v or %v", got, want1, want2)
|
||||
}
|
||||
|
||||
// save the cookies from the login response
|
||||
loggedInCookies = unexpiredCookies(resp)
|
||||
})
|
||||
t.Run("authenticated request to home page", func(t *testing.T) {
|
||||
resp := doRequest("GET", "http://example.com/", "", loggedInCookies, true, nil)
|
||||
respBody, _ := ioutil.ReadAll(resp.Body)
|
||||
if want := http.StatusOK; resp.StatusCode != want {
|
||||
t.Errorf("got status code %v, want %v", resp.StatusCode, want)
|
||||
}
|
||||
if got, want := string(respBody), "This is the home"; got != want {
|
||||
t.Errorf("got response body %v, want %v", got, want)
|
||||
}
|
||||
})
|
||||
t.Run("authenticated request to sub page", func(t *testing.T) {
|
||||
resp := doRequest("GET", "http://example.com/page", "", loggedInCookies, true, nil)
|
||||
respBody, _ := ioutil.ReadAll(resp.Body)
|
||||
if want := http.StatusOK; resp.StatusCode != want {
|
||||
t.Errorf("got status code %v, want %v", resp.StatusCode, want)
|
||||
}
|
||||
if got, want := string(respBody), "This is a page"; got != want {
|
||||
t.Errorf("got response body %v, want %v", got, want)
|
||||
}
|
||||
})
|
||||
t.Run("verify actor gets set in request context", func(t *testing.T) {
|
||||
resp := doRequest("GET", "http://example.com/require-authn", "", loggedInCookies, true, nil)
|
||||
if want := http.StatusOK; resp.StatusCode != want {
|
||||
t.Errorf("got status code %v, want %v", resp.StatusCode, want)
|
||||
}
|
||||
})
|
||||
t.Run("verify actor gets set in API request context", func(t *testing.T) {
|
||||
resp := doRequest("GET", "http://example.com/.api/foo", "", loggedInCookies, true, nil)
|
||||
if got, want := resp.StatusCode, http.StatusOK; got != want {
|
||||
t.Errorf("wrong status code: got %v, want %v", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// unexpiredCookies returns the list of unexpired cookies set by the response
|
||||
func unexpiredCookies(resp *http.Response) (cookies []*http.Cookie) {
|
||||
for _, cookie := range resp.Cookies() {
|
||||
if cookie.RawExpires == "" || cookie.Expires.After(time.Now()) {
|
||||
cookies = append(cookies, cookie)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
281
enterprise/cmd/frontend/auth/saml/provider.go
Normal file
281
enterprise/cmd/frontend/auth/saml/provider.go
Normal file
@ -0,0 +1,281 @@
|
||||
package saml
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/xml"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/beevik/etree"
|
||||
"github.com/pkg/errors"
|
||||
saml2 "github.com/russellhaering/gosaml2"
|
||||
"github.com/russellhaering/gosaml2/types"
|
||||
dsig "github.com/russellhaering/goxmldsig"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/auth"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/conf"
|
||||
"github.com/sourcegraph/sourcegraph/schema"
|
||||
"golang.org/x/net/context/ctxhttp"
|
||||
)
|
||||
|
||||
const providerType = "saml"
|
||||
|
||||
type provider struct {
|
||||
config schema.SAMLAuthProvider
|
||||
multiple bool // whether there are multiple SAML auth providers
|
||||
|
||||
mu sync.Mutex
|
||||
samlSP *saml2.SAMLServiceProvider
|
||||
refreshErr error
|
||||
}
|
||||
|
||||
// ConfigID implements auth.Provider.
|
||||
func (p *provider) ConfigID() auth.ProviderConfigID {
|
||||
return auth.ProviderConfigID{
|
||||
Type: providerType,
|
||||
ID: providerConfigID(&p.config, p.multiple),
|
||||
}
|
||||
}
|
||||
|
||||
// Config implements auth.Provider.
|
||||
func (p *provider) Config() schema.AuthProviders {
|
||||
return schema.AuthProviders{Saml: &p.config}
|
||||
}
|
||||
|
||||
// Refresh implements auth.Provider.
|
||||
func (p *provider) Refresh(ctx context.Context) error {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.samlSP, p.refreshErr = getServiceProvider(ctx, &p.config)
|
||||
return p.refreshErr
|
||||
}
|
||||
|
||||
func providerIDQuery(pc *schema.SAMLAuthProvider, multiple bool) url.Values {
|
||||
if multiple {
|
||||
return url.Values{"pc": []string{providerConfigID(pc, multiple)}}
|
||||
}
|
||||
return url.Values{}
|
||||
}
|
||||
|
||||
func (p *provider) getCachedInfoAndError() (*auth.ProviderInfo, error) {
|
||||
info := auth.ProviderInfo{
|
||||
DisplayName: p.config.DisplayName,
|
||||
AuthenticationURL: (&url.URL{
|
||||
Path: path.Join(auth.AuthURLPrefix, "saml", "login"),
|
||||
RawQuery: providerIDQuery(&p.config, p.multiple).Encode(),
|
||||
}).String(),
|
||||
}
|
||||
if info.DisplayName == "" {
|
||||
info.DisplayName = "SAML"
|
||||
}
|
||||
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
err := p.refreshErr
|
||||
if err != nil {
|
||||
err = errors.WithMessage(err, "failed to initialize SAML Service Provider")
|
||||
} else if p.samlSP == nil {
|
||||
err = errors.New("SAML Service Provider is not yet initialized")
|
||||
}
|
||||
if p.samlSP != nil {
|
||||
info.ServiceID = p.samlSP.IdentityProviderIssuer
|
||||
info.ClientID = p.samlSP.ServiceProviderIssuer
|
||||
}
|
||||
return &info, err
|
||||
}
|
||||
|
||||
// CachedInfo implements auth.Provider.
|
||||
func (p *provider) CachedInfo() *auth.ProviderInfo {
|
||||
info, _ := p.getCachedInfoAndError()
|
||||
return info
|
||||
}
|
||||
|
||||
func getServiceProvider(ctx context.Context, pc *schema.SAMLAuthProvider) (*saml2.SAMLServiceProvider, error) {
|
||||
c, err := readProviderConfig(pc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
idpMetadata, err := readIdentityProviderMetadata(ctx, c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
{
|
||||
if c.identityProviderMetadataURL != nil {
|
||||
traceLog(fmt.Sprintf("Identity Provider metadata: %s", c.identityProviderMetadataURL), string(idpMetadata))
|
||||
}
|
||||
}
|
||||
|
||||
metadata, err := unmarshalEntityDescriptor(idpMetadata)
|
||||
if err != nil {
|
||||
return nil, errors.WithMessage(err, "parsing SAML Identity Provider metadata")
|
||||
}
|
||||
|
||||
sp := saml2.SAMLServiceProvider{
|
||||
IdentityProviderSSOURL: metadata.IDPSSODescriptor.SingleSignOnServices[0].Location,
|
||||
IdentityProviderIssuer: metadata.EntityID,
|
||||
NameIdFormat: getNameIDFormat(pc),
|
||||
SkipSignatureValidation: pc.InsecureSkipAssertionSignatureValidation,
|
||||
ValidateEncryptionCert: true,
|
||||
AllowMissingAttributes: true,
|
||||
}
|
||||
|
||||
idpCertStore := &dsig.MemoryX509CertificateStore{Roots: []*x509.Certificate{}}
|
||||
for _, kd := range metadata.IDPSSODescriptor.KeyDescriptors {
|
||||
for i, xcert := range kd.KeyInfo.X509Data.X509Certificates {
|
||||
if xcert.Data == "" {
|
||||
return nil, fmt.Errorf("SAML Identity Provider metadata certificate %d is empty", i)
|
||||
}
|
||||
certData, err := base64.StdEncoding.DecodeString(xcert.Data)
|
||||
if err != nil {
|
||||
return nil, errors.WithMessage(err, fmt.Sprintf("decoding SAML Identity Provider metadata certificate %d", i))
|
||||
}
|
||||
idpCert, err := x509.ParseCertificate(certData)
|
||||
if err != nil {
|
||||
return nil, errors.WithMessage(err, fmt.Sprintf("parsing SAML Identity Provider metadata certificate %d X.509 data", i))
|
||||
}
|
||||
idpCertStore.Roots = append(idpCertStore.Roots, idpCert)
|
||||
}
|
||||
}
|
||||
sp.IDPCertificateStore = idpCertStore
|
||||
|
||||
// The SP's signing and encryption keys.
|
||||
if c.keyPair != nil {
|
||||
sp.SPKeyStore = dsig.TLSCertKeyStore(*c.keyPair)
|
||||
sp.SignAuthnRequests = pc.SignRequests == nil || *pc.SignRequests
|
||||
} else {
|
||||
// If the SP private key isn't specified, then the IdP must not care to validate.
|
||||
if pc.SignRequests != nil && *pc.SignRequests {
|
||||
return nil, errors.New("signRequests is true for SAML Service Provider but no private key and cert are given")
|
||||
}
|
||||
}
|
||||
|
||||
// pc.Issuer's default of ${appURL}/.auth/saml/metadata already applied (in withConfigDefaults).
|
||||
sp.ServiceProviderIssuer = pc.ServiceProviderIssuer
|
||||
if pc.ServiceProviderIssuer == "" {
|
||||
return nil, errors.New("invalid SAML Service Provider configuration: issuer is empty (and default issuer could not be derived from empty appURL)")
|
||||
}
|
||||
appURL, err := url.Parse(conf.Get().AppURL)
|
||||
if err != nil {
|
||||
return nil, errors.WithMessage(err, "parsing app URL for SAML Service Provider")
|
||||
}
|
||||
sp.AssertionConsumerServiceURL = appURL.ResolveReference(&url.URL{Path: path.Join(authPrefix, "acs")}).String()
|
||||
sp.AudienceURI = sp.ServiceProviderIssuer
|
||||
|
||||
return &sp, nil
|
||||
}
|
||||
|
||||
// entitiesDescriptor represents the SAML EntitiesDescriptor object.
|
||||
type entitiesDescriptor struct {
|
||||
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:metadata EntitiesDescriptor"`
|
||||
ID *string `xml:",attr,omitempty"`
|
||||
ValidUntil *time.Time `xml:"validUntil,attr,omitempty"`
|
||||
CacheDuration *time.Duration `xml:"cacheDuration,attr,omitempty"`
|
||||
Name *string `xml:",attr,omitempty"`
|
||||
Signature *etree.Element
|
||||
EntitiesDescriptors []entitiesDescriptor `xml:"urn:oasis:names:tc:SAML:2.0:metadata EntitiesDescriptor"`
|
||||
EntityDescriptors []types.EntityDescriptor `xml:"urn:oasis:names:tc:SAML:2.0:metadata EntityDescriptor"`
|
||||
}
|
||||
|
||||
// unmarshalEntityDescriptor unmarshals from an XML root <EntityDescriptor> or <EntitiesDescriptor>
|
||||
// element. If the latter, it returns the first <EntityDescriptor> child that has an
|
||||
// IDPSSODescriptor.
|
||||
//
|
||||
// Taken from github.com/crewjam/saml.
|
||||
func unmarshalEntityDescriptor(data []byte) (*types.EntityDescriptor, error) {
|
||||
var entity *types.EntityDescriptor
|
||||
if err := xml.Unmarshal(data, &entity); err != nil {
|
||||
// This comparison is ugly, but it is how the error is generated in encoding/xml.
|
||||
if err.Error() != "expected element type <EntityDescriptor> but have <EntitiesDescriptor>" {
|
||||
return nil, err
|
||||
}
|
||||
var entities *entitiesDescriptor
|
||||
if err := xml.Unmarshal(data, &entities); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for i, e := range entities.EntityDescriptors {
|
||||
if e.IDPSSODescriptor != nil {
|
||||
entity = &entities.EntityDescriptors[i]
|
||||
break
|
||||
}
|
||||
}
|
||||
if entity == nil {
|
||||
return nil, errors.New("no entity found with IDPSSODescriptor")
|
||||
}
|
||||
}
|
||||
return entity, nil
|
||||
}
|
||||
|
||||
type providerConfig struct {
|
||||
keyPair *tls.Certificate
|
||||
|
||||
// Exactly 1 of these is set:
|
||||
identityProviderMetadataURL *url.URL
|
||||
identityProviderMetadata []byte
|
||||
}
|
||||
|
||||
func readProviderConfig(pc *schema.SAMLAuthProvider) (*providerConfig, error) {
|
||||
var c providerConfig
|
||||
|
||||
if pc.ServiceProviderCertificate != "" && pc.ServiceProviderPrivateKey != "" {
|
||||
keyPair, err := tls.X509KeyPair([]byte(pc.ServiceProviderCertificate), []byte(pc.ServiceProviderPrivateKey))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
c.keyPair = &keyPair
|
||||
}
|
||||
|
||||
// Allow specifying either URL to SAML Identity Provider metadata XML file, or the XML
|
||||
// file contents directly.
|
||||
switch {
|
||||
case pc.IdentityProviderMetadataURL != "" && pc.IdentityProviderMetadata != "":
|
||||
return nil, errors.New("invalid SAML configuration: set either identityProviderMetadataURL or identityProviderMetadata, not both")
|
||||
|
||||
case pc.IdentityProviderMetadataURL != "":
|
||||
var err error
|
||||
c.identityProviderMetadataURL, err = url.Parse(pc.IdentityProviderMetadataURL)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "parsing SAML Identity Provider metadata URL")
|
||||
}
|
||||
|
||||
case pc.IdentityProviderMetadata != "":
|
||||
c.identityProviderMetadata = []byte(pc.IdentityProviderMetadata)
|
||||
|
||||
default:
|
||||
return nil, errors.New("invalid SAML configuration: must provide the SAML metadata, using either identityProviderMetadataURL (URL where XML file is available) or identityProviderMetadata (XML file contents)")
|
||||
}
|
||||
|
||||
return &c, nil
|
||||
}
|
||||
|
||||
func readIdentityProviderMetadata(ctx context.Context, c *providerConfig) ([]byte, error) {
|
||||
if c.identityProviderMetadata != nil {
|
||||
return []byte(c.identityProviderMetadata), nil
|
||||
}
|
||||
|
||||
resp, err := ctxhttp.Get(ctx, nil, c.identityProviderMetadataURL.String())
|
||||
if err != nil {
|
||||
return nil, errors.WithMessage(err, "fetching SAML Identity Provider metadata")
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("non-200 HTTP response for SAML Identity Provider metadata URL: %s", c.identityProviderMetadataURL)
|
||||
}
|
||||
|
||||
data, err := ioutil.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, errors.WithMessage(err, "reading SAML Identity Provider metadata")
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
1
enterprise/cmd/frontend/auth/saml/provider_test.go
Normal file
1
enterprise/cmd/frontend/auth/saml/provider_test.go
Normal file
@ -0,0 +1 @@
|
||||
package saml
|
||||
83
enterprise/cmd/frontend/auth/saml/session.go
Normal file
83
enterprise/cmd/frontend/auth/saml/session.go
Normal file
@ -0,0 +1,83 @@
|
||||
package saml
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/beevik/etree"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/conf"
|
||||
"github.com/sourcegraph/sourcegraph/schema"
|
||||
)
|
||||
|
||||
// SignOut returns the URL where the user can initiate a logout from the SAML IdentityProvider, if
|
||||
// it has a SingleLogoutService.
|
||||
func SignOut(w http.ResponseWriter, r *http.Request) (logoutURL string, err error) {
|
||||
// TODO(sqs): Only supports a single SAML auth provider.
|
||||
pc, multiple := getFirstProviderConfig()
|
||||
if pc == nil {
|
||||
return "", nil
|
||||
}
|
||||
p := getProvider(providerConfigID(pc, multiple))
|
||||
if p == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
doc, err := newLogoutRequest(p)
|
||||
if err != nil {
|
||||
return "", errors.WithMessage(err, "creating SAML LogoutRequest")
|
||||
}
|
||||
{
|
||||
if data, err := doc.WriteToString(); err == nil {
|
||||
traceLog(fmt.Sprintf("LogoutRequest: %s", p.ConfigID().ID), string(data))
|
||||
}
|
||||
}
|
||||
return p.samlSP.BuildAuthURLRedirect("/", doc)
|
||||
}
|
||||
|
||||
// getFirstProviderConfig returns the SAML auth provider config. At most 1 can be specified in site
|
||||
// config; if there is more than 1, it returns multiple == true (which the caller should handle by
|
||||
// returning an error and refusing to proceed with auth).
|
||||
func getFirstProviderConfig() (pc *schema.SAMLAuthProvider, multiple bool) {
|
||||
for _, p := range conf.Get().AuthProviders {
|
||||
if p.Saml != nil {
|
||||
if pc != nil {
|
||||
return pc, true // multiple SAML auth providers
|
||||
}
|
||||
pc = withConfigDefaults(p.Saml)
|
||||
}
|
||||
}
|
||||
return pc, false
|
||||
}
|
||||
|
||||
func newLogoutRequest(p *provider) (*etree.Document, error) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
if p.samlSP == nil {
|
||||
return nil, errors.New("unable to create SAML LogoutRequest because provider is not yet initialized")
|
||||
}
|
||||
|
||||
// Start with the doc for AuthnRequest and change a few things to make it into a LogoutRequest
|
||||
// doc. This saves us from needing to duplicate a bunch of code.
|
||||
doc, err := p.samlSP.BuildAuthRequestDocumentNoSig()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
root := doc.Root()
|
||||
root.Tag = "LogoutRequest"
|
||||
// TODO(sqs): This assumes SSO URL == SLO URL (i.e., the same endpoint is used for signin and
|
||||
// logout). To fix this, use `root.SelectAttr("Destination").Value = "..."`.
|
||||
if t := root.FindElement("//samlp:NameIDPolicy"); t != nil {
|
||||
root.RemoveChild(t)
|
||||
}
|
||||
|
||||
if p.samlSP.SignAuthnRequests {
|
||||
signed, err := p.samlSP.SignAuthnRequest(root)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
doc.SetRoot(signed)
|
||||
}
|
||||
return doc, nil
|
||||
}
|
||||
124
enterprise/cmd/frontend/auth/saml/user.go
Normal file
124
enterprise/cmd/frontend/auth/saml/user.go
Normal file
@ -0,0 +1,124 @@
|
||||
package saml
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
saml2 "github.com/russellhaering/gosaml2"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/auth"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/actor"
|
||||
)
|
||||
|
||||
type authnResponseInfo struct {
|
||||
spec db.ExternalAccountSpec
|
||||
email, displayName string
|
||||
unnormalizedUsername string
|
||||
accountData interface{}
|
||||
}
|
||||
|
||||
func readAuthnResponse(p *provider, encodedResp string) (*authnResponseInfo, error) {
|
||||
{
|
||||
if raw, err := base64.StdEncoding.DecodeString(encodedResp); err == nil {
|
||||
traceLog(fmt.Sprintf("AuthnResponse: %s", p.ConfigID().ID), string(raw))
|
||||
}
|
||||
}
|
||||
|
||||
assertions, err := p.samlSP.RetrieveAssertionInfo(encodedResp)
|
||||
if err != nil {
|
||||
return nil, errors.WithMessage(err, "reading AuthnResponse assertions")
|
||||
}
|
||||
if wi := assertions.WarningInfo; wi.InvalidTime || wi.NotInAudience {
|
||||
return nil, fmt.Errorf("invalid SAML AuthnResponse: %+v", wi)
|
||||
}
|
||||
|
||||
pi, err := p.getCachedInfoAndError()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
firstNonempty := func(ss ...string) string {
|
||||
for _, s := range ss {
|
||||
if s := strings.TrimSpace(s); s != "" {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
attr := samlAssertionValues(assertions.Values)
|
||||
email := firstNonempty(attr.Get("email"), attr.Get("emailaddress"))
|
||||
if email == "" && mightBeEmail(assertions.NameID) {
|
||||
email = assertions.NameID
|
||||
}
|
||||
if pn := attr.Get("eduPersonPrincipalName"); email == "" && mightBeEmail(pn) {
|
||||
email = pn
|
||||
}
|
||||
info := authnResponseInfo{
|
||||
spec: db.ExternalAccountSpec{
|
||||
ServiceType: providerType,
|
||||
ServiceID: pi.ServiceID,
|
||||
ClientID: pi.ClientID,
|
||||
AccountID: assertions.NameID,
|
||||
},
|
||||
email: email,
|
||||
unnormalizedUsername: firstNonempty(attr.Get("login"), attr.Get("uid"), email),
|
||||
displayName: firstNonempty(attr.Get("displayName"), attr.Get("givenName")+" "+attr.Get("surname")),
|
||||
accountData: assertions,
|
||||
}
|
||||
if assertions.NameID == "" {
|
||||
return nil, errors.New("the SAML response did not contain a valid NameID")
|
||||
}
|
||||
if info.email == "" {
|
||||
return nil, errors.New("the SAML response did not contain an email attribute")
|
||||
}
|
||||
if info.unnormalizedUsername == "" {
|
||||
return nil, errors.New("the SAML response did not contain a username attribute")
|
||||
}
|
||||
return &info, nil
|
||||
}
|
||||
|
||||
// getOrCreateUser gets or creates a user account based on the SAML claims. It returns the
|
||||
// authenticated actor if successful; otherwise it returns an friendly error message (safeErrMsg)
|
||||
// that is safe to display to users, and a non-nil err with lower-level error details.
|
||||
func getOrCreateUser(ctx context.Context, info *authnResponseInfo) (_ *actor.Actor, safeErrMsg string, err error) {
|
||||
var data db.ExternalAccountData
|
||||
auth.SetExternalAccountData(&data.AccountData, info.accountData)
|
||||
|
||||
username, err := auth.NormalizeUsername(info.unnormalizedUsername)
|
||||
if err != nil {
|
||||
return nil, fmt.Sprintf("Error normalizing the username %q. See https://about.sourcegraph.com/docs/config/authentication#username-normalization.", info.unnormalizedUsername), err
|
||||
}
|
||||
|
||||
userID, safeErrMsg, err := auth.CreateOrUpdateUser(ctx, db.NewUser{
|
||||
Username: username,
|
||||
Email: info.email,
|
||||
EmailIsVerified: info.email != "", // TODO(sqs): https://github.com/sourcegraph/sourcegraph/issues/10118
|
||||
DisplayName: info.displayName,
|
||||
// SAML has no standard way of providing an avatar URL.
|
||||
},
|
||||
info.spec,
|
||||
data,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, safeErrMsg, err
|
||||
}
|
||||
return actor.FromUser(userID), "", nil
|
||||
}
|
||||
|
||||
func mightBeEmail(s string) bool {
|
||||
return strings.Count(s, "@") == 1
|
||||
}
|
||||
|
||||
type samlAssertionValues saml2.Values
|
||||
|
||||
func (v samlAssertionValues) Get(key string) string {
|
||||
for _, a := range v {
|
||||
if a.Name == key || a.FriendlyName == key {
|
||||
return a.Values[0].Value
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
59
enterprise/cmd/frontend/auth/saml/user_test.go
Normal file
59
enterprise/cmd/frontend/auth/saml/user_test.go
Normal file
File diff suppressed because one or more lines are too long
7
enterprise/cmd/frontend/db/db_test.go
Normal file
7
enterprise/cmd/frontend/db/db_test.go
Normal file
@ -0,0 +1,7 @@
|
||||
package db
|
||||
|
||||
import dbtesting "github.com/sourcegraph/sourcegraph/cmd/frontend/db/testing"
|
||||
|
||||
func init() {
|
||||
dbtesting.DBNameSuffix = "enterprisedb"
|
||||
}
|
||||
527
enterprise/cmd/frontend/db/global_deps.go
Normal file
527
enterprise/cmd/frontend/db/global_deps.go
Normal file
@ -0,0 +1,527 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/lib/pq"
|
||||
opentracing "github.com/opentracing/opentracing-go"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
otlog "github.com/opentracing/opentracing-go/log"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/db/dbconn"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/types"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/api"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/inventory"
|
||||
"github.com/sourcegraph/sourcegraph/xlang"
|
||||
"github.com/sourcegraph/sourcegraph/xlang/lspext"
|
||||
)
|
||||
|
||||
// globalDeps provides access to the `global_dep` table. Each row in
|
||||
// the table represents a dependency relationship from a repository to
|
||||
// a package-manager-level package.
|
||||
//
|
||||
// * The language column is the programming language in which the
|
||||
// dependency occurs (the language of the repository and the package
|
||||
// manager package)
|
||||
// * The dep_data column contains JSON describing the package manager package.
|
||||
// Typically, this includes a name and version field.
|
||||
// * The repo_id column identifies the repository.
|
||||
// * The hints column contains JSON that contains additional hints that can
|
||||
// be used to optimized requests related to the dependency (e.g., which
|
||||
// directory in a repository contains the dependency).
|
||||
//
|
||||
// For a detailed overview of the schema, see schema.txt.
|
||||
type globalDeps struct{}
|
||||
|
||||
func (g *globalDeps) TotalRefs(ctx context.Context, repo *types.Repo, langs []*inventory.Lang) (int, error) {
|
||||
var sum int
|
||||
for _, lang := range langs {
|
||||
switch lang.Name {
|
||||
case inventory.LangGo:
|
||||
for _, expandedSources := range repoURIToGoPathPrefixes(repo.URI) {
|
||||
refs, err := g.doTotalRefsGo(ctx, expandedSources)
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "doTotalRefsGo")
|
||||
}
|
||||
sum += refs
|
||||
}
|
||||
case inventory.LangJava:
|
||||
refs, err := g.doTotalRefs(ctx, repo.ID, "java")
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "doTotalRefs")
|
||||
}
|
||||
sum += refs
|
||||
}
|
||||
}
|
||||
return sum, nil
|
||||
}
|
||||
|
||||
// ListTotalRefs is like TotalRefs, except it returns a list of repo IDs
|
||||
// instead of just the length of that list. Obviously, this is less efficient
|
||||
// if you just need the count, however.
|
||||
func (g *globalDeps) ListTotalRefs(ctx context.Context, repo *types.Repo, langs []*inventory.Lang) ([]api.RepoID, error) {
|
||||
var repos []api.RepoID
|
||||
for _, lang := range langs {
|
||||
switch lang.Name {
|
||||
case inventory.LangGo:
|
||||
for _, expandedSources := range repoURIToGoPathPrefixes(repo.URI) {
|
||||
refs, err := g.doListTotalRefsGo(ctx, expandedSources)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "doListTotalRefsGo")
|
||||
}
|
||||
repos = append(repos, refs...)
|
||||
}
|
||||
case inventory.LangJava:
|
||||
refs, err := g.doListTotalRefs(ctx, repo.ID, "java")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "doListTotalRefs")
|
||||
}
|
||||
repos = append(repos, refs...)
|
||||
}
|
||||
}
|
||||
return repos, nil
|
||||
}
|
||||
|
||||
// repoURIToGoPathPrefixes translates a repository URI like
|
||||
// github.com/kubernetes/kubernetes into its _prefix_ matching Go import paths
|
||||
// (e.g. k8s.io/kubernetes). In the case of the standard library,
|
||||
// github.com/golang/go returns all of the Go stdlib package paths. If the
|
||||
// repository URI is not special cased, []string{repoURI} is simply returned.
|
||||
//
|
||||
// TODO(slimsag): In the future, when the pkgs index includes Go repositories,
|
||||
// use that instead of this manual mapping hack.
|
||||
func repoURIToGoPathPrefixes(repoURI api.RepoURI) []string {
|
||||
manualMapping := map[api.RepoURI][]string{
|
||||
// stdlib hack: by returning an empty string (NOT no strings) we end up
|
||||
// with an SQL query like `AND dep_data->>'package' LIKE '%';` which
|
||||
// matches all Go repositories effectively. We do this for the stdlib
|
||||
// because all Go repositories will import the stdlib anyway.
|
||||
"github.com/golang/go": {""},
|
||||
|
||||
// google.golang.org
|
||||
"github.com/grpc/grpc-go": {"google.golang.org/grpc"},
|
||||
"github.com/google/google-api-go-client": {"google.golang.org/api"},
|
||||
"github.com/golang/appengine": {"google.golang.org/appengine"},
|
||||
|
||||
// go4.org
|
||||
"github.com/camlistore/go4": {"go4.org"},
|
||||
|
||||
// At special request of a user, since we don't support custom import
|
||||
// paths generically here yet. See https://github.com/sourcegraph/sourcegraph/issues/12488
|
||||
"github.com/goadesign/goa": {"github.com/goadesign/goa", "goa.design/goa"},
|
||||
}
|
||||
if v, ok := manualMapping[repoURI]; ok {
|
||||
return v
|
||||
}
|
||||
|
||||
switch {
|
||||
case strings.HasPrefix(string(repoURI), "github.com/azul3d"): // azul3d.org
|
||||
split := strings.Split(string(repoURI), "/")
|
||||
if len(split) >= 3 {
|
||||
return []string{"azul3d.org/" + split[2]}
|
||||
}
|
||||
|
||||
case strings.HasPrefix(string(repoURI), "github.com/dskinner"): // dasa.cc
|
||||
split := strings.Split(string(repoURI), "/")
|
||||
if len(split) >= 3 {
|
||||
return []string{"dasa.cc/" + split[2]}
|
||||
}
|
||||
|
||||
case strings.HasPrefix(string(repoURI), "github.com/kubernetes"): // k8s.io
|
||||
split := strings.Split(string(repoURI), "/")
|
||||
if len(split) >= 3 {
|
||||
return []string{"k8s.io/" + split[2]}
|
||||
}
|
||||
|
||||
case strings.HasPrefix(string(repoURI), "github.com/uber-go"): // go.uber.org
|
||||
split := strings.Split(string(repoURI), "/")
|
||||
if len(split) >= 3 {
|
||||
// Some repos use non-canonical import paths.
|
||||
return []string{
|
||||
string(repoURI),
|
||||
"go.uber.org/" + split[2],
|
||||
}
|
||||
}
|
||||
|
||||
case strings.HasPrefix(string(repoURI), "github.com/dominikh"): // honnef.co
|
||||
split := strings.Split(string(repoURI), "/")
|
||||
if len(split) >= 3 {
|
||||
return []string{"honnef.co/" + strings.Replace(split[2], "-", "/", -1)}
|
||||
}
|
||||
|
||||
case strings.HasPrefix(string(repoURI), "github.com/golang") && repoURI != "github.com/golang/go": // golang.org/x
|
||||
split := strings.Split(string(repoURI), "/")
|
||||
if len(split) >= 3 {
|
||||
return []string{"golang.org/x/" + split[2]}
|
||||
}
|
||||
|
||||
case strings.HasPrefix(string(repoURI), "github.com"): // gopkg.in
|
||||
split := strings.Split(string(repoURI), "/")
|
||||
if len(split) >= 3 && strings.HasPrefix(split[1], "go-") {
|
||||
// Four possibilities
|
||||
return []string{
|
||||
string(repoURI), // github.com/go-foo/foo
|
||||
"gopkg.in/" + strings.TrimPrefix(split[1], "go-"), // gopkg.in/foo
|
||||
"labix.org/v1/" + strings.TrimPrefix(split[1], "go-"), // labix.org/v1/foo
|
||||
"labix.org/v2/" + strings.TrimPrefix(split[1], "go-"), // labix.org/v2/foo
|
||||
}
|
||||
} else if len(split) >= 3 {
|
||||
// Two possibilities
|
||||
return []string{
|
||||
string(repoURI), // github.com/foo/bar
|
||||
"gopkg.in/" + split[1] + "/" + split[2], // gopkg.in/foo/bar
|
||||
}
|
||||
}
|
||||
}
|
||||
return []string{string(repoURI)}
|
||||
}
|
||||
|
||||
// doTotalRefs is the generic implementation of total references, using the `pkgs` table.
|
||||
func (g *globalDeps) doTotalRefs(ctx context.Context, repo api.RepoID, lang string) (sum int, err error) {
|
||||
// Get packages contained in the repo
|
||||
packages, err := (&pkgs{}).ListPackages(ctx, &api.ListPackagesOp{Lang: lang, Limit: 500, RepoID: repo})
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "ListPackages")
|
||||
}
|
||||
if len(packages) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
// Find number of repos that depend on that set of packages
|
||||
var args []interface{}
|
||||
arg := func(a interface{}) string {
|
||||
args = append(args, a)
|
||||
return fmt.Sprintf("$%d", len(args))
|
||||
}
|
||||
var pkgClauses []string
|
||||
for _, pkg := range packages {
|
||||
pkgID, ok := xlang.PackageIdentifier(pkg.Pkg, lang)
|
||||
if !ok {
|
||||
return 0, errors.Wrap(err, "PackageIdentifier")
|
||||
}
|
||||
containmentQuery, err := json.Marshal(pkgID)
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "Marshal")
|
||||
}
|
||||
pkgClauses = append(pkgClauses, `dep_data @> `+arg(string(containmentQuery)))
|
||||
}
|
||||
whereSQL := `(language=` + arg(lang) + `) AND ((` + strings.Join(pkgClauses, ") OR (") + `))`
|
||||
sql := `SELECT count(distinct(repo_id))
|
||||
FROM global_dep
|
||||
WHERE ` + whereSQL
|
||||
rows, err := dbconn.Global.QueryContext(ctx, sql, args...)
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "Query")
|
||||
}
|
||||
defer rows.Close()
|
||||
var count int
|
||||
for rows.Next() {
|
||||
if err := rows.Scan(&count); err != nil {
|
||||
return 0, errors.Wrap(err, "Scan")
|
||||
}
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// doListTotalRefs is the generic implementation of list total references,
|
||||
// using the `pkgs` table.
|
||||
func (g *globalDeps) doListTotalRefs(ctx context.Context, repo api.RepoID, lang string) ([]api.RepoID, error) {
|
||||
// Get packages contained in the repo
|
||||
packages, err := (&pkgs{}).ListPackages(ctx, &api.ListPackagesOp{Lang: lang, Limit: 500, RepoID: repo})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "ListPackages")
|
||||
}
|
||||
if len(packages) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Find all repos that depend on that set of packages
|
||||
var args []interface{}
|
||||
arg := func(a interface{}) string {
|
||||
args = append(args, a)
|
||||
return fmt.Sprintf("$%d", len(args))
|
||||
}
|
||||
var pkgClauses []string
|
||||
for _, pkg := range packages {
|
||||
pkgID, ok := xlang.PackageIdentifier(pkg.Pkg, lang)
|
||||
if !ok {
|
||||
return nil, errors.Wrap(err, "PackageIdentifier")
|
||||
}
|
||||
containmentQuery, err := json.Marshal(pkgID)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Marshal")
|
||||
}
|
||||
pkgClauses = append(pkgClauses, `dep_data @> `+arg(string(containmentQuery)))
|
||||
}
|
||||
whereSQL := `(language=` + arg(lang) + `) AND ((` + strings.Join(pkgClauses, ") OR (") + `))`
|
||||
sql := `SELECT distinct(repo_id)
|
||||
FROM global_dep
|
||||
WHERE ` + whereSQL
|
||||
rows, err := dbconn.Global.QueryContext(ctx, sql, args...)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Query")
|
||||
}
|
||||
defer rows.Close()
|
||||
var repos []api.RepoID
|
||||
for rows.Next() {
|
||||
var repo api.RepoID
|
||||
if err := rows.Scan(&repo); err != nil {
|
||||
return nil, errors.Wrap(err, "Scan")
|
||||
}
|
||||
repos = append(repos, repo)
|
||||
}
|
||||
return repos, nil
|
||||
}
|
||||
|
||||
// doTotalRefsGo is the Go-specific implementation of total references, since we can extract package metadata directly
|
||||
// from Go repository URLs, without going through the `pkgs` table.
|
||||
func (g *globalDeps) doTotalRefsGo(ctx context.Context, source string) (int, error) {
|
||||
// Because global_dep only stores Go package paths, not repository URIs, we
|
||||
// use a simple heuristic here by using `LIKE <repo>%`. This will work for
|
||||
// GitHub package paths (e.g. `github.com/a/b%` matches `github.com/a/b/c`)
|
||||
// but not custom import paths etc.
|
||||
rows, err := dbconn.Global.QueryContext(ctx, `SELECT COUNT(DISTINCT repo_id)
|
||||
FROM global_dep
|
||||
WHERE language='go'
|
||||
AND dep_data->>'depth' = '0'
|
||||
AND ( -- in C locale, this is equivalent to matching "$1/*", but matches much faster
|
||||
(dep_data->>'package' COLLATE "C" < $1 || '0' COLLATE "C" AND dep_data->>'package' COLLATE "C" > $1 || '/' COLLATE "C")
|
||||
OR (dep_data->>'package' COLLATE "C" = $1)
|
||||
);
|
||||
`, source)
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "Query")
|
||||
}
|
||||
defer rows.Close()
|
||||
var count int
|
||||
for rows.Next() {
|
||||
err := rows.Scan(&count)
|
||||
if err != nil {
|
||||
return 0, errors.Wrap(err, "Scan")
|
||||
}
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// doListTotalRefsGo is the Go-specific implementation of list total
|
||||
// references, since we can extract package metadata directly from Go
|
||||
// repository URLs, without going through the `pkgs` table.
|
||||
func (g *globalDeps) doListTotalRefsGo(ctx context.Context, source string) ([]api.RepoID, error) {
|
||||
// Because global_dep only stores Go package paths, not repository URIs, we
|
||||
// use a simple heuristic here by using `LIKE <repo>%`. This will work for
|
||||
// GitHub package paths (e.g. `github.com/a/b%` matches `github.com/a/b/c`)
|
||||
// but not custom import paths etc.
|
||||
rows, err := dbconn.Global.QueryContext(ctx, `SELECT DISTINCT repo_id
|
||||
FROM global_dep
|
||||
WHERE language='go'
|
||||
AND dep_data->>'depth' = '0'
|
||||
AND dep_data->>'package' LIKE $1;
|
||||
`, source+"%")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Query")
|
||||
}
|
||||
defer rows.Close()
|
||||
var repos []api.RepoID
|
||||
for rows.Next() {
|
||||
var repo api.RepoID
|
||||
err := rows.Scan(&repo)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "Scan")
|
||||
}
|
||||
repos = append(repos, repo)
|
||||
}
|
||||
return repos, nil
|
||||
}
|
||||
|
||||
func (g *globalDeps) UpdateIndexForLanguage(ctx context.Context, language string, repo api.RepoID, deps []lspext.DependencyReference) (err error) {
|
||||
err = db.Transaction(ctx, dbconn.Global, func(tx *sql.Tx) error {
|
||||
// Update the table.
|
||||
err = g.update(ctx, tx, language, deps, repo)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "update global_dep")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "executing transaction")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *globalDeps) Dependencies(ctx context.Context, op db.DependenciesOptions) (refs []*api.DependencyReference, err error) {
|
||||
if db.Mocks.GlobalDeps.Dependencies != nil {
|
||||
return db.Mocks.GlobalDeps.Dependencies(ctx, op)
|
||||
}
|
||||
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "db.Dependencies")
|
||||
defer func() {
|
||||
if err != nil {
|
||||
ext.Error.Set(span, true)
|
||||
span.SetTag("err", err.Error())
|
||||
}
|
||||
span.Finish()
|
||||
}()
|
||||
span.SetTag("Language", op.Language)
|
||||
span.SetTag("DepData", op.DepData)
|
||||
|
||||
var args []interface{}
|
||||
arg := func(a interface{}) string {
|
||||
args = append(args, a)
|
||||
return fmt.Sprintf("$%d", len(args))
|
||||
}
|
||||
|
||||
var whereConds []string
|
||||
|
||||
if op.Language != "" {
|
||||
whereConds = append(whereConds, `gd.language=`+arg(op.Language))
|
||||
}
|
||||
|
||||
if op.DepData != nil {
|
||||
containmentQuery, err := json.Marshal(op.DepData)
|
||||
if err != nil {
|
||||
return nil, errors.New("marshaling op.DepData query")
|
||||
}
|
||||
whereConds = append(whereConds, `dep_data @> `+arg(string(containmentQuery)))
|
||||
}
|
||||
if op.Repo != 0 {
|
||||
whereConds = append(whereConds, `repo_id = `+arg(op.Repo))
|
||||
}
|
||||
|
||||
selectSQL := `SELECT gd.language, dep_data, repo_id, hints`
|
||||
fromSQL := `FROM global_dep AS gd INNER JOIN repo AS r ON gd.repo_id=r.id`
|
||||
whereSQL := ""
|
||||
if len(whereConds) > 0 {
|
||||
whereSQL = `WHERE ` + strings.Join(whereConds, " AND ")
|
||||
}
|
||||
limitSQL := ""
|
||||
if op.Limit != 0 {
|
||||
limitSQL = `LIMIT ` + arg(op.Limit)
|
||||
}
|
||||
sql := fmt.Sprintf("%s %s %s %s", selectSQL, fromSQL, whereSQL, limitSQL)
|
||||
|
||||
rows, err := dbconn.Global.QueryContext(ctx, sql, args...)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "query")
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var (
|
||||
language, depData, hints string
|
||||
repo api.RepoID
|
||||
)
|
||||
if err := rows.Scan(&language, &depData, &repo, &hints); err != nil {
|
||||
return nil, errors.Wrap(err, "Scan")
|
||||
}
|
||||
r := &api.DependencyReference{
|
||||
RepoID: repo,
|
||||
Language: language,
|
||||
}
|
||||
if err := json.Unmarshal([]byte(depData), &r.DepData); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshaling xdependencies metadata from sql scan")
|
||||
}
|
||||
if err := json.Unmarshal([]byte(hints), &r.Hints); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshaling xdependencies hints from sql scan")
|
||||
}
|
||||
refs = append(refs, r)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, errors.Wrap(err, "rows error")
|
||||
}
|
||||
return refs, nil
|
||||
}
|
||||
|
||||
// updateGlobalDep updates the global_dep table.
|
||||
func (g *globalDeps) update(ctx context.Context, tx *sql.Tx, language string, deps []lspext.DependencyReference, indexRepo api.RepoID) (err error) {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "updateGlobalDep "+language)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
ext.Error.Set(span, true)
|
||||
span.SetTag("err", err.Error())
|
||||
}
|
||||
span.Finish()
|
||||
}()
|
||||
span.SetTag("deps", len(deps))
|
||||
|
||||
// First, create a temporary table.
|
||||
_, err = tx.ExecContext(ctx, `CREATE TEMPORARY TABLE new_global_dep (
|
||||
language text NOT NULL,
|
||||
dep_data jsonb NOT NULL,
|
||||
repo_id integer NOT NULL,
|
||||
hints jsonb
|
||||
) ON COMMIT DROP;`)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "create temp table")
|
||||
}
|
||||
span.LogFields(otlog.String("event", "created temp table"))
|
||||
|
||||
// Copy the new deps into the temporary table.
|
||||
copy, err := tx.Prepare(pq.CopyIn("new_global_dep",
|
||||
"language",
|
||||
"dep_data",
|
||||
"repo_id",
|
||||
"hints",
|
||||
))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "prepare copy in")
|
||||
}
|
||||
defer copy.Close()
|
||||
span.LogFields(otlog.String("event", "prepared copy in"))
|
||||
|
||||
for _, r := range deps {
|
||||
data, err := json.Marshal(r.Attributes)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "marshaling xdependency metadata to JSON")
|
||||
}
|
||||
hintsData, err := json.Marshal(r.Hints)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "marshaling xdependency hints to JSON")
|
||||
}
|
||||
|
||||
if _, err := copy.Exec(
|
||||
language, // language
|
||||
string(data), // dep_data
|
||||
indexRepo, // repo_id
|
||||
string(hintsData), // hints
|
||||
); err != nil {
|
||||
return errors.Wrap(err, "executing ref copy")
|
||||
}
|
||||
}
|
||||
span.LogFields(otlog.String("event", "executed all dep copy"))
|
||||
if _, err := copy.Exec(); err != nil {
|
||||
return errors.Wrap(err, "executing copy")
|
||||
}
|
||||
span.LogFields(otlog.String("event", "executed copy"))
|
||||
|
||||
if _, err := tx.ExecContext(ctx, `DELETE FROM global_dep WHERE language=$1 AND repo_id=$2`, language, indexRepo); err != nil {
|
||||
return errors.Wrap(err, "executing table deletion")
|
||||
}
|
||||
span.LogFields(otlog.String("event", "executed table deletion"))
|
||||
|
||||
// Insert from temporary table into the real table.
|
||||
_, err = tx.ExecContext(ctx, `INSERT INTO global_dep(
|
||||
language,
|
||||
dep_data,
|
||||
repo_id,
|
||||
hints
|
||||
) SELECT d.language,
|
||||
d.dep_data,
|
||||
d.repo_id,
|
||||
d.hints
|
||||
FROM new_global_dep d;`)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "executing final insertion from temp table")
|
||||
}
|
||||
span.LogFields(otlog.String("event", "executed final insertion from temp table"))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (g *globalDeps) Delete(ctx context.Context, repo api.RepoID) error {
|
||||
_, err := dbconn.Global.ExecContext(ctx, `DELETE FROM global_dep WHERE repo_id=$1`, repo)
|
||||
return err
|
||||
}
|
||||
345
enterprise/cmd/frontend/db/global_deps_test.go
Normal file
345
enterprise/cmd/frontend/db/global_deps_test.go
Normal file
@ -0,0 +1,345 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
|
||||
dbtesting "github.com/sourcegraph/sourcegraph/cmd/frontend/db/testing"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/api"
|
||||
"github.com/sourcegraph/sourcegraph/xlang/lspext"
|
||||
)
|
||||
|
||||
func TestGlobalDeps_TotalRefsExpansion(t *testing.T) {
|
||||
tests := map[api.RepoURI][]string{
|
||||
// azul3d.org
|
||||
"github.com/azul3d/engine": {"azul3d.org/engine"},
|
||||
|
||||
// dasa.cc
|
||||
"github.com/dskinner/ztext": {"dasa.cc/ztext"},
|
||||
|
||||
// k8s.io
|
||||
"github.com/kubernetes/kubernetes": {"k8s.io/kubernetes"},
|
||||
"github.com/kubernetes/apimachinery": {"k8s.io/apimachinery"},
|
||||
"github.com/kubernetes/client-go": {"k8s.io/client-go"},
|
||||
"github.com/kubernetes/heapster": {"k8s.io/heapster"},
|
||||
|
||||
// golang.org/x
|
||||
"github.com/golang/net": {"golang.org/x/net"},
|
||||
"github.com/golang/tools": {"golang.org/x/tools"},
|
||||
"github.com/golang/oauth2": {"golang.org/x/oauth2"},
|
||||
"github.com/golang/crypto": {"golang.org/x/crypto"},
|
||||
"github.com/golang/sys": {"golang.org/x/sys"},
|
||||
"github.com/golang/text": {"golang.org/x/text"},
|
||||
"github.com/golang/image": {"golang.org/x/image"},
|
||||
"github.com/golang/mobile": {"golang.org/x/mobile"},
|
||||
|
||||
// google.golang.org
|
||||
"github.com/grpc/grpc-go": {"google.golang.org/grpc"},
|
||||
"github.com/google/google-api-go-client": {"google.golang.org/api"},
|
||||
"github.com/golang/appengine": {"google.golang.org/appengine"},
|
||||
|
||||
// go.uber.org
|
||||
"github.com/uber-go/yarpc": {"github.com/uber-go/yarpc", "go.uber.org/yarpc"},
|
||||
"github.com/uber-go/thriftrw": {"github.com/uber-go/thriftrw", "go.uber.org/thriftrw"},
|
||||
"github.com/uber-go/zap": {"github.com/uber-go/zap", "go.uber.org/zap"},
|
||||
"github.com/uber-go/atomic": {"github.com/uber-go/atomic", "go.uber.org/atomic"},
|
||||
"github.com/uber-go/fx": {"github.com/uber-go/fx", "go.uber.org/fx"},
|
||||
|
||||
// go4.org
|
||||
"github.com/camlistore/go4": {"go4.org"},
|
||||
|
||||
// honnef.co
|
||||
"github.com/dominikh/go-staticcheck": {"honnef.co/go/staticcheck"},
|
||||
"github.com/dominikh/go-js-dom": {"honnef.co/go/js/dom"},
|
||||
"github.com/dominikh/go-ssa": {"honnef.co/go/ssa"},
|
||||
|
||||
// gopkg.in
|
||||
"github.com/go-mgo/mgo": {"github.com/go-mgo/mgo", "gopkg.in/mgo", "labix.org/v1/mgo", "labix.org/v2/mgo"},
|
||||
"github.com/go-yaml/yaml": {"github.com/go-yaml/yaml", "gopkg.in/yaml", "labix.org/v1/yaml", "labix.org/v2/yaml"},
|
||||
"github.com/fatih/set": {"github.com/fatih/set", "gopkg.in/fatih/set"},
|
||||
"github.com/juju/environschema": {"github.com/juju/environschema", "gopkg.in/juju/environschema"},
|
||||
}
|
||||
for input, want := range tests {
|
||||
got := repoURIToGoPathPrefixes(input)
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Errorf("got %q want %q", got, want)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func TestGlobalDeps_update_delete(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip()
|
||||
}
|
||||
ctx := dbtesting.TestContext(t)
|
||||
|
||||
if err := db.Repos.Upsert(ctx, api.InsertRepoOp{URI: "myrepo", Description: "", Fork: false, Enabled: true}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rp, err := db.Repos.GetByURI(ctx, "myrepo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
repo := rp.ID
|
||||
|
||||
inputRefs := []lspext.DependencyReference{{
|
||||
Attributes: map[string]interface{}{"name": "dep1", "vendor": true},
|
||||
}}
|
||||
if err := GlobalDeps.UpdateIndexForLanguage(ctx, "go", repo, inputRefs); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Log("update")
|
||||
wantRefs := []*api.DependencyReference{{
|
||||
Language: "go",
|
||||
DepData: map[string]interface{}{"name": "dep1", "vendor": true},
|
||||
RepoID: repo,
|
||||
}}
|
||||
gotRefs, err := GlobalDeps.Dependencies(ctx, db.DependenciesOptions{
|
||||
Language: "go",
|
||||
DepData: map[string]interface{}{"name": "dep1"},
|
||||
Limit: 20,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sort.Sort(sortDepRefs(wantRefs))
|
||||
sort.Sort(sortDepRefs(gotRefs))
|
||||
if !reflect.DeepEqual(gotRefs, wantRefs) {
|
||||
t.Errorf("got %+v, expected %+v", gotRefs, wantRefs)
|
||||
}
|
||||
|
||||
t.Log("delete other")
|
||||
if err := GlobalDeps.Delete(ctx, 345345345); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
gotRefs, err = GlobalDeps.Dependencies(ctx, db.DependenciesOptions{
|
||||
Language: "go",
|
||||
DepData: map[string]interface{}{"name": "dep1"},
|
||||
Limit: 20,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sort.Sort(sortDepRefs(wantRefs))
|
||||
sort.Sort(sortDepRefs(gotRefs))
|
||||
if !reflect.DeepEqual(gotRefs, wantRefs) {
|
||||
t.Errorf("got %+v, expected %+v", gotRefs, wantRefs)
|
||||
}
|
||||
|
||||
t.Log("delete")
|
||||
if err := GlobalDeps.Delete(ctx, repo); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
gotRefs, err = GlobalDeps.Dependencies(ctx, db.DependenciesOptions{
|
||||
Language: "go",
|
||||
DepData: map[string]interface{}{"name": "dep1"},
|
||||
Limit: 20,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(gotRefs) > 0 {
|
||||
t.Errorf("expected no matching refs, got %+v", gotRefs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGlobalDeps_RefreshIndex(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip()
|
||||
}
|
||||
ctx := dbtesting.TestContext(t)
|
||||
|
||||
if err := db.Repos.Upsert(ctx, api.InsertRepoOp{URI: "myrepo", Description: "", Fork: false, Enabled: true}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
repo, err := db.Repos.GetByURI(ctx, "myrepo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := GlobalDeps.UpdateIndexForLanguage(ctx, "go", repo.ID, []lspext.DependencyReference{{
|
||||
Attributes: map[string]interface{}{
|
||||
"name": "github.com/gorilla/dep",
|
||||
"vendor": true,
|
||||
},
|
||||
}}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
wantRefs := []*api.DependencyReference{{
|
||||
Language: "go",
|
||||
DepData: map[string]interface{}{"name": "github.com/gorilla/dep", "vendor": true},
|
||||
RepoID: repo.ID,
|
||||
}}
|
||||
gotRefs, err := GlobalDeps.Dependencies(ctx, db.DependenciesOptions{
|
||||
Language: "go",
|
||||
DepData: map[string]interface{}{"name": "github.com/gorilla/dep"},
|
||||
Limit: 20,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sort.Sort(sortDepRefs(wantRefs))
|
||||
sort.Sort(sortDepRefs(gotRefs))
|
||||
if !reflect.DeepEqual(gotRefs, wantRefs) {
|
||||
t.Errorf("got %+v, expected %+v", gotRefs, wantRefs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGlobalDeps_Dependencies(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip()
|
||||
}
|
||||
ctx := dbtesting.TestContext(t)
|
||||
|
||||
repos := make([]api.RepoID, 5)
|
||||
for i := 0; i < 5; i++ {
|
||||
uri := api.RepoURI(fmt.Sprintf("myrepo-%d", i))
|
||||
if err := db.Repos.Upsert(ctx, api.InsertRepoOp{URI: uri, Description: "", Fork: false, Enabled: true}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rp, err := db.Repos.GetByURI(ctx, uri)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
repos[i] = rp.ID
|
||||
}
|
||||
|
||||
inputRefs := map[api.RepoID][]lspext.DependencyReference{
|
||||
repos[0]: {{Attributes: map[string]interface{}{"name": "github.com/gorilla/dep2", "vendor": true}}},
|
||||
repos[1]: {{Attributes: map[string]interface{}{"name": "github.com/gorilla/dep3", "vendor": true}}},
|
||||
repos[2]: {{Attributes: map[string]interface{}{"name": "github.com/gorilla/dep4", "vendor": true}}},
|
||||
repos[3]: {{Attributes: map[string]interface{}{"name": "github.com/gorilla/dep4", "vendor": true}}},
|
||||
repos[4]: {{Attributes: map[string]interface{}{"name": "github.com/gorilla/dep4", "vendor": true}}},
|
||||
}
|
||||
for rp, deps := range inputRefs {
|
||||
err := GlobalDeps.UpdateIndexForLanguage(ctx, "go", rp, deps)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
{ // Test case 1
|
||||
wantRefs := []*api.DependencyReference{{
|
||||
Language: "go",
|
||||
DepData: map[string]interface{}{"name": "github.com/gorilla/dep2", "vendor": true},
|
||||
RepoID: repos[0],
|
||||
}}
|
||||
gotRefs, err := GlobalDeps.Dependencies(ctx, db.DependenciesOptions{
|
||||
Language: "go",
|
||||
DepData: map[string]interface{}{"name": "github.com/gorilla/dep2"},
|
||||
Limit: 20,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sort.Sort(sortDepRefs(wantRefs))
|
||||
sort.Sort(sortDepRefs(gotRefs))
|
||||
if !reflect.DeepEqual(gotRefs, wantRefs) {
|
||||
t.Errorf("got %+v, expected %+v", gotRefs, wantRefs)
|
||||
}
|
||||
}
|
||||
{ // Test case 2
|
||||
wantRefs := []*api.DependencyReference{{
|
||||
Language: "go",
|
||||
DepData: map[string]interface{}{"name": "github.com/gorilla/dep3", "vendor": true},
|
||||
RepoID: repos[1],
|
||||
}}
|
||||
gotRefs, err := GlobalDeps.Dependencies(ctx, db.DependenciesOptions{
|
||||
Language: "go",
|
||||
DepData: map[string]interface{}{"name": "github.com/gorilla/dep3"},
|
||||
Limit: 20,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sort.Sort(sortDepRefs(wantRefs))
|
||||
sort.Sort(sortDepRefs(gotRefs))
|
||||
if !reflect.DeepEqual(gotRefs, wantRefs) {
|
||||
t.Errorf("got %+v, expected %+v", gotRefs, wantRefs)
|
||||
}
|
||||
}
|
||||
{ // Test case 3
|
||||
wantRefs := []*api.DependencyReference{{
|
||||
Language: "go",
|
||||
DepData: map[string]interface{}{"name": "github.com/gorilla/dep4", "vendor": true},
|
||||
RepoID: repos[2],
|
||||
}, {
|
||||
Language: "go",
|
||||
DepData: map[string]interface{}{"name": "github.com/gorilla/dep4", "vendor": true},
|
||||
RepoID: repos[3],
|
||||
},
|
||||
{
|
||||
Language: "go",
|
||||
DepData: map[string]interface{}{"name": "github.com/gorilla/dep4", "vendor": true},
|
||||
RepoID: repos[4],
|
||||
},
|
||||
}
|
||||
gotRefs, err := GlobalDeps.Dependencies(ctx, db.DependenciesOptions{
|
||||
Language: "go",
|
||||
DepData: map[string]interface{}{"name": "github.com/gorilla/dep4"},
|
||||
Limit: 20,
|
||||
})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
sort.Sort(sortDepRefs(wantRefs))
|
||||
sort.Sort(sortDepRefs(gotRefs))
|
||||
if !reflect.DeepEqual(gotRefs, wantRefs) {
|
||||
t.Errorf("got %+v, expected %+v", gotRefs, wantRefs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type sortDepRefs []*api.DependencyReference
|
||||
|
||||
func (s sortDepRefs) Len() int { return len(s) }
|
||||
|
||||
func (s sortDepRefs) Swap(a, b int) { s[a], s[b] = s[b], s[a] }
|
||||
|
||||
func (s sortDepRefs) Less(a, b int) bool {
|
||||
if s[a].RepoID != s[b].RepoID {
|
||||
return s[a].RepoID < s[b].RepoID
|
||||
}
|
||||
if !reflect.DeepEqual(s[a].DepData, s[b].DepData) {
|
||||
return stringMapLess(s[a].DepData, s[b].DepData)
|
||||
}
|
||||
return stringMapLess(s[a].Hints, s[b].Hints)
|
||||
}
|
||||
|
||||
func stringMapLess(a, b map[string]interface{}) bool {
|
||||
if len(a) != len(b) {
|
||||
return len(a) < len(b)
|
||||
}
|
||||
ak := make([]string, 0, len(a))
|
||||
for k := range a {
|
||||
ak = append(ak, k)
|
||||
}
|
||||
bk := make([]string, 0, len(b))
|
||||
for k := range b {
|
||||
bk = append(bk, k)
|
||||
}
|
||||
sort.Strings(ak)
|
||||
sort.Strings(bk)
|
||||
for i := range ak {
|
||||
if ak[i] != bk[i] {
|
||||
return ak[i] < bk[i]
|
||||
}
|
||||
// This does not consistentlbk order the output, but in the
|
||||
// cases we use this it will since it is just a simple value
|
||||
// like bool or string
|
||||
av, _ := json.Marshal(a[ak[i]])
|
||||
bv, _ := json.Marshal(b[bk[i]])
|
||||
if bytes.Equal(av, bv) {
|
||||
return string(av) < string(bv)
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
212
enterprise/cmd/frontend/db/pkgs.go
Normal file
212
enterprise/cmd/frontend/db/pkgs.go
Normal file
@ -0,0 +1,212 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/lib/pq"
|
||||
opentracing "github.com/opentracing/opentracing-go"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
otlog "github.com/opentracing/opentracing-go/log"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/db/dbconn"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/api"
|
||||
"github.com/sourcegraph/sourcegraph/xlang/lspext"
|
||||
)
|
||||
|
||||
// pkgs provides access to the `pkgs` table.
|
||||
//
|
||||
// For a detailed overview of the schema, see schema.txt.
|
||||
type pkgs struct{}
|
||||
|
||||
func (p *pkgs) UpdateIndexForLanguage(ctx context.Context, language string, repo api.RepoID, pks []lspext.PackageInformation) (err error) {
|
||||
err = db.Transaction(ctx, dbconn.Global, func(tx *sql.Tx) error {
|
||||
// Update the pkgs table.
|
||||
err = p.update(ctx, tx, repo, language, pks)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "pkgs.update")
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "executing transaction")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type dbQueryer interface {
|
||||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
func (p *pkgs) update(ctx context.Context, tx *sql.Tx, indexRepo api.RepoID, language string, pks []lspext.PackageInformation) (err error) {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "pkgs.update "+language)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
ext.Error.Set(span, true)
|
||||
span.SetTag("err", err.Error())
|
||||
}
|
||||
span.Finish()
|
||||
}()
|
||||
span.SetTag("pkgs", len(pks))
|
||||
|
||||
// First, create a temporary table.
|
||||
_, err = tx.ExecContext(ctx, `CREATE TEMPORARY TABLE new_pkgs (
|
||||
pkg jsonb NOT NULL,
|
||||
language text NOT NULL,
|
||||
repo_id integer NOT NULL
|
||||
) ON COMMIT DROP;`)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "create temp table")
|
||||
}
|
||||
span.LogFields(otlog.String("event", "created temp table"))
|
||||
|
||||
// Copy the new deps into the temporary table.
|
||||
copy, err := tx.Prepare(pq.CopyIn("new_pkgs",
|
||||
"repo_id",
|
||||
"language",
|
||||
"pkg",
|
||||
))
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "prepare copy in")
|
||||
}
|
||||
defer copy.Close()
|
||||
span.LogFields(otlog.String("event", "prepared copy in"))
|
||||
|
||||
for _, r := range pks {
|
||||
pkgData, err := json.Marshal(r.Package)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "marshaling package metadata to JSON")
|
||||
}
|
||||
|
||||
if _, err := copy.Exec(
|
||||
indexRepo, // repo_id
|
||||
language, // language
|
||||
string(pkgData), // pkg
|
||||
); err != nil {
|
||||
return errors.Wrap(err, "executing pkg copy")
|
||||
}
|
||||
}
|
||||
span.LogFields(otlog.String("event", "executed all pkg copy"))
|
||||
if _, err := copy.Exec(); err != nil {
|
||||
return errors.Wrap(err, "executing copy")
|
||||
}
|
||||
span.LogFields(otlog.String("event", "executed copy"))
|
||||
|
||||
if _, err := tx.ExecContext(ctx, `DELETE FROM pkgs WHERE language=$1 AND repo_id=$2`, language, indexRepo); err != nil {
|
||||
return errors.Wrap(err, "executing pkgs deletion")
|
||||
}
|
||||
span.LogFields(otlog.String("event", "executed pkgs deletion"))
|
||||
|
||||
// Insert from temporary table into the real table.
|
||||
_, err = tx.ExecContext(ctx, `INSERT INTO pkgs(
|
||||
repo_id,
|
||||
language,
|
||||
pkg
|
||||
)
|
||||
SELECT p.repo_id,
|
||||
p.language,
|
||||
p.pkg
|
||||
FROM new_pkgs p;
|
||||
`)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "executing final insertion from temp table")
|
||||
}
|
||||
span.LogFields(otlog.String("event", "executed final insertion from temp table"))
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *pkgs) ListPackages(ctx context.Context, op *api.ListPackagesOp) (pks []*api.PackageInfo, err error) {
|
||||
if db.Mocks.Pkgs.ListPackages != nil {
|
||||
return db.Mocks.Pkgs.ListPackages(ctx, op)
|
||||
}
|
||||
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "pkgs.ListPackages")
|
||||
defer func() {
|
||||
if err != nil {
|
||||
ext.Error.Set(span, true)
|
||||
span.SetTag("err", err.Error())
|
||||
}
|
||||
span.Finish()
|
||||
}()
|
||||
span.SetTag("Lang", op.Lang)
|
||||
span.SetTag("PkgQuery", op.PkgQuery)
|
||||
|
||||
var args []interface{}
|
||||
arg := func(a interface{}) string {
|
||||
args = append(args, a)
|
||||
return fmt.Sprintf("$%d", len(args))
|
||||
}
|
||||
|
||||
var whereClauses []string
|
||||
if op.PkgQuery != nil {
|
||||
containmentQuery, err := json.Marshal(op.PkgQuery)
|
||||
if err != nil {
|
||||
return nil, errors.New("marshaling op.PkgQuery")
|
||||
}
|
||||
whereClauses = append(whereClauses, `pkgs.pkg @> `+arg(string(containmentQuery)))
|
||||
}
|
||||
if op.RepoID != 0 {
|
||||
whereClauses = append(whereClauses, `repo_id=`+arg(op.RepoID))
|
||||
}
|
||||
if op.Lang != "" {
|
||||
whereClauses = append(whereClauses, `pkgs.language=`+arg(op.Lang))
|
||||
}
|
||||
if len(whereClauses) == 0 {
|
||||
return nil, fmt.Errorf("no filtering options specified, must specify at least one")
|
||||
}
|
||||
whereSQL := "(" + strings.Join(whereClauses, ") AND (") + ")"
|
||||
sql := `
|
||||
SELECT pkgs.*
|
||||
FROM pkgs INNER JOIN repo ON pkgs.repo_id=repo.id
|
||||
WHERE ` + whereSQL + `
|
||||
ORDER BY repo.created_at ASC NULLS LAST, pkgs.repo_id ASC
|
||||
LIMIT ` + arg(op.Limit)
|
||||
rows, err := dbconn.Global.QueryContext(ctx, sql, args...)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "query")
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var rawPkgs []*api.PackageInfo
|
||||
for rows.Next() {
|
||||
var (
|
||||
pkg, lang string
|
||||
repo api.RepoID
|
||||
)
|
||||
if err := rows.Scan(&repo, &lang, &pkg); err != nil {
|
||||
return nil, errors.Wrap(err, "Scan")
|
||||
}
|
||||
r := api.PackageInfo{
|
||||
RepoID: repo,
|
||||
Lang: lang,
|
||||
// NOTE: Dependency info (in api.PackageInfo's Dependencies field) is not set
|
||||
// here because it is stored separately in the global_dep table in a way that
|
||||
// is slow and difficult to get in this code path. Currently callers that use
|
||||
// DB-persisted package info do not need the dependency info, so this is
|
||||
// acceptable.
|
||||
}
|
||||
if err := json.Unmarshal([]byte(pkg), &r.Pkg); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshaling xdependencies metadata from sql scan")
|
||||
}
|
||||
rawPkgs = append(rawPkgs, &r)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, errors.Wrap(err, "rows error")
|
||||
}
|
||||
|
||||
return rawPkgs, nil
|
||||
}
|
||||
|
||||
func (p *pkgs) Delete(ctx context.Context, repo api.RepoID) error {
|
||||
if db.Mocks.Pkgs.Delete != nil {
|
||||
return db.Mocks.Pkgs.Delete(ctx, repo)
|
||||
}
|
||||
|
||||
_, err := dbconn.Global.ExecContext(ctx, `DELETE FROM pkgs WHERE repo_id=$1`, repo)
|
||||
return err
|
||||
}
|
||||
303
enterprise/cmd/frontend/db/pkgs_test.go
Normal file
303
enterprise/cmd/frontend/db/pkgs_test.go
Normal file
@ -0,0 +1,303 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/db/dbconn"
|
||||
dbtesting "github.com/sourcegraph/sourcegraph/cmd/frontend/db/testing"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/api"
|
||||
"github.com/sourcegraph/sourcegraph/xlang/lspext"
|
||||
)
|
||||
|
||||
func TestPkgs_update_delete(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip()
|
||||
}
|
||||
ctx := dbtesting.TestContext(t)
|
||||
|
||||
if err := db.Repos.Upsert(ctx, api.InsertRepoOp{URI: "myrepo", Description: "", Fork: false, Enabled: true}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rp, err := db.Repos.GetByURI(ctx, "myrepo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
pks := []lspext.PackageInformation{{
|
||||
Package: map[string]interface{}{"name": "pkg"},
|
||||
Dependencies: []lspext.DependencyReference{{
|
||||
Attributes: map[string]interface{}{"name": "dep1"},
|
||||
}},
|
||||
}}
|
||||
|
||||
t.Log("update")
|
||||
if err := db.Transaction(ctx, dbconn.Global, func(tx *sql.Tx) error {
|
||||
if err := Pkgs.update(ctx, tx, rp.ID, "go", pks); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
expPkgs := []*api.PackageInfo{{
|
||||
RepoID: rp.ID,
|
||||
Lang: "go",
|
||||
Pkg: map[string]interface{}{"name": "pkg"},
|
||||
}}
|
||||
gotPkgs, err := Pkgs.getAll(ctx, dbconn.Global)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !reflect.DeepEqual(gotPkgs, expPkgs) {
|
||||
t.Errorf("got %+v, expected %+v", gotPkgs, expPkgs)
|
||||
}
|
||||
|
||||
t.Log("delete nothing")
|
||||
if err := Pkgs.Delete(ctx, 0); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
gotPkgs, err = Pkgs.getAll(ctx, dbconn.Global)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !reflect.DeepEqual(gotPkgs, expPkgs) {
|
||||
t.Errorf("got %+v, expected %+v", gotPkgs, expPkgs)
|
||||
}
|
||||
|
||||
t.Log("delete")
|
||||
if err := Pkgs.Delete(ctx, 1); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
gotPkgs, err = Pkgs.getAll(ctx, dbconn.Global)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if len(gotPkgs) > 0 {
|
||||
t.Errorf("expected all pkgs corresponding to repo %d deleted, but got %+v", rp.ID, gotPkgs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPkgs_RefreshIndex(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip()
|
||||
}
|
||||
ctx := dbtesting.TestContext(t)
|
||||
|
||||
if err := db.Repos.Upsert(ctx, api.InsertRepoOp{URI: "myrepo", Description: "", Fork: false, Enabled: true}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
rp, err := db.Repos.GetByURI(ctx, "myrepo")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if err := Pkgs.UpdateIndexForLanguage(ctx, "typescript", rp.ID, []lspext.PackageInformation{
|
||||
{
|
||||
Package: map[string]interface{}{
|
||||
"name": "tspkg",
|
||||
"version": "2.2.2",
|
||||
},
|
||||
Dependencies: []lspext.DependencyReference{},
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
expPkgs := []*api.PackageInfo{{
|
||||
RepoID: rp.ID,
|
||||
Lang: "typescript",
|
||||
Pkg: map[string]interface{}{
|
||||
"name": "tspkg",
|
||||
"version": "2.2.2",
|
||||
},
|
||||
}}
|
||||
gotPkgs, err := Pkgs.getAll(ctx, dbconn.Global)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !reflect.DeepEqual(gotPkgs, expPkgs) {
|
||||
t.Errorf("got %+v, expected %+v", gotPkgs, expPkgs)
|
||||
}
|
||||
}
|
||||
|
||||
func TestPkgs_ListPackages(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip()
|
||||
}
|
||||
ctx := dbtesting.TestContext(t)
|
||||
|
||||
repoToPkgs := map[api.RepoID][]lspext.PackageInformation{
|
||||
1: {{
|
||||
Package: map[string]interface{}{"name": "pkg1", "version": "1.1.1"},
|
||||
Dependencies: []lspext.DependencyReference{{
|
||||
Attributes: map[string]interface{}{"name": "pkg1-dep", "version": "1.1.2"},
|
||||
}},
|
||||
}},
|
||||
2: {{
|
||||
Package: map[string]interface{}{"name": "pkg2", "version": "2.2.1"},
|
||||
Dependencies: []lspext.DependencyReference{{
|
||||
Attributes: map[string]interface{}{"name": "pkg2-dep", "version": "2.2.2"},
|
||||
}},
|
||||
}},
|
||||
3: {{Package: map[string]interface{}{"name": "pkg3", "version": "3.3.1"}}},
|
||||
4: {{Package: map[string]interface{}{"name": "pkg3", "version": "3.3.1"}}},
|
||||
5: {{Package: map[string]interface{}{"name": "pkg3", "version": "3.3.1"}}},
|
||||
}
|
||||
|
||||
createdAt := time.Now()
|
||||
for repo, pks := range repoToPkgs {
|
||||
if err := db.Transaction(ctx, dbconn.Global, func(tx *sql.Tx) error {
|
||||
if _, err := tx.ExecContext(ctx, `INSERT INTO repo(id, uri, created_at) VALUES ($1, $2, $3)`, repo, strconv.Itoa(int(repo)), createdAt); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := Pkgs.update(ctx, tx, repo, "go", pks); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
{ // Test case 1
|
||||
expPkgInfo := []*api.PackageInfo{{
|
||||
RepoID: 1,
|
||||
Lang: "go",
|
||||
Pkg: map[string]interface{}{"name": "pkg1", "version": "1.1.1"},
|
||||
}}
|
||||
op := &api.ListPackagesOp{
|
||||
Lang: "go",
|
||||
PkgQuery: map[string]interface{}{"name": "pkg1"},
|
||||
Limit: 10,
|
||||
}
|
||||
gotPkgInfo, err := Pkgs.ListPackages(ctx, op)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !reflect.DeepEqual(gotPkgInfo, expPkgInfo) {
|
||||
t.Errorf("got %+v, expected %+v", gotPkgInfo, expPkgInfo)
|
||||
}
|
||||
}
|
||||
{ // Test case 2
|
||||
expPkgInfo := []*api.PackageInfo{{
|
||||
RepoID: 1,
|
||||
Lang: "go",
|
||||
Pkg: map[string]interface{}{"name": "pkg1", "version": "1.1.1"},
|
||||
}}
|
||||
op := &api.ListPackagesOp{
|
||||
Lang: "go",
|
||||
PkgQuery: map[string]interface{}{"name": "pkg1", "version": "1.1.1"},
|
||||
Limit: 10,
|
||||
}
|
||||
gotPkgInfo, err := Pkgs.ListPackages(ctx, op)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !reflect.DeepEqual(gotPkgInfo, expPkgInfo) {
|
||||
t.Errorf("got %+v, expected %+v", gotPkgInfo, expPkgInfo)
|
||||
}
|
||||
}
|
||||
{ // Test case 3
|
||||
var expPkgInfo []*api.PackageInfo
|
||||
op := &api.ListPackagesOp{
|
||||
Lang: "go",
|
||||
PkgQuery: map[string]interface{}{"name": "pkg1", "version": "2"},
|
||||
Limit: 10,
|
||||
}
|
||||
gotPkgInfo, err := Pkgs.ListPackages(ctx, op)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !reflect.DeepEqual(gotPkgInfo, expPkgInfo) {
|
||||
t.Errorf("got %+v, expected %+v", gotPkgInfo, expPkgInfo)
|
||||
}
|
||||
}
|
||||
{ // Test case 4
|
||||
expPkgInfo := []*api.PackageInfo{{
|
||||
RepoID: 3,
|
||||
Lang: "go",
|
||||
Pkg: map[string]interface{}{"name": "pkg3", "version": "3.3.1"},
|
||||
}, {
|
||||
RepoID: 4,
|
||||
Lang: "go",
|
||||
Pkg: map[string]interface{}{"name": "pkg3", "version": "3.3.1"},
|
||||
},
|
||||
{
|
||||
RepoID: 5,
|
||||
Lang: "go",
|
||||
Pkg: map[string]interface{}{"name": "pkg3", "version": "3.3.1"},
|
||||
},
|
||||
}
|
||||
op := &api.ListPackagesOp{
|
||||
Lang: "go",
|
||||
PkgQuery: map[string]interface{}{"name": "pkg3"},
|
||||
Limit: 10,
|
||||
}
|
||||
gotPkgInfo, err := Pkgs.ListPackages(ctx, op)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !reflect.DeepEqual(gotPkgInfo, expPkgInfo) {
|
||||
t.Errorf("got %+v, expected %+v", gotPkgInfo, expPkgInfo)
|
||||
}
|
||||
}
|
||||
{ // Test case 5, filter by repo ID
|
||||
expPkgInfo := []*api.PackageInfo{{
|
||||
RepoID: 3,
|
||||
Lang: "go",
|
||||
Pkg: map[string]interface{}{"name": "pkg3", "version": "3.3.1"},
|
||||
}}
|
||||
op := &api.ListPackagesOp{
|
||||
Lang: "go",
|
||||
RepoID: 3,
|
||||
Limit: 10,
|
||||
}
|
||||
gotPkgInfo, err := Pkgs.ListPackages(ctx, op)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !reflect.DeepEqual(gotPkgInfo, expPkgInfo) {
|
||||
t.Errorf("got %+v, expected %+v", gotPkgInfo, expPkgInfo)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (p *pkgs) getAll(ctx context.Context, db dbQueryer) (packages []*api.PackageInfo, err error) {
|
||||
rows, err := db.QueryContext(ctx, "SELECT * FROM pkgs ORDER BY language ASC")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "query")
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var (
|
||||
repo api.RepoID
|
||||
language string
|
||||
pkg string
|
||||
)
|
||||
if err := rows.Scan(&repo, &language, &pkg); err != nil {
|
||||
return nil, errors.Wrap(err, "Scan")
|
||||
}
|
||||
p := api.PackageInfo{
|
||||
RepoID: repo,
|
||||
Lang: language,
|
||||
}
|
||||
if err := json.Unmarshal([]byte(pkg), &p.Pkg); err != nil {
|
||||
return nil, errors.Wrap(err, "unmarshaling package metadata from SQL scan")
|
||||
}
|
||||
packages = append(packages, &p)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, errors.Wrap(err, "rows error")
|
||||
}
|
||||
return packages, nil
|
||||
}
|
||||
15
enterprise/cmd/frontend/db/register.go
Normal file
15
enterprise/cmd/frontend/db/register.go
Normal file
@ -0,0 +1,15 @@
|
||||
package db
|
||||
|
||||
import (
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
|
||||
)
|
||||
|
||||
var (
|
||||
Pkgs = &pkgs{}
|
||||
GlobalDeps = &globalDeps{}
|
||||
)
|
||||
|
||||
func init() {
|
||||
db.Pkgs = Pkgs
|
||||
db.GlobalDeps = GlobalDeps
|
||||
}
|
||||
13
enterprise/cmd/frontend/internal/assets/assets_dev.go
Normal file
13
enterprise/cmd/frontend/internal/assets/assets_dev.go
Normal file
@ -0,0 +1,13 @@
|
||||
// +build !dist
|
||||
|
||||
package assets
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/assets"
|
||||
)
|
||||
|
||||
func init() {
|
||||
assets.Assets = http.Dir("./ui/assets")
|
||||
}
|
||||
9
enterprise/cmd/frontend/internal/assets/assets_dist.go
Normal file
9
enterprise/cmd/frontend/internal/assets/assets_dist.go
Normal file
@ -0,0 +1,9 @@
|
||||
// +build dist
|
||||
|
||||
package assets
|
||||
|
||||
import "github.com/sourcegraph/sourcegraph/cmd/frontend/assets"
|
||||
|
||||
func init() {
|
||||
assets.Assets = DistAssets
|
||||
}
|
||||
22
enterprise/cmd/frontend/internal/assets/assets_generate.go
Normal file
22
enterprise/cmd/frontend/internal/assets/assets_generate.go
Normal file
@ -0,0 +1,22 @@
|
||||
// +build generate
|
||||
|
||||
package main
|
||||
|
||||
import (
|
||||
"log"
|
||||
"net/http"
|
||||
|
||||
"github.com/shurcooL/vfsgen"
|
||||
)
|
||||
|
||||
func main() {
|
||||
dir := "../../../../ui/assets/"
|
||||
err := vfsgen.Generate(http.Dir(dir), vfsgen.Options{
|
||||
PackageName: "assets",
|
||||
BuildTags: "dist",
|
||||
VariableName: "DistAssets",
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatalln(err)
|
||||
}
|
||||
}
|
||||
3
enterprise/cmd/frontend/internal/assets/doc.go
Normal file
3
enterprise/cmd/frontend/internal/assets/doc.go
Normal file
@ -0,0 +1,3 @@
|
||||
// Package assets contains static assets for the enterprise front-end Web app. It should be imported
|
||||
// for side-effects.
|
||||
package assets
|
||||
3
enterprise/cmd/frontend/internal/assets/gen.go
Normal file
3
enterprise/cmd/frontend/internal/assets/gen.go
Normal file
@ -0,0 +1,3 @@
|
||||
//go:generate go run assets_generate.go
|
||||
|
||||
package assets
|
||||
165
enterprise/cmd/frontend/internal/dotcom/billing/customers.go
Normal file
165
enterprise/cmd/frontend/internal/dotcom/billing/customers.go
Normal file
@ -0,0 +1,165 @@
|
||||
package billing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
multierror "github.com/hashicorp/go-multierror"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/db/dbconn"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
|
||||
stripe "github.com/stripe/stripe-go"
|
||||
"github.com/stripe/stripe-go/customer"
|
||||
log15 "gopkg.in/inconshreveable/log15.v2"
|
||||
)
|
||||
|
||||
// GetOrAssignUserCustomerID returns the billing customer ID associated with the user. If no billing
|
||||
// customer ID exists for the user, a new one is created and saved on the user's DB record.
|
||||
func GetOrAssignUserCustomerID(ctx context.Context, userID int32) (_ string, err error) {
|
||||
// Wrap this operation in a transaction so we never have stored 2 auto-created billing customer
|
||||
// IDs for the same user.
|
||||
tx, err := dbconn.Global.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
rollErr := tx.Rollback()
|
||||
if rollErr != nil {
|
||||
err = multierror.Append(err, rollErr)
|
||||
}
|
||||
return
|
||||
}
|
||||
err = tx.Commit()
|
||||
}()
|
||||
|
||||
custID, err := dbBilling{}.getUserBillingCustomerID(ctx, tx, userID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if custID == nil {
|
||||
// There is no billing customer ID for this user yet, so we must make one. This is not racy
|
||||
// w.r.t. the DB because we are still in a DB transaction. It is still possible for a race
|
||||
// condition to result in 2 billing customers being created, but only one of them would ever
|
||||
// be stored in our DB.
|
||||
newCustID, err := createCustomerID(ctx, userID)
|
||||
if err != nil {
|
||||
return "", errors.WithMessage(err, fmt.Sprintf("auto-creating customer ID for user ID %d", userID))
|
||||
}
|
||||
|
||||
// If we fail after here, then try to clean up the customer ID.
|
||||
defer func() {
|
||||
if err != nil {
|
||||
ctx, cancel := context.WithTimeout(ctx, 3*time.Second) // don't wait too long
|
||||
defer cancel()
|
||||
if err := deleteCustomerID(ctx, newCustID); err != nil {
|
||||
log15.Error("During cleanup of failed auto-creation of billing customer ID for user, failed to delete billing customer ID.", "userID", userID, "newCustomerID", newCustID, "err", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if err := (dbBilling{}).setUserBillingCustomerID(ctx, tx, userID, &newCustID); err != nil {
|
||||
return "", err
|
||||
}
|
||||
custID = &newCustID
|
||||
}
|
||||
return *custID, nil
|
||||
}
|
||||
|
||||
var (
|
||||
dummyCustomerMu sync.Mutex
|
||||
dummyCustomerID string
|
||||
)
|
||||
|
||||
// GetDummyCustomerID returns the customer ID for a dummy customer that must be used only for
|
||||
// pricing out invoices not associated with any particular customer. There is one dummy customer in
|
||||
// the billing system that is used for all such operations (because the billing system requires
|
||||
// providing a customer ID but we don't want to use any actual customer's ID).
|
||||
//
|
||||
// The first time this func is called, it looks up the ID of the existing dummy customer in the
|
||||
// billing system and returns that if one exists (to avoid creating many dummy customer records). If
|
||||
// the dummy customer doesn't exist yet, it is automatically created.
|
||||
func GetDummyCustomerID(ctx context.Context) (string, error) {
|
||||
dummyCustomerMu.Lock()
|
||||
defer dummyCustomerMu.Unlock()
|
||||
if dummyCustomerID == "" {
|
||||
// Look up dummy customer.
|
||||
const dummyCustomerEmail = "dummy-customer@example.com"
|
||||
listParams := &stripe.CustomerListParams{
|
||||
ListParams: stripe.ListParams{Context: ctx},
|
||||
}
|
||||
listParams.Filters.AddFilter("email", "", dummyCustomerEmail)
|
||||
listParams.Limit = stripe.Int64(1)
|
||||
customers := customer.List(listParams)
|
||||
if err := customers.Err(); err != nil {
|
||||
return "", err
|
||||
}
|
||||
if customers.Next() {
|
||||
dummyCustomerID = customers.Customer().ID
|
||||
} else {
|
||||
// No dummy customer exists yet, so create it. Future calls to GetDummyCustomerID will reuse the dummy customer.
|
||||
params := &stripe.CustomerParams{
|
||||
Params: stripe.Params{Context: ctx},
|
||||
Email: stripe.String(dummyCustomerEmail),
|
||||
Description: stripe.String("DUMMY (only used for generating quotes for unauthenticated viewers)"),
|
||||
}
|
||||
cust, err := customer.New(params)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
dummyCustomerID = cust.ID
|
||||
}
|
||||
}
|
||||
return dummyCustomerID, nil
|
||||
}
|
||||
|
||||
var mockCreateCustomerID func(userID int32) (string, error)
|
||||
|
||||
// createCustomerID creates a customer record on the billing system and returns the customer ID of
|
||||
// the new record.
|
||||
func createCustomerID(ctx context.Context, userID int32) (string, error) {
|
||||
if mockCreateCustomerID != nil {
|
||||
return mockCreateCustomerID(userID)
|
||||
}
|
||||
|
||||
user, err := graphqlbackend.UserByIDInt32(ctx, userID)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
custParams := &stripe.CustomerParams{
|
||||
Params: stripe.Params{Context: ctx},
|
||||
Description: stripe.String(fmt.Sprintf("%s (%d)", user.Username(), user.SourcegraphID())),
|
||||
}
|
||||
|
||||
// Use the user's first verified email (if any).
|
||||
emails, err := user.Emails(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
for _, email := range emails {
|
||||
if email.Verified() {
|
||||
custParams.Email = stripe.String(email.Email())
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// Create the billing customer.
|
||||
cust, err := customer.New(custParams)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return cust.ID, nil
|
||||
}
|
||||
|
||||
// deleteCustomerID deletes the customer record on the billing system.
|
||||
func deleteCustomerID(ctx context.Context, customerID string) error {
|
||||
// For simplicity of tests, just noop if the mockCreateCustomerID is set.
|
||||
if mockCreateCustomerID != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err := customer.Del(customerID, &stripe.CustomerParams{Params: stripe.Params{Context: ctx}})
|
||||
return err
|
||||
}
|
||||
@ -0,0 +1,53 @@
|
||||
package billing
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/backend"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
|
||||
stripe "github.com/stripe/stripe-go"
|
||||
"github.com/stripe/stripe-go/customer"
|
||||
)
|
||||
|
||||
func init() {
|
||||
graphqlbackend.UserURLForSiteAdminBilling = func(ctx context.Context, userID int32) (*string, error) {
|
||||
// 🚨 SECURITY: Only site admins may view the billing URL, because it may contain sensitive
|
||||
// data or identifiers.
|
||||
if err := backend.CheckCurrentUserIsSiteAdmin(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
custID, err := dbBilling{}.getUserBillingCustomerID(ctx, nil, userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if custID != nil {
|
||||
u := CustomerURL(*custID)
|
||||
return &u, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (BillingResolver) SetUserBilling(ctx context.Context, args *graphqlbackend.SetUserBillingArgs) (*graphqlbackend.EmptyResponse, error) {
|
||||
// 🚨 SECURITY: Only site admins may set a user's billing info.
|
||||
if err := backend.CheckCurrentUserIsSiteAdmin(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userID, err := graphqlbackend.UnmarshalUserID(args.User)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Ensure the billing customer ID refers to a valid customer in the billing system.
|
||||
if args.BillingCustomerID != nil {
|
||||
if _, err := customer.Get(*args.BillingCustomerID, &stripe.CustomerParams{Params: stripe.Params{Context: ctx}}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := (dbBilling{}).setUserBillingCustomerID(ctx, nil, userID, args.BillingCustomerID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &graphqlbackend.EmptyResponse{}, nil
|
||||
}
|
||||
@ -0,0 +1,45 @@
|
||||
package billing
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
|
||||
dbtesting "github.com/sourcegraph/sourcegraph/cmd/frontend/db/testing"
|
||||
)
|
||||
|
||||
func TestGetOrAssignUserCustomerID(t *testing.T) {
|
||||
ctx := dbtesting.TestContext(t)
|
||||
|
||||
c := 0
|
||||
mockCreateCustomerID = func(userID int32) (string, error) {
|
||||
c++
|
||||
return fmt.Sprintf("cust%d", c), nil
|
||||
}
|
||||
defer func() { mockCreateCustomerID = nil }()
|
||||
|
||||
u, err := db.Users.Create(ctx, db.NewUser{Username: "u"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
t.Run("assigns and retrieves", func(t *testing.T) {
|
||||
custID1, err := GetOrAssignUserCustomerID(ctx, u.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
custID2, err := GetOrAssignUserCustomerID(ctx, u.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if custID2 != custID1 {
|
||||
t.Errorf("got custID %q, want %q", custID2, custID2)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("fails on nonexistent users", func(t *testing.T) {
|
||||
if _, err := GetOrAssignUserCustomerID(ctx, 123 /* no such user */); err == nil {
|
||||
t.Fatal("err == nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
65
enterprise/cmd/frontend/internal/dotcom/billing/db.go
Normal file
65
enterprise/cmd/frontend/internal/dotcom/billing/db.go
Normal file
@ -0,0 +1,65 @@
|
||||
package billing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/keegancsmith/sqlf"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/db/dbconn"
|
||||
)
|
||||
|
||||
// dbBilling provides billing-related database operations.
|
||||
type dbBilling struct{}
|
||||
|
||||
// getUserBillingCustomerID gets the billing customer ID (if any) for a user.
|
||||
//
|
||||
// If a transaction tx is provided, the query is executed using the transaction. If tx is nil, the
|
||||
// global DB handle is used.
|
||||
func (dbBilling) getUserBillingCustomerID(ctx context.Context, tx *sql.Tx, userID int32) (billingCustomerID *string, err error) {
|
||||
var dbh dbh
|
||||
if tx != nil {
|
||||
dbh = tx
|
||||
} else {
|
||||
dbh = dbconn.Global
|
||||
}
|
||||
|
||||
query := sqlf.Sprintf("SELECT billing_customer_id FROM users WHERE id=%d AND deleted_at IS NULL", userID)
|
||||
err = dbh.QueryRowContext(ctx, query.Query(sqlf.PostgresBindVar), query.Args()...).Scan(&billingCustomerID)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, db.NewUserNotFoundError(userID)
|
||||
}
|
||||
return billingCustomerID, err
|
||||
}
|
||||
|
||||
// setUserBillingCustomerID sets or unsets the billing customer ID for a user.
|
||||
//
|
||||
// If a transaction tx is provided, the query is executed using the transaction. If tx is nil, the
|
||||
// global DB handle is used.
|
||||
func (dbBilling) setUserBillingCustomerID(ctx context.Context, tx *sql.Tx, userID int32, billingCustomerID *string) error {
|
||||
var dbh dbh
|
||||
if tx != nil {
|
||||
dbh = tx
|
||||
} else {
|
||||
dbh = dbconn.Global
|
||||
}
|
||||
|
||||
query := sqlf.Sprintf("UPDATE users SET billing_customer_id=%s WHERE id=%d AND deleted_at IS NULL", billingCustomerID, userID)
|
||||
res, err := dbh.ExecContext(ctx, query.Query(sqlf.PostgresBindVar), query.Args()...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nrows, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if nrows == 0 {
|
||||
return db.NewUserNotFoundError(userID)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type dbh interface {
|
||||
QueryRowContext(context.Context, string, ...interface{}) *sql.Row
|
||||
ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
|
||||
}
|
||||
66
enterprise/cmd/frontend/internal/dotcom/billing/db_test.go
Normal file
66
enterprise/cmd/frontend/internal/dotcom/billing/db_test.go
Normal file
@ -0,0 +1,66 @@
|
||||
package billing
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
|
||||
dbtesting "github.com/sourcegraph/sourcegraph/cmd/frontend/db/testing"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/errcode"
|
||||
)
|
||||
|
||||
func init() {
|
||||
dbtesting.DBNameSuffix = "billing"
|
||||
}
|
||||
|
||||
func TestDBUsersBillingCustomerID(t *testing.T) {
|
||||
ctx := dbtesting.TestContext(t)
|
||||
|
||||
t.Run("existing user", func(t *testing.T) {
|
||||
u, err := db.Users.Create(ctx, db.NewUser{Username: "u"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if custID, err := (dbBilling{}).getUserBillingCustomerID(ctx, nil, u.ID); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if custID != nil {
|
||||
t.Errorf("got %q, want nil", *custID)
|
||||
}
|
||||
|
||||
t.Run("set to non-nil", func(t *testing.T) {
|
||||
if err := (dbBilling{}).setUserBillingCustomerID(ctx, nil, u.ID, strptr("x")); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if custID, err := (dbBilling{}).getUserBillingCustomerID(ctx, nil, u.ID); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if want := "x"; custID == nil || *custID != want {
|
||||
t.Errorf("got %v, want %q", custID, want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("set to nil", func(t *testing.T) {
|
||||
if err := (dbBilling{}).setUserBillingCustomerID(ctx, nil, u.ID, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if custID, err := (dbBilling{}).getUserBillingCustomerID(ctx, nil, u.ID); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if custID != nil {
|
||||
t.Errorf("got %q, want nil", *custID)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
t.Run("nonexistent user", func(t *testing.T) {
|
||||
if _, err := (dbBilling{}).getUserBillingCustomerID(ctx, nil, 123 /* doesn't exist */); !errcode.IsNotFound(err) {
|
||||
t.Errorf("got %v, want errcode.IsNotFound(err) == true", err)
|
||||
}
|
||||
if err := (dbBilling{}).setUserBillingCustomerID(ctx, nil, 123 /* doesn't exist */, strptr("x")); !errcode.IsNotFound(err) {
|
||||
t.Errorf("got %v, want errcode.IsNotFound(err) == true", err)
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
func strptr(s string) *string {
|
||||
return &s
|
||||
}
|
||||
2
enterprise/cmd/frontend/internal/dotcom/billing/doc.go
Normal file
2
enterprise/cmd/frontend/internal/dotcom/billing/doc.go
Normal file
@ -0,0 +1,2 @@
|
||||
// Package billing handles subscription billing on Sourcegraph.com (via Stripe).
|
||||
package billing
|
||||
@ -0,0 +1,97 @@
|
||||
package billing
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
|
||||
stripe "github.com/stripe/stripe-go"
|
||||
)
|
||||
|
||||
// productSubscriptionEvent implements the GraphQL type ProductSubscriptionEvent.
|
||||
type productSubscriptionEvent struct {
|
||||
v *stripe.Event
|
||||
}
|
||||
|
||||
// ToProductSubscriptionEvent returns a resolver for the GraphQL type ProductSubscriptionEvent from
|
||||
// the given billing event.
|
||||
//
|
||||
// The okToShowUser return value reports whether the event should be shown to the user. It is false
|
||||
// for internal events (e.g., an invoice being marked uncollectible).
|
||||
func ToProductSubscriptionEvent(event *stripe.Event) (gqlEvent graphqlbackend.ProductSubscriptionEvent, okToShowUser bool) {
|
||||
_, _, okToShowUser = getProductSubscriptionEventInfo(event)
|
||||
return &productSubscriptionEvent{v: event}, okToShowUser
|
||||
}
|
||||
|
||||
// getProductSubscriptionEventInfo returns a nice title and description for the event. See
|
||||
// ToProductSubscriptionEvent for information about the okToShowUser return value.
|
||||
func getProductSubscriptionEventInfo(v *stripe.Event) (title, description string, okToShowUser bool) {
|
||||
switch v.Type {
|
||||
case "charge.succeeded":
|
||||
title = "Charge succeeded"
|
||||
okToShowUser = true
|
||||
|
||||
case "invoice.created":
|
||||
title = "Invoice created"
|
||||
okToShowUser = true
|
||||
case "invoice.payment_succeeded":
|
||||
title = "Invoice payment succeeded"
|
||||
description = fmt.Sprintf("An invoice payment of %s succeeded.", usdCentsToString(v.GetObjectValue("amount_paid")))
|
||||
okToShowUser = true
|
||||
case "invoice.payment_failed":
|
||||
title = "Invoice payment failed"
|
||||
description = fmt.Sprintf("An invoice payment of %s failed.", usdCentsToString(v.GetObjectValue("amount_paid")))
|
||||
okToShowUser = true
|
||||
case "invoice.sent":
|
||||
title = "Invoice email sent"
|
||||
okToShowUser = true
|
||||
case "invoice.updated":
|
||||
title = "Invoice updated"
|
||||
okToShowUser = true
|
||||
|
||||
default:
|
||||
title = v.Type
|
||||
}
|
||||
return title, description, okToShowUser
|
||||
}
|
||||
|
||||
func usdCentsToString(s string) string {
|
||||
// TODO(sqs): use a real currency lib
|
||||
usdCents, err := strconv.ParseFloat(s, 64)
|
||||
if err != nil {
|
||||
return "unknown amount"
|
||||
}
|
||||
return fmt.Sprintf("$%.2f", usdCents/100)
|
||||
}
|
||||
|
||||
func (r *productSubscriptionEvent) ID() string { return r.v.ID }
|
||||
|
||||
func (r *productSubscriptionEvent) Date() string {
|
||||
return time.Unix(r.v.Created, 0).Format(time.RFC3339)
|
||||
}
|
||||
|
||||
func (r *productSubscriptionEvent) Title() string {
|
||||
title, _, _ := getProductSubscriptionEventInfo(r.v)
|
||||
return title
|
||||
}
|
||||
|
||||
func (r *productSubscriptionEvent) Description() *string {
|
||||
_, description, _ := getProductSubscriptionEventInfo(r.v)
|
||||
if description == "" {
|
||||
return nil
|
||||
}
|
||||
return &description
|
||||
}
|
||||
|
||||
func (r *productSubscriptionEvent) URL() *string {
|
||||
var u string
|
||||
if strings.HasPrefix(r.v.Type, "invoice.") {
|
||||
u = r.v.GetObjectValue("hosted_invoice_url")
|
||||
}
|
||||
if u == "" {
|
||||
return nil
|
||||
}
|
||||
return &u
|
||||
}
|
||||
@ -0,0 +1,4 @@
|
||||
package billing
|
||||
|
||||
// BillingResolver implements the GraphQL Query and Mutation fields related to billing.
|
||||
type BillingResolver struct{}
|
||||
38
enterprise/cmd/frontend/internal/dotcom/billing/plans.go
Normal file
38
enterprise/cmd/frontend/internal/dotcom/billing/plans.go
Normal file
@ -0,0 +1,38 @@
|
||||
package billing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/sourcegraph/enterprise/cmd/frontend/internal/licensing"
|
||||
"github.com/sourcegraph/enterprise/pkg/license"
|
||||
stripe "github.com/stripe/stripe-go"
|
||||
"github.com/stripe/stripe-go/plan"
|
||||
)
|
||||
|
||||
// InfoForProductPlan returns the license key tags and min quantity that should be used for the
|
||||
// given product plan.
|
||||
//
|
||||
// License key tags indicate which product variant (e.g., Enterprise vs. Enterprise Starter), so
|
||||
// they are stored on the billing system in the metadata of the product plans.
|
||||
func InfoForProductPlan(ctx context.Context, planID string) (licenseTags []string, minQuantity *int32, err error) {
|
||||
params := &stripe.PlanParams{Params: stripe.Params{Context: ctx}}
|
||||
params.AddExpand("product")
|
||||
plan, err := plan.Get(planID, params)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
var tags []string
|
||||
switch {
|
||||
case plan.Product.Metadata["licenseTags"] != "":
|
||||
tags = license.ParseTagsInput(plan.Product.Metadata["licenseTags"])
|
||||
case plan.Product.Name == "Enterprise Starter":
|
||||
tags = licensing.EnterpriseStarterTags
|
||||
case plan.Product.Name == "Enterprise":
|
||||
tags = licensing.EnterpriseTags
|
||||
default:
|
||||
return nil, nil, fmt.Errorf("unable to determine license tags for plan %q (nickname %q)", planID, plan.Nickname)
|
||||
}
|
||||
return tags, ProductPlanMinQuantity(plan), nil
|
||||
}
|
||||
119
enterprise/cmd/frontend/internal/dotcom/billing/plans_graphql.go
Normal file
119
enterprise/cmd/frontend/internal/dotcom/billing/plans_graphql.go
Normal file
@ -0,0 +1,119 @@
|
||||
package billing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strconv"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
|
||||
stripe "github.com/stripe/stripe-go"
|
||||
"github.com/stripe/stripe-go/plan"
|
||||
)
|
||||
|
||||
// productPlan implements the GraphQL type ProductPlan.
|
||||
type productPlan struct {
|
||||
billingPlanID string
|
||||
productPlanID string
|
||||
name string
|
||||
pricePerUserPerYear int32
|
||||
minQuantity *int32
|
||||
tiersMode string
|
||||
planTiers []graphqlbackend.PlanTier
|
||||
}
|
||||
|
||||
// planTier implements the GraphQL type PlanTier.
|
||||
type planTier struct {
|
||||
unitAmount int64
|
||||
upTo int64
|
||||
}
|
||||
|
||||
func (r *productPlan) ProductPlanID() string { return r.productPlanID }
|
||||
func (r *productPlan) BillingPlanID() string { return r.billingPlanID }
|
||||
func (r *productPlan) Name() string { return r.name }
|
||||
func (r *productPlan) NameWithBrand() string { return "Sourcegraph " + r.name }
|
||||
func (r *productPlan) PricePerUserPerYear() int32 { return r.pricePerUserPerYear }
|
||||
func (r *productPlan) MinQuantity() *int32 { return r.minQuantity }
|
||||
func (r *productPlan) TiersMode() string { return r.tiersMode }
|
||||
func (r *productPlan) PlanTiers() []graphqlbackend.PlanTier {
|
||||
if r.planTiers == nil {
|
||||
return nil
|
||||
}
|
||||
return r.planTiers
|
||||
}
|
||||
|
||||
func (r *planTier) UnitAmount() int32 { return int32(r.unitAmount) }
|
||||
func (r *planTier) UpTo() int32 { return int32(r.upTo) }
|
||||
|
||||
// ToProductPlan returns a resolver for the GraphQL type ProductPlan from the given billing plan.
|
||||
func ToProductPlan(plan *stripe.Plan) (graphqlbackend.ProductPlan, error) {
|
||||
// Sanity check.
|
||||
if plan.Product.Name == "" {
|
||||
return nil, fmt.Errorf("unexpected empty product name for plan %q", plan.ID)
|
||||
}
|
||||
if plan.Currency != stripe.CurrencyUSD {
|
||||
return nil, fmt.Errorf("unexpected currency %q for plan %q", plan.Currency, plan.ID)
|
||||
}
|
||||
if plan.Interval != stripe.PlanIntervalYear {
|
||||
return nil, fmt.Errorf("unexpected plan interval %q for plan %q", plan.Interval, plan.ID)
|
||||
}
|
||||
if plan.IntervalCount != 1 {
|
||||
return nil, fmt.Errorf("unexpected plan interval count %d for plan %q", plan.IntervalCount, plan.ID)
|
||||
}
|
||||
|
||||
var tiers []graphqlbackend.PlanTier
|
||||
for _, tier := range plan.Tiers {
|
||||
tiers = append(tiers, &planTier{
|
||||
unitAmount: tier.UnitAmount,
|
||||
upTo: tier.UpTo,
|
||||
})
|
||||
}
|
||||
|
||||
return &productPlan{
|
||||
productPlanID: plan.Product.ID,
|
||||
billingPlanID: plan.ID,
|
||||
name: plan.Product.Name,
|
||||
pricePerUserPerYear: int32(plan.Amount),
|
||||
minQuantity: ProductPlanMinQuantity(plan),
|
||||
planTiers: tiers,
|
||||
tiersMode: plan.TiersMode,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// ProductPlanMinQuantity returns the plan's product's minQuantity metadata value, or nil if there
|
||||
// is none.
|
||||
func ProductPlanMinQuantity(plan *stripe.Plan) *int32 {
|
||||
if v, err := strconv.Atoi(plan.Product.Metadata["minQuantity"]); err == nil {
|
||||
tmp := int32(v)
|
||||
return &tmp
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// ProductPlans implements the GraphQL field Query.dotcom.productPlans.
|
||||
func (BillingResolver) ProductPlans(ctx context.Context) ([]graphqlbackend.ProductPlan, error) {
|
||||
params := &stripe.PlanListParams{
|
||||
ListParams: stripe.ListParams{Context: ctx},
|
||||
Active: stripe.Bool(true),
|
||||
}
|
||||
params.AddExpand("data.product")
|
||||
plans := plan.List(params)
|
||||
var gqlPlans []graphqlbackend.ProductPlan
|
||||
for plans.Next() {
|
||||
gqlPlan, err := ToProductPlan(plans.Plan())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
gqlPlans = append(gqlPlans, gqlPlan)
|
||||
}
|
||||
if err := plans.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Sort cheapest first (a reasonable assumption).
|
||||
sort.Slice(gqlPlans, func(i, j int) bool {
|
||||
return gqlPlans[i].PricePerUserPerYear() < gqlPlans[j].PricePerUserPerYear()
|
||||
})
|
||||
|
||||
return gqlPlans, nil
|
||||
}
|
||||
55
enterprise/cmd/frontend/internal/dotcom/billing/stripe.go
Normal file
55
enterprise/cmd/frontend/internal/dotcom/billing/stripe.go
Normal file
@ -0,0 +1,55 @@
|
||||
package billing
|
||||
|
||||
import (
|
||||
"log"
|
||||
"strings"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/app"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/env"
|
||||
stripe "github.com/stripe/stripe-go"
|
||||
)
|
||||
|
||||
var (
|
||||
stripeSecretKey = env.Get("STRIPE_SECRET_KEY", "", "billing: Stripe API secret key")
|
||||
stripePublishableKey = env.Get("STRIPE_PUBLISHABLE_KEY", "", "billing: Stripe API publishable key")
|
||||
stripeWebhookSecret = env.Get("STRIPE_WEBHOOK_SECRET", "", "billing: Stripe webhook secret")
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Sanity-check the Stripe keys (to help prevent mistakes where they got switched and the secret
|
||||
// key is published).
|
||||
if stripeSecretKey != "" && !strings.HasPrefix(stripeSecretKey, "sk_") {
|
||||
log.Fatal(`Invalid STRIPE_SECRET_KEY (must begin with "sk_").`)
|
||||
}
|
||||
if stripePublishableKey != "" && !strings.HasPrefix(stripePublishableKey, "pk_") {
|
||||
log.Fatal(`Invalid STRIPE_PUBLISHABLE_KEY (must begin with "pk_").`)
|
||||
}
|
||||
if (stripeSecretKey != "") != (stripePublishableKey != "") {
|
||||
log.Fatalf("Either zero or both of STRIPE_SECRET_KEY (set=%v) and STRIPE_PUBLISHABLE_KEY (set=%v) must be set.", stripeSecretKey != "", stripePublishableKey != "")
|
||||
}
|
||||
|
||||
stripe.Key = stripeSecretKey
|
||||
app.SetBillingPublishableKey(stripePublishableKey)
|
||||
}
|
||||
|
||||
func isTest() bool {
|
||||
return strings.Contains(stripe.Key, "_test_")
|
||||
}
|
||||
|
||||
func baseURL() string {
|
||||
u := "https://dashboard.stripe.com"
|
||||
if isTest() {
|
||||
u += "/test"
|
||||
}
|
||||
return u
|
||||
}
|
||||
|
||||
// CustomerURL returns the URL to the customer with the given ID on the billing system.
|
||||
func CustomerURL(id string) string {
|
||||
return baseURL() + "/customers/" + id
|
||||
}
|
||||
|
||||
// SubscriptionURL returns the URL to the subscription with the given ID on the billing system.
|
||||
func SubscriptionURL(id string) string {
|
||||
return baseURL() + "/subscriptions/" + id
|
||||
}
|
||||
@ -0,0 +1,31 @@
|
||||
package billing
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
|
||||
stripe "github.com/stripe/stripe-go"
|
||||
)
|
||||
|
||||
// ToSubscriptionItemsParams converts a value of GraphQL type ProductSubscriptionInput into a
|
||||
// subscription item parameter for the billing system.
|
||||
func ToSubscriptionItemsParams(input graphqlbackend.ProductSubscriptionInput) *stripe.SubscriptionItemsParams {
|
||||
return &stripe.SubscriptionItemsParams{
|
||||
Plan: stripe.String(input.BillingPlanID),
|
||||
Quantity: stripe.Int64(int64(input.UserCount)),
|
||||
}
|
||||
}
|
||||
|
||||
// GetSubscriptionItemIDToReplace returns the ID of the billing subscription item (used when
|
||||
// updating the subscription or previewing an invoice to do so). It also performs a good set of
|
||||
// sanity checks on the subscription that should be performed whenever the subscription is updated.
|
||||
func GetSubscriptionItemIDToReplace(billingSub *stripe.Subscription, billingCustomerID string) (string, error) {
|
||||
if billingSub.Customer.ID != billingCustomerID {
|
||||
return "", errors.New("product subscription's billing customer does not match the provided account parameter")
|
||||
}
|
||||
if len(billingSub.Items.Data) != 1 {
|
||||
return "", fmt.Errorf("product subscription has unexpected number of invoice items (got %d, want 1)", len(billingSub.Items.Data))
|
||||
}
|
||||
return billingSub.Items.Data[0].ID, nil
|
||||
}
|
||||
56
enterprise/cmd/frontend/internal/dotcom/billing/webhook.go
Normal file
56
enterprise/cmd/frontend/internal/dotcom/billing/webhook.go
Normal file
@ -0,0 +1,56 @@
|
||||
package billing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
||||
stripe "github.com/stripe/stripe-go"
|
||||
"github.com/stripe/stripe-go/webhook"
|
||||
log15 "gopkg.in/inconshreveable/log15.v2"
|
||||
)
|
||||
|
||||
// handleWebhook handles HTTP requests containing webhook payloads about billing-related events from
|
||||
// the billing system.
|
||||
func handleWebhook(w http.ResponseWriter, r *http.Request) {
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusServiceUnavailable)
|
||||
return
|
||||
}
|
||||
|
||||
// Check the signature to verify the HTTP request came from the billing system.
|
||||
event, err := webhook.ConstructEvent(body, r.Header.Get("Stripe-Signature"), stripeWebhookSecret)
|
||||
if err != nil {
|
||||
// Parse out some of the event for logging.
|
||||
var event struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
}
|
||||
_ = json.Unmarshal(body, &event)
|
||||
log15.Error("Billing webhook received request with invalid signature.", "idUnverified", event.ID, "typeUnverified", event.Type, "err", err)
|
||||
http.Error(w, "billing event signature is invalid", http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
|
||||
log15.Info("Billing webhook received event.", "id", event.ID, "type", event.Type)
|
||||
if err := handleEvent(r.Context(), event); err != nil {
|
||||
log15.Error("Billing webhook failed to handle event.", "id", event.ID, "type", event.Type, "err", err)
|
||||
http.Error(w, "billing event handler error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
// handleEvent handles a billing event (received via webhook).
|
||||
//
|
||||
// TODO(sqs): implement this so we can create invoices instead of only being able to accept
|
||||
// immediate payment.
|
||||
func handleEvent(ctx context.Context, event stripe.Event) error {
|
||||
switch event.Type {
|
||||
case "invoice.payment_succeeded":
|
||||
// noop
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -0,0 +1,7 @@
|
||||
package productsubscription
|
||||
|
||||
import dbtesting "github.com/sourcegraph/sourcegraph/cmd/frontend/db/testing"
|
||||
|
||||
func init() {
|
||||
dbtesting.DBNameSuffix = "productsubscription"
|
||||
}
|
||||
@ -0,0 +1,2 @@
|
||||
// Package productsubscription handles product subscriptions and licensing.
|
||||
package productsubscription
|
||||
@ -0,0 +1,5 @@
|
||||
package productsubscription
|
||||
|
||||
// ProductSubscriptionLicensingResolver implements the GraphQL Query and Mutation fields related to product
|
||||
// subscriptions and licensing.
|
||||
type ProductSubscriptionLicensingResolver struct{}
|
||||
@ -0,0 +1,190 @@
|
||||
package productsubscription
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sourcegraph/enterprise/cmd/frontend/internal/dotcom/billing"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/backend"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
|
||||
stripe "github.com/stripe/stripe-go"
|
||||
"github.com/stripe/stripe-go/invoice"
|
||||
"github.com/stripe/stripe-go/plan"
|
||||
"github.com/stripe/stripe-go/sub"
|
||||
)
|
||||
|
||||
type productSubscriptionPreviewInvoice struct {
|
||||
price int32
|
||||
amountDue int32
|
||||
prorationDate *int64
|
||||
before, after *productSubscriptionInvoiceItem
|
||||
}
|
||||
|
||||
func (r *productSubscriptionPreviewInvoice) Price() int32 { return r.price }
|
||||
func (r *productSubscriptionPreviewInvoice) AmountDue() int32 { return r.amountDue }
|
||||
func (r *productSubscriptionPreviewInvoice) ProrationDate() *string {
|
||||
if v := r.prorationDate; v != nil {
|
||||
s := time.Unix(*v, 0).Format(time.RFC3339)
|
||||
return &s
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *productSubscriptionPreviewInvoice) IsDowngradeRequiringManualIntervention() bool {
|
||||
return r.before != nil && isDowngradeRequiringManualIntervention(r.before.userCount, r.before.plan.Amount, r.after.userCount, r.after.plan.Amount)
|
||||
}
|
||||
|
||||
func isDowngradeRequiringManualIntervention(beforeUserCount int32, beforePlanPrice int64, afterUserCount int32, afterPlanPrice int64) bool {
|
||||
return afterUserCount < beforeUserCount || afterPlanPrice < beforePlanPrice
|
||||
}
|
||||
|
||||
func (r *productSubscriptionPreviewInvoice) BeforeInvoiceItem() graphqlbackend.ProductSubscriptionInvoiceItem {
|
||||
if r.before == nil {
|
||||
return nil // untyped nil is necessary for graphql-go
|
||||
}
|
||||
return r.before
|
||||
}
|
||||
|
||||
func (r *productSubscriptionPreviewInvoice) AfterInvoiceItem() graphqlbackend.ProductSubscriptionInvoiceItem {
|
||||
return r.after
|
||||
}
|
||||
|
||||
func (ProductSubscriptionLicensingResolver) PreviewProductSubscriptionInvoice(ctx context.Context, args *graphqlbackend.PreviewProductSubscriptionInvoiceArgs) (graphqlbackend.ProductSubscriptionPreviewInvoice, error) {
|
||||
// Support previewing an invoice with or without a customer ID.
|
||||
var custID string
|
||||
var accountUserID *int32
|
||||
if args.Account != nil {
|
||||
// There is a customer ID given.
|
||||
accountUser, err := graphqlbackend.UserByID(ctx, *args.Account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tmp := accountUser.SourcegraphID()
|
||||
accountUserID = &tmp
|
||||
custID, err = billing.GetOrAssignUserCustomerID(ctx, *accountUserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 🚨 SECURITY: Users may only preview invoices for their own product subscriptions. Site admins
|
||||
// may preview invoices for all product subscriptions.
|
||||
if err := backend.CheckSiteAdminOrSameUser(ctx, *accountUserID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// Support previewing an invoice without a customer ID, for unauthenticated viewers who just want
|
||||
// to see the price.
|
||||
if args.SubscriptionToUpdate != nil {
|
||||
return nil, errors.New("missing account ID argument (must be the owner of the subscriptionToUpdate)")
|
||||
}
|
||||
var err error
|
||||
custID, err = billing.GetDummyCustomerID(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Get the "before" subscription invoice item.
|
||||
planParams := &stripe.PlanParams{Params: stripe.Params{Context: ctx}}
|
||||
planParams.AddExpand("product")
|
||||
plan, err := plan.Get(args.ProductSubscription.BillingPlanID, planParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if minQuantity := billing.ProductPlanMinQuantity(plan); minQuantity != nil && args.ProductSubscription.UserCount < *minQuantity {
|
||||
args.ProductSubscription.UserCount = *minQuantity
|
||||
}
|
||||
result := productSubscriptionPreviewInvoice{
|
||||
after: &productSubscriptionInvoiceItem{
|
||||
plan: plan,
|
||||
userCount: args.ProductSubscription.UserCount,
|
||||
// The expiresAt field will be set below, not here, because its value depends on whether
|
||||
// this is a new vs. updated subscription.
|
||||
},
|
||||
}
|
||||
|
||||
params := &stripe.InvoiceParams{
|
||||
Params: stripe.Params{Context: ctx},
|
||||
Customer: stripe.String(custID),
|
||||
SubscriptionItems: []*stripe.SubscriptionItemsParams{billing.ToSubscriptionItemsParams(args.ProductSubscription)},
|
||||
}
|
||||
|
||||
if args.SubscriptionToUpdate != nil {
|
||||
// Update a subscription.
|
||||
//
|
||||
// When updating an existing subscription, craft the params to replace the existing subscription
|
||||
// item (otherwise the invoice would include both the existing and updated subscription items).
|
||||
subToUpdate, err := productSubscriptionByID(ctx, *args.SubscriptionToUpdate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 🚨 SECURITY: Only site admins and the subscription's account owner may preview invoices
|
||||
// for product subscriptions.
|
||||
if err := backend.CheckSiteAdminOrSameUser(ctx, subToUpdate.v.UserID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 🚨 SECURITY: Ensure that the subscription is owned by the account (i.e., that the
|
||||
// parameters are internally consistent). These checks are redundant for site admins, but
|
||||
// it's good to be robust against bugs.
|
||||
if subToUpdate.v.UserID != *accountUserID {
|
||||
return nil, errors.New("product subscription's account owner does not match the provided account parameter")
|
||||
}
|
||||
if subToUpdate.v.BillingSubscriptionID == nil {
|
||||
return nil, errors.New("unable to get preview invoice for product subscription that has no associated billing information")
|
||||
}
|
||||
|
||||
subParams := &stripe.SubscriptionParams{Params: stripe.Params{Context: ctx}}
|
||||
subParams.AddExpand("plan.product")
|
||||
billingSubToUpdate, err := sub.Get(*subToUpdate.v.BillingSubscriptionID, subParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
params.SubscriptionProrationDate = stripe.Int64(time.Now().Unix())
|
||||
params.Subscription = stripe.String(*subToUpdate.v.BillingSubscriptionID)
|
||||
params.SubscriptionProrate = stripe.Bool(true)
|
||||
idToReplace, err := billing.GetSubscriptionItemIDToReplace(billingSubToUpdate, custID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
params.SubscriptionItems[0].ID = stripe.String(idToReplace)
|
||||
|
||||
result.prorationDate = params.SubscriptionProrationDate
|
||||
result.before = &productSubscriptionInvoiceItem{
|
||||
plan: billingSubToUpdate.Plan,
|
||||
userCount: int32(billingSubToUpdate.Quantity),
|
||||
expiresAt: time.Unix(billingSubToUpdate.CurrentPeriodEnd, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// Get the preview invoice.
|
||||
invoice, err := invoice.GetNext(params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Calculate the price and expiration.
|
||||
for _, invoiceItem := range invoice.Lines.Data {
|
||||
// When updating an existing subscription, only include invoice items that are affected by
|
||||
// the update (== whose proration date is the same as the one we set on the update params).
|
||||
if result.prorationDate != nil && invoiceItem.Period.Start != *result.prorationDate {
|
||||
continue
|
||||
}
|
||||
result.price += int32(invoiceItem.Amount)
|
||||
|
||||
// Set the period end to the farthest ahead future invoice item's end date.
|
||||
periodEnd := time.Unix(invoiceItem.Period.End, 0)
|
||||
if periodEnd.After(result.after.expiresAt) {
|
||||
result.after.expiresAt = periodEnd
|
||||
}
|
||||
}
|
||||
|
||||
// When there is no change (no new invoice lines), set the "after" state to expire at the same
|
||||
// as the before state to indicate there is no change in the expiration either.
|
||||
if result.after.expiresAt.IsZero() && result.before != nil {
|
||||
result.after.expiresAt = result.before.expiresAt
|
||||
}
|
||||
|
||||
return &result, nil
|
||||
}
|
||||
@ -0,0 +1,145 @@
|
||||
package productsubscription
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/keegancsmith/sqlf"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/db/dbconn"
|
||||
)
|
||||
|
||||
// dbLicense describes an product license row in the product_licenses DB table.
|
||||
type dbLicense struct {
|
||||
ID string // UUID
|
||||
ProductSubscriptionID string // UUID
|
||||
LicenseKey string
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// errLicenseNotFound occurs when a database operation expects a specific Sourcegraph
|
||||
// license to exist but it does not exist.
|
||||
var errLicenseNotFound = errors.New("product license not found")
|
||||
|
||||
// dbLicenses exposes product licenses in the product_licenses DB table.
|
||||
type dbLicenses struct{}
|
||||
|
||||
// Create creates a new product license entry given a license key.
|
||||
func (dbLicenses) Create(ctx context.Context, subscriptionID, licenseKey string) (id string, err error) {
|
||||
if mocks.licenses.Create != nil {
|
||||
return mocks.licenses.Create(subscriptionID, licenseKey)
|
||||
}
|
||||
|
||||
uuid, err := uuid.NewRandom()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := dbconn.Global.QueryRowContext(ctx, `
|
||||
INSERT INTO product_licenses(id, product_subscription_id, license_key) VALUES($1, $2, $3) RETURNING id
|
||||
`,
|
||||
uuid, subscriptionID, licenseKey,
|
||||
).Scan(&id); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// GetByID retrieves the product license (if any) given its ID.
|
||||
//
|
||||
// 🚨 SECURITY: The caller must ensure that the actor is permitted to view this product license.
|
||||
func (s dbLicenses) GetByID(ctx context.Context, id string) (*dbLicense, error) {
|
||||
if mocks.licenses.GetByID != nil {
|
||||
return mocks.licenses.GetByID(id)
|
||||
}
|
||||
results, err := s.list(ctx, []*sqlf.Query{sqlf.Sprintf("id=%s", id)}, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(results) == 0 {
|
||||
return nil, errLicenseNotFound
|
||||
}
|
||||
return results[0], nil
|
||||
}
|
||||
|
||||
// GetByID retrieves the product license (if any) given its license key.
|
||||
func (s dbLicenses) GetByLicenseKey(ctx context.Context, licenseKey string) (*dbLicense, error) {
|
||||
if mocks.licenses.GetByLicenseKey != nil {
|
||||
return mocks.licenses.GetByLicenseKey(licenseKey)
|
||||
}
|
||||
results, err := s.list(ctx, []*sqlf.Query{sqlf.Sprintf("license_key=%s", licenseKey)}, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(results) == 0 {
|
||||
return nil, errLicenseNotFound
|
||||
}
|
||||
return results[0], nil
|
||||
}
|
||||
|
||||
// dbLicensesListOptions contains options for listing product licenses.
|
||||
type dbLicensesListOptions struct {
|
||||
LicenseKeySubstring string
|
||||
ProductSubscriptionID string // only list product licenses for this subscription (by UUID)
|
||||
*db.LimitOffset
|
||||
}
|
||||
|
||||
func (o dbLicensesListOptions) sqlConditions() []*sqlf.Query {
|
||||
conds := []*sqlf.Query{sqlf.Sprintf("TRUE")}
|
||||
if o.LicenseKeySubstring != "" {
|
||||
conds = append(conds, sqlf.Sprintf("license_key LIKE %s", "%"+o.LicenseKeySubstring+"%"))
|
||||
}
|
||||
if o.ProductSubscriptionID != "" {
|
||||
conds = append(conds, sqlf.Sprintf("product_subscription_id=%s", o.ProductSubscriptionID))
|
||||
}
|
||||
return conds
|
||||
}
|
||||
|
||||
// List lists all product licenses that satisfy the options.
|
||||
func (s dbLicenses) List(ctx context.Context, opt dbLicensesListOptions) ([]*dbLicense, error) {
|
||||
return s.list(ctx, opt.sqlConditions(), opt.LimitOffset)
|
||||
}
|
||||
|
||||
func (dbLicenses) list(ctx context.Context, conds []*sqlf.Query, limitOffset *db.LimitOffset) ([]*dbLicense, error) {
|
||||
q := sqlf.Sprintf(`
|
||||
SELECT id, product_subscription_id, license_key, created_at FROM product_licenses
|
||||
WHERE (%s)
|
||||
ORDER BY created_at DESC
|
||||
%s`,
|
||||
sqlf.Join(conds, ") AND ("),
|
||||
limitOffset.SQL(),
|
||||
)
|
||||
|
||||
rows, err := dbconn.Global.QueryContext(ctx, q.Query(sqlf.PostgresBindVar), q.Args()...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var results []*dbLicense
|
||||
for rows.Next() {
|
||||
var v dbLicense
|
||||
if err := rows.Scan(&v.ID, &v.ProductSubscriptionID, &v.LicenseKey, &v.CreatedAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
results = append(results, &v)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// Count counts all product licenses that satisfy the options (ignoring limit and offset).
|
||||
func (dbLicenses) Count(ctx context.Context, opt dbLicensesListOptions) (int, error) {
|
||||
q := sqlf.Sprintf("SELECT COUNT(*) FROM product_licenses WHERE (%s)", sqlf.Join(opt.sqlConditions(), ") AND ("))
|
||||
var count int
|
||||
if err := dbconn.Global.QueryRowContext(ctx, q.Query(sqlf.PostgresBindVar), q.Args()...).Scan(&count); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
type mockLicenses struct {
|
||||
Create func(subscriptionID, licenseKey string) (id string, err error)
|
||||
GetByID func(id string) (*dbLicense, error)
|
||||
GetByLicenseKey func(licenseKey string) (*dbLicense, error)
|
||||
}
|
||||
@ -0,0 +1,127 @@
|
||||
package productsubscription
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
|
||||
dbtesting "github.com/sourcegraph/sourcegraph/cmd/frontend/db/testing"
|
||||
)
|
||||
|
||||
func TestProductLicenses_Create(t *testing.T) {
|
||||
ctx := dbtesting.TestContext(t)
|
||||
|
||||
u, err := db.Users.Create(ctx, db.NewUser{Username: "u"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ps0, err := dbSubscriptions{}.Create(ctx, u.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
pl0, err := dbLicenses{}.Create(ctx, ps0, "k")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got, err := dbLicenses{}.GetByID(ctx, pl0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if want := pl0; got.ID != want {
|
||||
t.Errorf("got %v, want %v", got.ID, want)
|
||||
}
|
||||
if want := ps0; got.ProductSubscriptionID != want {
|
||||
t.Errorf("got %v, want %v", got.ProductSubscriptionID, want)
|
||||
}
|
||||
if want := "k"; got.LicenseKey != want {
|
||||
t.Errorf("got %q, want %q", got.LicenseKey, want)
|
||||
}
|
||||
|
||||
ts, err := dbLicenses{}.List(ctx, dbLicensesListOptions{ProductSubscriptionID: ps0})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if want := 1; len(ts) != want {
|
||||
t.Errorf("got %d product licenses, want %d", len(ts), want)
|
||||
}
|
||||
|
||||
ts, err = dbLicenses{}.List(ctx, dbLicensesListOptions{ProductSubscriptionID: "69da12d5-323c-4e42-9d44-cc7951639bca" /* invalid */})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if want := 0; len(ts) != want {
|
||||
t.Errorf("got %d product licenses, want %d", len(ts), want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProductLicenses_List(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip()
|
||||
}
|
||||
ctx := dbtesting.TestContext(t)
|
||||
|
||||
u1, err := db.Users.Create(ctx, db.NewUser{Username: "u1"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
ps0, err := dbSubscriptions{}.Create(ctx, u1.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
ps1, err := dbSubscriptions{}.Create(ctx, u1.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = dbLicenses{}.Create(ctx, ps0, "k")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = dbLicenses{}.Create(ctx, ps0, "n1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
{
|
||||
// List all product licenses.
|
||||
ts, err := dbLicenses{}.List(ctx, dbLicensesListOptions{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if want := 2; len(ts) != want {
|
||||
t.Errorf("got %d product licenses, want %d", len(ts), want)
|
||||
}
|
||||
count, err := dbLicenses{}.Count(ctx, dbLicensesListOptions{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if want := 2; count != want {
|
||||
t.Errorf("got %d, want %d", count, want)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// List ps0's product licenses.
|
||||
ts, err := dbLicenses{}.List(ctx, dbLicensesListOptions{ProductSubscriptionID: ps0})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if want := 2; len(ts) != want {
|
||||
t.Errorf("got %d product licenses, want %d", len(ts), want)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// List ps1's product licenses.
|
||||
ts, err := dbLicenses{}.List(ctx, dbLicensesListOptions{ProductSubscriptionID: ps1})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if want := 0; len(ts) != want {
|
||||
t.Errorf("got %d product licenses, want %d", len(ts), want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,201 @@
|
||||
package productsubscription
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
graphql "github.com/graph-gophers/graphql-go"
|
||||
"github.com/graph-gophers/graphql-go/relay"
|
||||
"github.com/sourcegraph/enterprise/cmd/frontend/internal/licensing"
|
||||
"github.com/sourcegraph/enterprise/pkg/license"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/backend"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend/graphqlutil"
|
||||
)
|
||||
|
||||
func init() {
|
||||
graphqlbackend.ProductLicenseByID = func(ctx context.Context, id graphql.ID) (graphqlbackend.ProductLicense, error) {
|
||||
return productLicenseByID(ctx, id)
|
||||
}
|
||||
}
|
||||
|
||||
// productLicense implements the GraphQL type ProductLicense.
|
||||
type productLicense struct {
|
||||
v *dbLicense
|
||||
}
|
||||
|
||||
// productLicenseByID looks up and returns the ProductLicense with the given GraphQL ID. If no such
|
||||
// ProductLicense exists, it returns a non-nil error.
|
||||
func productLicenseByID(ctx context.Context, id graphql.ID) (*productLicense, error) {
|
||||
idInt32, err := unmarshalProductLicenseID(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return productLicenseByDBID(ctx, idInt32)
|
||||
}
|
||||
|
||||
// productLicenseByDBID looks up and returns the ProductLicense with the given database ID. If no
|
||||
// such ProductLicense exists, it returns a non-nil error.
|
||||
func productLicenseByDBID(ctx context.Context, id string) (*productLicense, error) {
|
||||
v, err := dbLicenses{}.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 🚨 SECURITY: Only site admins and the license's subscription's account's user may view a
|
||||
// product license.
|
||||
sub, err := productSubscriptionByDBID(ctx, v.ProductSubscriptionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := backend.CheckSiteAdminOrSameUser(ctx, sub.v.UserID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &productLicense{v: v}, nil
|
||||
}
|
||||
|
||||
func (r *productLicense) ID() graphql.ID {
|
||||
return marshalProductLicenseID(r.v.ID)
|
||||
}
|
||||
|
||||
func marshalProductLicenseID(id string) graphql.ID {
|
||||
return relay.MarshalID("ProductLicense", id)
|
||||
}
|
||||
|
||||
func unmarshalProductLicenseID(id graphql.ID) (productLicenseID string, err error) {
|
||||
err = relay.UnmarshalSpec(id, &productLicenseID)
|
||||
return
|
||||
}
|
||||
|
||||
func (r *productLicense) Subscription(ctx context.Context) (graphqlbackend.ProductSubscription, error) {
|
||||
return productSubscriptionByDBID(ctx, r.v.ProductSubscriptionID)
|
||||
}
|
||||
|
||||
func (r *productLicense) Info() (*graphqlbackend.ProductLicenseInfo, error) {
|
||||
// Call this instead of licensing.ParseProductLicenseKey so that license info can be read from
|
||||
// license keys generated using the test license generation private key.
|
||||
info, err := licensing.ParseProductLicenseKeyWithBuiltinOrGenerationKey(r.v.LicenseKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &graphqlbackend.ProductLicenseInfo{
|
||||
TagsValue: info.Tags,
|
||||
UserCountValue: info.UserCount,
|
||||
ExpiresAtValue: info.ExpiresAt,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *productLicense) LicenseKey() string { return r.v.LicenseKey }
|
||||
|
||||
func (r *productLicense) CreatedAt() string {
|
||||
return r.v.CreatedAt.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
func generateProductLicenseForSubscription(ctx context.Context, subscriptionID string, input *graphqlbackend.ProductLicenseInput) (id string, err error) {
|
||||
licenseKey, err := licensing.GenerateProductLicenseKey(license.Info{
|
||||
Tags: input.Tags,
|
||||
UserCount: uint(input.UserCount),
|
||||
ExpiresAt: time.Unix(int64(input.ExpiresAt), 0),
|
||||
})
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return dbLicenses{}.Create(ctx, subscriptionID, licenseKey)
|
||||
}
|
||||
|
||||
func (ProductSubscriptionLicensingResolver) GenerateProductLicenseForSubscription(ctx context.Context, args *graphqlbackend.GenerateProductLicenseForSubscriptionArgs) (graphqlbackend.ProductLicense, error) {
|
||||
// 🚨 SECURITY: Only site admins may generate product licenses.
|
||||
if err := backend.CheckCurrentUserIsSiteAdmin(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sub, err := productSubscriptionByID(ctx, args.ProductSubscriptionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
id, err := generateProductLicenseForSubscription(ctx, sub.v.ID, args.License)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return productLicenseByDBID(ctx, id)
|
||||
}
|
||||
|
||||
func (ProductSubscriptionLicensingResolver) ProductLicenses(ctx context.Context, args *graphqlbackend.ProductLicensesArgs) (graphqlbackend.ProductLicenseConnection, error) {
|
||||
// 🚨 SECURITY: Only site admins may list product licenses.
|
||||
if err := backend.CheckCurrentUserIsSiteAdmin(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var sub *productSubscription
|
||||
if args.ProductSubscriptionID != nil {
|
||||
var err error
|
||||
sub, err = productSubscriptionByID(ctx, *args.ProductSubscriptionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var opt dbLicensesListOptions
|
||||
if sub != nil {
|
||||
opt.ProductSubscriptionID = sub.v.ID
|
||||
}
|
||||
if args.LicenseKeySubstring != nil {
|
||||
opt.LicenseKeySubstring = *args.LicenseKeySubstring
|
||||
}
|
||||
args.ConnectionArgs.Set(&opt.LimitOffset)
|
||||
return &productLicenseConnection{opt: opt}, nil
|
||||
}
|
||||
|
||||
// productLicenseConnection implements the GraphQL type ProductLicenseConnection.
|
||||
//
|
||||
// 🚨 SECURITY: When instantiating a productLicenseConnection value, the caller MUST
|
||||
// check permissions.
|
||||
type productLicenseConnection struct {
|
||||
opt dbLicensesListOptions
|
||||
|
||||
// cache results because they are used by multiple fields
|
||||
once sync.Once
|
||||
results []*dbLicense
|
||||
err error
|
||||
}
|
||||
|
||||
func (r *productLicenseConnection) compute(ctx context.Context) ([]*dbLicense, error) {
|
||||
r.once.Do(func() {
|
||||
opt2 := r.opt
|
||||
if opt2.LimitOffset != nil {
|
||||
tmp := *opt2.LimitOffset
|
||||
opt2.LimitOffset = &tmp
|
||||
opt2.Limit++ // so we can detect if there is a next page
|
||||
}
|
||||
|
||||
r.results, r.err = dbLicenses{}.List(ctx, opt2)
|
||||
})
|
||||
return r.results, r.err
|
||||
}
|
||||
|
||||
func (r *productLicenseConnection) Nodes(ctx context.Context) ([]graphqlbackend.ProductLicense, error) {
|
||||
results, err := r.compute(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var l []graphqlbackend.ProductLicense
|
||||
for _, result := range results {
|
||||
l = append(l, &productLicense{v: result})
|
||||
}
|
||||
return l, nil
|
||||
}
|
||||
|
||||
func (r *productLicenseConnection) TotalCount(ctx context.Context) (int32, error) {
|
||||
count, err := dbLicenses{}.Count(ctx, r.opt)
|
||||
return int32(count), err
|
||||
}
|
||||
|
||||
func (r *productLicenseConnection) PageInfo(ctx context.Context) (*graphqlutil.PageInfo, error) {
|
||||
results, err := r.compute(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return graphqlutil.HasNextPage(r.opt.LimitOffset != nil && len(results) > r.opt.Limit), nil
|
||||
}
|
||||
@ -0,0 +1,12 @@
|
||||
package productsubscription
|
||||
|
||||
func resetMocks() {
|
||||
mocks = dbMocks{}
|
||||
}
|
||||
|
||||
type dbMocks struct {
|
||||
subscriptions mockSubscriptions
|
||||
licenses mockLicenses
|
||||
}
|
||||
|
||||
var mocks dbMocks
|
||||
@ -0,0 +1,49 @@
|
||||
package productsubscription
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/sourcegraph/enterprise/cmd/frontend/internal/dotcom/billing"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
|
||||
stripe "github.com/stripe/stripe-go"
|
||||
"github.com/stripe/stripe-go/sub"
|
||||
)
|
||||
|
||||
func (r *productSubscription) InvoiceItem(ctx context.Context) (graphqlbackend.ProductSubscriptionInvoiceItem, error) {
|
||||
if r.v.BillingSubscriptionID == nil {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
params := &stripe.SubscriptionParams{Params: stripe.Params{Context: ctx}}
|
||||
params.AddExpand("plan.product")
|
||||
billingSub, err := sub.Get(*r.v.BillingSubscriptionID, params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &productSubscriptionInvoiceItem{
|
||||
plan: billingSub.Plan,
|
||||
userCount: int32(billingSub.Quantity),
|
||||
expiresAt: time.Unix(billingSub.CurrentPeriodEnd, 0),
|
||||
}, nil
|
||||
}
|
||||
|
||||
type productSubscriptionInvoiceItem struct {
|
||||
plan *stripe.Plan
|
||||
userCount int32
|
||||
expiresAt time.Time
|
||||
}
|
||||
|
||||
var _ graphqlbackend.ProductSubscriptionInvoiceItem = &productSubscriptionInvoiceItem{}
|
||||
|
||||
func (r *productSubscriptionInvoiceItem) Plan() (graphqlbackend.ProductPlan, error) {
|
||||
return billing.ToProductPlan(r.plan)
|
||||
}
|
||||
|
||||
func (r *productSubscriptionInvoiceItem) UserCount() int32 {
|
||||
return r.userCount
|
||||
}
|
||||
|
||||
func (r *productSubscriptionInvoiceItem) ExpiresAt() string {
|
||||
return r.expiresAt.Format(time.RFC3339)
|
||||
}
|
||||
@ -0,0 +1,186 @@
|
||||
package productsubscription
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/keegancsmith/sqlf"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/db/dbconn"
|
||||
)
|
||||
|
||||
// dbSubscription describes an product subscription row in the product_subscriptions DB
|
||||
// table.
|
||||
type dbSubscription struct {
|
||||
ID string // UUID
|
||||
UserID int32
|
||||
BillingSubscriptionID *string // this subscription's ID in the billing system
|
||||
CreatedAt time.Time
|
||||
ArchivedAt *time.Time
|
||||
}
|
||||
|
||||
// errSubscriptionNotFound occurs when a database operation expects a specific Sourcegraph
|
||||
// license to exist but it does not exist.
|
||||
var errSubscriptionNotFound = errors.New("product subscription not found")
|
||||
|
||||
// dbSubscriptions exposes product subscriptions in the product_subscriptions DB table.
|
||||
type dbSubscriptions struct{}
|
||||
|
||||
// Create creates a new product subscription entry given a license key.
|
||||
func (dbSubscriptions) Create(ctx context.Context, userID int32) (id string, err error) {
|
||||
if mocks.subscriptions.Create != nil {
|
||||
return mocks.subscriptions.Create(userID)
|
||||
}
|
||||
|
||||
uuid, err := uuid.NewRandom()
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := dbconn.Global.QueryRowContext(ctx, `
|
||||
INSERT INTO product_subscriptions(id, user_id) VALUES($1, $2) RETURNING id
|
||||
`,
|
||||
uuid, userID,
|
||||
).Scan(&id); err != nil {
|
||||
return "", err
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// GetByID retrieves the product subscription (if any) given its ID.
|
||||
//
|
||||
// 🚨 SECURITY: The caller must ensure that the actor is permitted to view this product subscription.
|
||||
func (s dbSubscriptions) GetByID(ctx context.Context, id string) (*dbSubscription, error) {
|
||||
if mocks.subscriptions.GetByID != nil {
|
||||
return mocks.subscriptions.GetByID(id)
|
||||
}
|
||||
results, err := s.list(ctx, []*sqlf.Query{sqlf.Sprintf("id=%s", id)}, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(results) == 0 {
|
||||
return nil, errSubscriptionNotFound
|
||||
}
|
||||
return results[0], nil
|
||||
}
|
||||
|
||||
// dbSubscriptionsListOptions contains options for listing product subscriptions.
|
||||
type dbSubscriptionsListOptions struct {
|
||||
UserID int32 // only list product subscriptions for this user
|
||||
IncludeArchived bool
|
||||
*db.LimitOffset
|
||||
}
|
||||
|
||||
func (o dbSubscriptionsListOptions) sqlConditions() []*sqlf.Query {
|
||||
conds := []*sqlf.Query{sqlf.Sprintf("TRUE")}
|
||||
if o.UserID != 0 {
|
||||
conds = append(conds, sqlf.Sprintf("user_id=%d", o.UserID))
|
||||
}
|
||||
if !o.IncludeArchived {
|
||||
conds = append(conds, sqlf.Sprintf("archived_at IS NULL"))
|
||||
}
|
||||
return conds
|
||||
}
|
||||
|
||||
// List lists all product subscriptions that satisfy the options.
|
||||
func (s dbSubscriptions) List(ctx context.Context, opt dbSubscriptionsListOptions) ([]*dbSubscription, error) {
|
||||
return s.list(ctx, opt.sqlConditions(), opt.LimitOffset)
|
||||
}
|
||||
|
||||
func (dbSubscriptions) list(ctx context.Context, conds []*sqlf.Query, limitOffset *db.LimitOffset) ([]*dbSubscription, error) {
|
||||
q := sqlf.Sprintf(`
|
||||
SELECT id, user_id, billing_subscription_id, created_at, archived_at FROM product_subscriptions
|
||||
WHERE (%s)
|
||||
ORDER BY archived_at DESC NULLS FIRST, created_at DESC
|
||||
%s`,
|
||||
sqlf.Join(conds, ") AND ("),
|
||||
limitOffset.SQL(),
|
||||
)
|
||||
|
||||
rows, err := dbconn.Global.QueryContext(ctx, q.Query(sqlf.PostgresBindVar), q.Args()...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var results []*dbSubscription
|
||||
for rows.Next() {
|
||||
var v dbSubscription
|
||||
if err := rows.Scan(&v.ID, &v.UserID, &v.BillingSubscriptionID, &v.CreatedAt, &v.ArchivedAt); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
results = append(results, &v)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// Count counts all product subscriptions that satisfy the options (ignoring limit and offset).
|
||||
func (dbSubscriptions) Count(ctx context.Context, opt dbSubscriptionsListOptions) (int, error) {
|
||||
q := sqlf.Sprintf("SELECT COUNT(*) FROM product_subscriptions WHERE (%s)", sqlf.Join(opt.sqlConditions(), ") AND ("))
|
||||
var count int
|
||||
if err := dbconn.Global.QueryRowContext(ctx, q.Query(sqlf.PostgresBindVar), q.Args()...).Scan(&count); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// dbSubscriptionsUpdate represents an update to a product subscription in the database. Each field
|
||||
// represents an update to the corresponding database field if the Go value is non-nil. If the Go
|
||||
// value is nil, the field remains unchanged in the database.
|
||||
type dbSubscriptionUpdate struct {
|
||||
billingSubscriptionID *sql.NullString
|
||||
}
|
||||
|
||||
// Update updates a product subscription.
|
||||
func (dbSubscriptions) Update(ctx context.Context, id string, update dbSubscriptionUpdate) error {
|
||||
fieldUpdates := []*sqlf.Query{
|
||||
sqlf.Sprintf("updated_at=now()"), // always update updated_at timestamp
|
||||
}
|
||||
if v := update.billingSubscriptionID; v != nil {
|
||||
fieldUpdates = append(fieldUpdates, sqlf.Sprintf("billing_subscription_id=%s", *v))
|
||||
}
|
||||
|
||||
query := sqlf.Sprintf("UPDATE product_subscriptions SET %s WHERE id=%s", sqlf.Join(fieldUpdates, ", "), id)
|
||||
res, err := dbconn.Global.ExecContext(ctx, query.Query(sqlf.PostgresBindVar), query.Args()...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nrows, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if nrows == 0 {
|
||||
return errSubscriptionNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Archive marks a product subscription as archived given its ID.
|
||||
//
|
||||
// 🚨 SECURITY: The caller must ensure that the actor is permitted to archive the token.
|
||||
func (dbSubscriptions) Archive(ctx context.Context, id string) error {
|
||||
if mocks.subscriptions.Archive != nil {
|
||||
return mocks.subscriptions.Archive(id)
|
||||
}
|
||||
q := sqlf.Sprintf("UPDATE product_subscriptions SET archived_at=now(), updated_at=now() WHERE id=%s AND archived_at IS NULL", id)
|
||||
res, err := dbconn.Global.ExecContext(ctx, q.Query(sqlf.PostgresBindVar), q.Args()...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nrows, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if nrows == 0 {
|
||||
return errSubscriptionNotFound
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type mockSubscriptions struct {
|
||||
Create func(userID int32) (id string, err error)
|
||||
GetByID func(id string) (*dbSubscription, error)
|
||||
Archive func(id string) error
|
||||
}
|
||||
@ -0,0 +1,176 @@
|
||||
package productsubscription
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
|
||||
dbtesting "github.com/sourcegraph/sourcegraph/cmd/frontend/db/testing"
|
||||
)
|
||||
|
||||
func TestProductSubscriptions_Create(t *testing.T) {
|
||||
ctx := dbtesting.TestContext(t)
|
||||
|
||||
u, err := db.Users.Create(ctx, db.NewUser{Username: "u"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
sub0, err := dbSubscriptions{}.Create(ctx, u.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
got, err := dbSubscriptions{}.GetByID(ctx, sub0)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if want := sub0; got.ID != want {
|
||||
t.Errorf("got %v, want %v", got.ID, want)
|
||||
}
|
||||
if want := u.ID; got.UserID != want {
|
||||
t.Errorf("got %v, want %v", got.UserID, want)
|
||||
}
|
||||
if got.BillingSubscriptionID != nil {
|
||||
t.Errorf("got %v, want nil", got.BillingSubscriptionID)
|
||||
}
|
||||
|
||||
ts, err := dbSubscriptions{}.List(ctx, dbSubscriptionsListOptions{UserID: u.ID})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if want := 1; len(ts) != want {
|
||||
t.Errorf("got %d product subscriptions, want %d", len(ts), want)
|
||||
}
|
||||
|
||||
ts, err = dbSubscriptions{}.List(ctx, dbSubscriptionsListOptions{UserID: 123 /* invalid */})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if want := 0; len(ts) != want {
|
||||
t.Errorf("got %d product subscriptions, want %d", len(ts), want)
|
||||
}
|
||||
}
|
||||
|
||||
func TestProductSubscriptions_List(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip()
|
||||
}
|
||||
ctx := dbtesting.TestContext(t)
|
||||
|
||||
u1, err := db.Users.Create(ctx, db.NewUser{Username: "u1"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
u2, err := db.Users.Create(ctx, db.NewUser{Username: "u2"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = dbSubscriptions{}.Create(ctx, u1.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
_, err = dbSubscriptions{}.Create(ctx, u1.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
{
|
||||
// List all product subscriptions.
|
||||
ts, err := dbSubscriptions{}.List(ctx, dbSubscriptionsListOptions{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if want := 2; len(ts) != want {
|
||||
t.Errorf("got %d product subscriptions, want %d", len(ts), want)
|
||||
}
|
||||
count, err := dbSubscriptions{}.Count(ctx, dbSubscriptionsListOptions{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if want := 2; count != want {
|
||||
t.Errorf("got %d, want %d", count, want)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// List u1's product subscriptions.
|
||||
ts, err := dbSubscriptions{}.List(ctx, dbSubscriptionsListOptions{UserID: u1.ID})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if want := 2; len(ts) != want {
|
||||
t.Errorf("got %d product subscriptions, want %d", len(ts), want)
|
||||
}
|
||||
}
|
||||
|
||||
{
|
||||
// List u2's product subscriptions.
|
||||
ts, err := dbSubscriptions{}.List(ctx, dbSubscriptionsListOptions{UserID: u2.ID})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if want := 0; len(ts) != want {
|
||||
t.Errorf("got %d product subscriptions, want %d", len(ts), want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestProductSubscriptions_Update(t *testing.T) {
|
||||
ctx := dbtesting.TestContext(t)
|
||||
|
||||
u, err := db.Users.Create(ctx, db.NewUser{Username: "u"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
sub0, err := dbSubscriptions{}.Create(ctx, u.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got, err := (dbSubscriptions{}).GetByID(ctx, sub0); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if got.BillingSubscriptionID != nil {
|
||||
t.Errorf("got %q, want nil", *got.BillingSubscriptionID)
|
||||
}
|
||||
|
||||
// Set non-null value.
|
||||
if err := (dbSubscriptions{}).Update(ctx, sub0, dbSubscriptionUpdate{
|
||||
billingSubscriptionID: &sql.NullString{
|
||||
String: "x",
|
||||
Valid: true,
|
||||
},
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got, err := (dbSubscriptions{}).GetByID(ctx, sub0); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if want := "x"; got.BillingSubscriptionID == nil || *got.BillingSubscriptionID != want {
|
||||
t.Errorf("got %v, want %q", got.BillingSubscriptionID, want)
|
||||
}
|
||||
|
||||
// Update no fields.
|
||||
if err := (dbSubscriptions{}).Update(ctx, sub0, dbSubscriptionUpdate{billingSubscriptionID: nil}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got, err := (dbSubscriptions{}).GetByID(ctx, sub0); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if want := "x"; got.BillingSubscriptionID == nil || *got.BillingSubscriptionID != want {
|
||||
t.Errorf("got %v, want %q", got.BillingSubscriptionID, want)
|
||||
}
|
||||
|
||||
// Set null value.
|
||||
if err := (dbSubscriptions{}).Update(ctx, sub0, dbSubscriptionUpdate{
|
||||
billingSubscriptionID: &sql.NullString{Valid: false},
|
||||
}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if got, err := (dbSubscriptions{}).GetByID(ctx, sub0); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if got.BillingSubscriptionID != nil {
|
||||
t.Errorf("got %q, want nil", *got.BillingSubscriptionID)
|
||||
}
|
||||
}
|
||||
|
||||
func strptr(s string) *string { return &s }
|
||||
@ -0,0 +1,520 @@
|
||||
package productsubscription
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
graphql "github.com/graph-gophers/graphql-go"
|
||||
"github.com/graph-gophers/graphql-go/relay"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sourcegraph/enterprise/cmd/frontend/internal/dotcom/billing"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/backend"
|
||||
db_ "github.com/sourcegraph/sourcegraph/cmd/frontend/db"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend/graphqlutil"
|
||||
stripe "github.com/stripe/stripe-go"
|
||||
"github.com/stripe/stripe-go/customer"
|
||||
"github.com/stripe/stripe-go/event"
|
||||
"github.com/stripe/stripe-go/invoice"
|
||||
"github.com/stripe/stripe-go/plan"
|
||||
"github.com/stripe/stripe-go/sub"
|
||||
)
|
||||
|
||||
func init() {
|
||||
graphqlbackend.ProductSubscriptionByID = func(ctx context.Context, id graphql.ID) (graphqlbackend.ProductSubscription, error) {
|
||||
return productSubscriptionByID(ctx, id)
|
||||
}
|
||||
}
|
||||
|
||||
// productSubscription implements the GraphQL type ProductSubscription.
|
||||
type productSubscription struct {
|
||||
v *dbSubscription
|
||||
}
|
||||
|
||||
// productSubscriptionByID looks up and returns the ProductSubscription with the given GraphQL
|
||||
// ID. If no such ProductSubscription exists, it returns a non-nil error.
|
||||
func productSubscriptionByID(ctx context.Context, id graphql.ID) (*productSubscription, error) {
|
||||
idString, err := unmarshalProductSubscriptionID(id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return productSubscriptionByDBID(ctx, idString)
|
||||
}
|
||||
|
||||
// productSubscriptionByDBID looks up and returns the ProductSubscription with the given database
|
||||
// ID. If no such ProductSubscription exists, it returns a non-nil error.
|
||||
func productSubscriptionByDBID(ctx context.Context, id string) (*productSubscription, error) {
|
||||
v, err := dbSubscriptions{}.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// 🚨 SECURITY: Only site admins and the subscription account's user may view a product subscription.
|
||||
if err := backend.CheckSiteAdminOrSameUser(ctx, v.UserID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &productSubscription{v: v}, nil
|
||||
}
|
||||
|
||||
func (r *productSubscription) ID() graphql.ID {
|
||||
return marshalProductSubscriptionID(r.v.ID)
|
||||
}
|
||||
|
||||
func marshalProductSubscriptionID(id string) graphql.ID {
|
||||
return relay.MarshalID("ProductSubscription", id)
|
||||
}
|
||||
|
||||
func unmarshalProductSubscriptionID(id graphql.ID) (productSubscriptionID string, err error) {
|
||||
err = relay.UnmarshalSpec(id, &productSubscriptionID)
|
||||
return
|
||||
}
|
||||
|
||||
func (r *productSubscription) UUID() string {
|
||||
return r.v.ID
|
||||
}
|
||||
|
||||
func (r *productSubscription) Name() string {
|
||||
return fmt.Sprintf("L-%s", strings.ToUpper(strings.Replace(r.v.ID, "-", "", -1)[:10]))
|
||||
}
|
||||
|
||||
func (r *productSubscription) Account(ctx context.Context) (*graphqlbackend.UserResolver, error) {
|
||||
return graphqlbackend.UserByIDInt32(ctx, r.v.UserID)
|
||||
}
|
||||
|
||||
func (r *productSubscription) Events(ctx context.Context) ([]graphqlbackend.ProductSubscriptionEvent, error) {
|
||||
if r.v.BillingSubscriptionID == nil {
|
||||
return []graphqlbackend.ProductSubscriptionEvent{}, nil
|
||||
}
|
||||
|
||||
// List all events related to this subscription. The related_object parameter is an undocumented
|
||||
// Stripe API.
|
||||
params := &stripe.EventListParams{
|
||||
ListParams: stripe.ListParams{Context: ctx},
|
||||
}
|
||||
params.Filters.AddFilter("related_object", "", *r.v.BillingSubscriptionID)
|
||||
events := event.List(params)
|
||||
var gqlEvents []graphqlbackend.ProductSubscriptionEvent
|
||||
for events.Next() {
|
||||
gqlEvent, okToShowUser := billing.ToProductSubscriptionEvent(events.Event())
|
||||
if okToShowUser {
|
||||
gqlEvents = append(gqlEvents, gqlEvent)
|
||||
}
|
||||
}
|
||||
if err := events.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return gqlEvents, nil
|
||||
}
|
||||
|
||||
func (r *productSubscription) ActiveLicense(ctx context.Context) (graphqlbackend.ProductLicense, error) {
|
||||
// Return newest license.
|
||||
licenses, err := dbLicenses{}.List(ctx, dbLicensesListOptions{
|
||||
ProductSubscriptionID: r.v.ID,
|
||||
LimitOffset: &db_.LimitOffset{Limit: 1},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(licenses) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return &productLicense{v: licenses[0]}, nil
|
||||
}
|
||||
|
||||
func (r *productSubscription) ProductLicenses(ctx context.Context, args *graphqlutil.ConnectionArgs) (graphqlbackend.ProductLicenseConnection, error) {
|
||||
// 🚨 SECURITY: Only site admins may list historical product licenses (to reduce confusion
|
||||
// around old license reuse). Other viewers should use ProductSubscription.activeLicense.
|
||||
if err := backend.CheckCurrentUserIsSiteAdmin(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
opt := dbLicensesListOptions{ProductSubscriptionID: r.v.ID}
|
||||
args.Set(&opt.LimitOffset)
|
||||
return &productLicenseConnection{opt: opt}, nil
|
||||
}
|
||||
|
||||
func (r *productSubscription) CreatedAt() string {
|
||||
return r.v.CreatedAt.Format(time.RFC3339)
|
||||
}
|
||||
|
||||
func (r *productSubscription) IsArchived() bool { return r.v.ArchivedAt != nil }
|
||||
|
||||
func (r *productSubscription) URL(ctx context.Context) (string, error) {
|
||||
accountUser, err := r.Account(ctx)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return accountUser.URL() + "/subscriptions/" + string(r.v.ID), nil
|
||||
}
|
||||
|
||||
func (r *productSubscription) URLForSiteAdmin(ctx context.Context) *string {
|
||||
// 🚨 SECURITY: Only site admins may see this URL. Currently it does not contain any sensitive
|
||||
// info, but there is no need to show it to non-site admins.
|
||||
if err := backend.CheckCurrentUserIsSiteAdmin(ctx); err != nil {
|
||||
return nil
|
||||
}
|
||||
u := fmt.Sprintf("/site-admin/dotcom/product/subscriptions/%s", r.v.ID)
|
||||
return &u
|
||||
}
|
||||
|
||||
func (r *productSubscription) URLForSiteAdminBilling(ctx context.Context) (*string, error) {
|
||||
// 🚨 SECURITY: Only site admins may see this URL, which might contain the subscription's billing ID.
|
||||
if err := backend.CheckCurrentUserIsSiteAdmin(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if id := r.v.BillingSubscriptionID; id != nil {
|
||||
u := billing.SubscriptionURL(*id)
|
||||
return &u, nil
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (ProductSubscriptionLicensingResolver) CreateProductSubscription(ctx context.Context, args *graphqlbackend.CreateProductSubscriptionArgs) (graphqlbackend.ProductSubscription, error) {
|
||||
// 🚨 SECURITY: Only site admins may create product subscriptions.
|
||||
if err := backend.CheckCurrentUserIsSiteAdmin(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
user, err := graphqlbackend.UserByID(ctx, args.AccountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
id, err := dbSubscriptions{}.Create(ctx, user.SourcegraphID())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return productSubscriptionByDBID(ctx, id)
|
||||
}
|
||||
|
||||
func (ProductSubscriptionLicensingResolver) SetProductSubscriptionBilling(ctx context.Context, args *graphqlbackend.SetProductSubscriptionBillingArgs) (*graphqlbackend.EmptyResponse, error) {
|
||||
// 🚨 SECURITY: Only site admins may update product subscriptions.
|
||||
if err := backend.CheckCurrentUserIsSiteAdmin(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Ensure the args refer to valid subscriptions in the database and in the billing system.
|
||||
dbSub, err := productSubscriptionByID(ctx, args.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if args.BillingSubscriptionID != nil {
|
||||
if _, err := sub.Get(*args.BillingSubscriptionID, &stripe.SubscriptionParams{Params: stripe.Params{Context: ctx}}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
stringValue := func(s *string) string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
return *s
|
||||
}
|
||||
|
||||
if err := (dbSubscriptions{}).Update(ctx, dbSub.v.ID, dbSubscriptionUpdate{
|
||||
billingSubscriptionID: &sql.NullString{
|
||||
String: stringValue(args.BillingSubscriptionID),
|
||||
Valid: args.BillingSubscriptionID != nil,
|
||||
},
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &graphqlbackend.EmptyResponse{}, nil
|
||||
}
|
||||
|
||||
func (ProductSubscriptionLicensingResolver) CreatePaidProductSubscription(ctx context.Context, args *graphqlbackend.CreatePaidProductSubscriptionArgs) (*graphqlbackend.CreatePaidProductSubscriptionResult, error) {
|
||||
user, err := graphqlbackend.UserByID(ctx, args.AccountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 🚨 SECURITY: Users may only create paid product subscriptions for themselves. Site admins may
|
||||
// create them for any user.
|
||||
if err := backend.CheckSiteAdminOrSameUser(ctx, user.SourcegraphID()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Determine which license tags and min quantity to use for the purchased plan. Do this early on
|
||||
// because it's the most likely place for a stupid mistake to cause a bug, and doing it early
|
||||
// means the user hasn't been charged if there is an error.
|
||||
licenseTags, minQuantity, err := billing.InfoForProductPlan(ctx, args.ProductSubscription.BillingPlanID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if minQuantity != nil && args.ProductSubscription.UserCount < *minQuantity {
|
||||
args.ProductSubscription.UserCount = *minQuantity
|
||||
}
|
||||
|
||||
// Create the subscription in our database first, before processing payment. If payment fails,
|
||||
// users can retry payment on the already created subscription.
|
||||
subID, err := dbSubscriptions{}.Create(ctx, user.SourcegraphID())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Get the billing customer for the current user, and update it to use the payment source
|
||||
// provided to us.
|
||||
custID, err := billing.GetOrAssignUserCustomerID(ctx, user.SourcegraphID())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
custUpdateParams := &stripe.CustomerParams{
|
||||
Params: stripe.Params{Context: ctx},
|
||||
}
|
||||
custUpdateParams.SetSource(args.PaymentToken)
|
||||
if _, err := customer.Update(custID, custUpdateParams); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Create the billing subscription.
|
||||
billingSub, err := sub.New(&stripe.SubscriptionParams{
|
||||
Params: stripe.Params{Context: ctx},
|
||||
Customer: stripe.String(custID),
|
||||
Items: []*stripe.SubscriptionItemsParams{billing.ToSubscriptionItemsParams(args.ProductSubscription)},
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Link the billing subscription with the subscription in our database.
|
||||
if err := (dbSubscriptions{}).Update(ctx, subID, dbSubscriptionUpdate{
|
||||
billingSubscriptionID: &sql.NullString{
|
||||
String: billingSub.ID,
|
||||
Valid: true,
|
||||
},
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Generate a new license key for the subscription.
|
||||
if _, err := generateProductLicenseForSubscription(ctx, subID, &graphqlbackend.ProductLicenseInput{
|
||||
Tags: licenseTags,
|
||||
UserCount: args.ProductSubscription.UserCount,
|
||||
ExpiresAt: int32(billingSub.CurrentPeriodEnd),
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sub, err := productSubscriptionByDBID(ctx, subID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &graphqlbackend.CreatePaidProductSubscriptionResult{ProductSubscriptionValue: sub}, nil
|
||||
}
|
||||
|
||||
func (ProductSubscriptionLicensingResolver) UpdatePaidProductSubscription(ctx context.Context, args *graphqlbackend.UpdatePaidProductSubscriptionArgs) (*graphqlbackend.UpdatePaidProductSubscriptionResult, error) {
|
||||
subToUpdate, err := productSubscriptionByID(ctx, args.SubscriptionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// 🚨 SECURITY: Only site admins and the subscription's account owner may update product
|
||||
// subscriptions.
|
||||
if err := backend.CheckSiteAdminOrSameUser(ctx, subToUpdate.v.UserID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Determine which license tags and min quantity to use for the purchased plan. Do this early on
|
||||
// because it's the most likely place for a stupid mistake to cause a bug, and doing it early
|
||||
// means the user hasn't been charged if there is an error.
|
||||
licenseTags, minQuantity, err := billing.InfoForProductPlan(ctx, args.Update.BillingPlanID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if minQuantity != nil && args.Update.UserCount < *minQuantity {
|
||||
args.Update.UserCount = *minQuantity
|
||||
}
|
||||
|
||||
params := &stripe.SubscriptionParams{
|
||||
Params: stripe.Params{Context: ctx},
|
||||
Items: []*stripe.SubscriptionItemsParams{billing.ToSubscriptionItemsParams(args.Update)},
|
||||
Prorate: stripe.Bool(true),
|
||||
}
|
||||
|
||||
// Get the billing customer for the current user, and update it to use the payment source
|
||||
// provided to us.
|
||||
custID, err := billing.GetOrAssignUserCustomerID(ctx, subToUpdate.v.UserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
custUpdateParams := &stripe.CustomerParams{
|
||||
Params: stripe.Params{Context: ctx},
|
||||
}
|
||||
custUpdateParams.SetSource(args.PaymentToken)
|
||||
if _, err := customer.Update(custID, custUpdateParams); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if subToUpdate.v.BillingSubscriptionID == nil {
|
||||
return nil, errors.New("unable to update product subscription that has no associated billing information")
|
||||
}
|
||||
subParams := &stripe.SubscriptionParams{Params: stripe.Params{Context: ctx}}
|
||||
subParams.AddExpand("plan")
|
||||
billingSubToUpdate, err := sub.Get(*subToUpdate.v.BillingSubscriptionID, subParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
idToReplace, err := billing.GetSubscriptionItemIDToReplace(billingSubToUpdate, custID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
params.Items[0].ID = stripe.String(idToReplace)
|
||||
|
||||
// Forbid self-service downgrades. (Reason: We can't revoke licenses, so we want to manually
|
||||
// intervene to ensure that customers who downgrade are not using the previous license.)
|
||||
{
|
||||
planParams := &stripe.PlanParams{Params: stripe.Params{Context: ctx}}
|
||||
afterPlan, err := plan.Get(args.Update.BillingPlanID, planParams)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if isDowngradeRequiringManualIntervention(int32(billingSubToUpdate.Quantity), billingSubToUpdate.Plan.Amount, args.Update.UserCount, afterPlan.Amount) {
|
||||
return nil, errors.New("self-service downgrades are not yet supported")
|
||||
}
|
||||
}
|
||||
|
||||
// Update the billing subscription.
|
||||
billingSub, err := sub.Update(*subToUpdate.v.BillingSubscriptionID, params)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Generate an invoice and charge so that payment is performed immediately. See
|
||||
// https://stripe.com/docs/billing/subscriptions/upgrading-downgrading.
|
||||
//
|
||||
// TODO(sqs): use webhooks to ensure the subscription is rolled back if the invoice payment
|
||||
// fails.
|
||||
{
|
||||
inv, err := invoice.New(&stripe.InvoiceParams{
|
||||
Params: stripe.Params{Context: ctx},
|
||||
Customer: stripe.String(custID),
|
||||
Subscription: stripe.String(*subToUpdate.v.BillingSubscriptionID),
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := invoice.Pay(inv.ID, &stripe.InvoicePayParams{
|
||||
Params: stripe.Params{Context: ctx},
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// Generate a new license key for the subscription with the updated parameters.
|
||||
if _, err := generateProductLicenseForSubscription(ctx, subToUpdate.v.ID, &graphqlbackend.ProductLicenseInput{
|
||||
Tags: licenseTags,
|
||||
UserCount: args.Update.UserCount,
|
||||
ExpiresAt: int32(billingSub.CurrentPeriodEnd),
|
||||
}); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &graphqlbackend.UpdatePaidProductSubscriptionResult{ProductSubscriptionValue: subToUpdate}, nil
|
||||
}
|
||||
|
||||
func (ProductSubscriptionLicensingResolver) ArchiveProductSubscription(ctx context.Context, args *graphqlbackend.ArchiveProductSubscriptionArgs) (*graphqlbackend.EmptyResponse, error) {
|
||||
// 🚨 SECURITY: Only site admins may archive product subscriptions.
|
||||
if err := backend.CheckCurrentUserIsSiteAdmin(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sub, err := productSubscriptionByID(ctx, args.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := (dbSubscriptions{}).Archive(ctx, sub.v.ID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &graphqlbackend.EmptyResponse{}, nil
|
||||
}
|
||||
|
||||
func (ProductSubscriptionLicensingResolver) ProductSubscription(ctx context.Context, args *graphqlbackend.ProductSubscriptionArgs) (graphqlbackend.ProductSubscription, error) {
|
||||
// 🚨 SECURITY: Only site admins and the subscription's account owner may get a product
|
||||
// subscription. This check is performed in productSubscriptionByDBID.
|
||||
return productSubscriptionByDBID(ctx, args.UUID)
|
||||
}
|
||||
|
||||
func (ProductSubscriptionLicensingResolver) ProductSubscriptions(ctx context.Context, args *graphqlbackend.ProductSubscriptionsArgs) (graphqlbackend.ProductSubscriptionConnection, error) {
|
||||
var accountUser *graphqlbackend.UserResolver
|
||||
if args.Account != nil {
|
||||
var err error
|
||||
accountUser, err = graphqlbackend.UserByID(ctx, *args.Account)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
// 🚨 SECURITY: Users may only list their own product subscriptions. Site admins may list
|
||||
// licenses for all users, or for any other user.
|
||||
if accountUser == nil {
|
||||
if err := backend.CheckCurrentUserIsSiteAdmin(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
if err := backend.CheckSiteAdminOrSameUser(ctx, accountUser.SourcegraphID()); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var opt dbSubscriptionsListOptions
|
||||
if accountUser != nil {
|
||||
opt.UserID = accountUser.SourcegraphID()
|
||||
}
|
||||
args.ConnectionArgs.Set(&opt.LimitOffset)
|
||||
return &productSubscriptionConnection{opt: opt}, nil
|
||||
}
|
||||
|
||||
// productSubscriptionConnection implements the GraphQL type ProductSubscriptionConnection.
|
||||
//
|
||||
// 🚨 SECURITY: When instantiating a productSubscriptionConnection value, the caller MUST
|
||||
// check permissions.
|
||||
type productSubscriptionConnection struct {
|
||||
opt dbSubscriptionsListOptions
|
||||
|
||||
// cache results because they are used by multiple fields
|
||||
once sync.Once
|
||||
results []*dbSubscription
|
||||
err error
|
||||
}
|
||||
|
||||
func (r *productSubscriptionConnection) compute(ctx context.Context) ([]*dbSubscription, error) {
|
||||
r.once.Do(func() {
|
||||
opt2 := r.opt
|
||||
if opt2.LimitOffset != nil {
|
||||
tmp := *opt2.LimitOffset
|
||||
opt2.LimitOffset = &tmp
|
||||
opt2.Limit++ // so we can detect if there is a next page
|
||||
}
|
||||
|
||||
r.results, r.err = dbSubscriptions{}.List(ctx, opt2)
|
||||
})
|
||||
return r.results, r.err
|
||||
}
|
||||
|
||||
func (r *productSubscriptionConnection) Nodes(ctx context.Context) ([]graphqlbackend.ProductSubscription, error) {
|
||||
results, err := r.compute(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var l []graphqlbackend.ProductSubscription
|
||||
for _, result := range results {
|
||||
l = append(l, &productSubscription{v: result})
|
||||
}
|
||||
return l, nil
|
||||
}
|
||||
|
||||
func (r *productSubscriptionConnection) TotalCount(ctx context.Context) (int32, error) {
|
||||
count, err := dbSubscriptions{}.Count(ctx, r.opt)
|
||||
return int32(count), err
|
||||
}
|
||||
|
||||
func (r *productSubscriptionConnection) PageInfo(ctx context.Context) (*graphqlutil.PageInfo, error) {
|
||||
results, err := r.compute(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return graphqlutil.HasNextPage(r.opt.LimitOffset != nil && len(results) > r.opt.Limit), nil
|
||||
}
|
||||
18
enterprise/cmd/frontend/internal/graphqlbackend/dotcom.go
Normal file
18
enterprise/cmd/frontend/internal/graphqlbackend/dotcom.go
Normal file
@ -0,0 +1,18 @@
|
||||
package graphqlbackend
|
||||
|
||||
import (
|
||||
"github.com/sourcegraph/enterprise/cmd/frontend/internal/dotcom/billing"
|
||||
"github.com/sourcegraph/enterprise/cmd/frontend/internal/dotcom/productsubscription"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
|
||||
)
|
||||
|
||||
func init() {
|
||||
// Contribute the GraphQL types DotcomMutation and DotcomQuery.
|
||||
graphqlbackend.Dotcom = dotcomResolver{}
|
||||
}
|
||||
|
||||
// dotcomResolver implements the GraphQL types DotcomMutation and DotcomQuery.
|
||||
type dotcomResolver struct {
|
||||
productsubscription.ProductSubscriptionLicensingResolver
|
||||
billing.BillingResolver
|
||||
}
|
||||
3
enterprise/cmd/frontend/internal/httpapi/doc.go
Normal file
3
enterprise/cmd/frontend/internal/httpapi/doc.go
Normal file
@ -0,0 +1,3 @@
|
||||
// Package httpapi should be imported for side-effects to register enterprise-specific hooks into
|
||||
// corresponding httpapi package in the open-source repository.
|
||||
package httpapi
|
||||
16
enterprise/cmd/frontend/internal/httpapi/register.go
Normal file
16
enterprise/cmd/frontend/internal/httpapi/register.go
Normal file
@ -0,0 +1,16 @@
|
||||
package httpapi
|
||||
|
||||
import (
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/httpapi"
|
||||
"github.com/sourcegraph/sourcegraph/xlang"
|
||||
)
|
||||
|
||||
func init() {
|
||||
httpapi.XLangNewClient = func() (httpapi.XLangClient, error) {
|
||||
c, err := xlang.UnsafeNewDefaultClient()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &xclient{Client: c}, nil
|
||||
}
|
||||
}
|
||||
394
enterprise/cmd/frontend/internal/httpapi/xclient.go
Normal file
394
enterprise/cmd/frontend/internal/httpapi/xclient.go
Normal file
@ -0,0 +1,394 @@
|
||||
package httpapi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/url"
|
||||
|
||||
opentracing "github.com/opentracing/opentracing-go"
|
||||
"github.com/opentracing/opentracing-go/ext"
|
||||
otlog "github.com/opentracing/opentracing-go/log"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sourcegraph/go-langserver/pkg/lsp"
|
||||
"github.com/sourcegraph/go-langserver/pkg/lspext"
|
||||
vcsurl "github.com/sourcegraph/go-vcsurl"
|
||||
"github.com/sourcegraph/jsonrpc2"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/backend"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/api"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/conf"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/errcode"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/vcs"
|
||||
"github.com/sourcegraph/sourcegraph/xlang"
|
||||
xlang_lspext "github.com/sourcegraph/sourcegraph/xlang/lspext"
|
||||
)
|
||||
|
||||
// xclient is an LSP client that transparently wraps xlang.Client,
|
||||
// except that it translates textDocument/definition requests into a
|
||||
// series of requests that computes the cross-repo jump-to-definition
|
||||
// result.
|
||||
type xclient struct {
|
||||
*xlang.Client
|
||||
|
||||
hasXDefinitionAndXPackages bool
|
||||
hasCrossRepoHover bool
|
||||
mode string
|
||||
}
|
||||
|
||||
// Call transparently wraps xlang.Client.Call *except* for `textDocument/definition` if the language
|
||||
// server is a textDocument/xdefinition provider. In that case, this method invokes
|
||||
// `textDocument/xdefinition` instead. If the result contains a non-zero `Location` field, then that
|
||||
// is returned to the client as if it came from `textDocument/definition`. If the location is zero,
|
||||
// then that means the definition did not exist locally. The method will locate the definition in an
|
||||
// external repository and return that to the client as if it came from a single
|
||||
// `textDocument/definition` call.
|
||||
//
|
||||
// SECURITY NOTE: Call also verifies permissions for cross-repo jumps. Any changes to this method
|
||||
// should preserve this property.
|
||||
func (c *xclient) Call(ctx context.Context, method string, params, result interface{}, opt ...jsonrpc2.CallOption) (err error) {
|
||||
span, ctx := opentracing.StartSpanFromContext(ctx, "xclient.Call")
|
||||
defer func() {
|
||||
if err != nil {
|
||||
ext.Error.Set(span, true)
|
||||
span.SetTag("err", err.Error())
|
||||
}
|
||||
span.Finish()
|
||||
}()
|
||||
span.SetTag("Method", method)
|
||||
|
||||
// marshalResult takes an existing interface and marshals it into result
|
||||
// via JSON.
|
||||
marshalResult := func(v interface{}, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
b, err := json.Marshal(v)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "marshaling result")
|
||||
}
|
||||
return json.Unmarshal(b, result)
|
||||
}
|
||||
|
||||
switch {
|
||||
case method == "initialize":
|
||||
var init xlang_lspext.ClientProxyInitializeParams
|
||||
if err := json.Unmarshal(*params.(*json.RawMessage), &init); err != nil {
|
||||
return err
|
||||
}
|
||||
c.mode = init.InitializationOptions.Mode
|
||||
if c.mode == "" {
|
||||
// DEPRECATED: Use old Mode field if the new one is not set.
|
||||
c.mode = init.Mode
|
||||
}
|
||||
var resultRaw json.RawMessage
|
||||
if err := c.Client.Call(ctx, method, params, &resultRaw, opt...); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// We only care about the XDefinitionProvider. Right now it implies
|
||||
// the support of XPackages as well :'(
|
||||
var initResultSubset struct {
|
||||
Capabilities struct {
|
||||
// XDefinitionProvider indicates the server provides support for
|
||||
// textDocument/xdefinition. This is a Sourcegraph extension.
|
||||
XDefinitionProvider bool `json:"xdefinitionProvider,omitempty"`
|
||||
} `json:"capabilities,omitempty"`
|
||||
}
|
||||
if err := json.Unmarshal(resultRaw, &initResultSubset); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.hasXDefinitionAndXPackages = initResultSubset.Capabilities.XDefinitionProvider
|
||||
//_, c.hasXDefinitionAndXPackages = xlang.HasXDefinitionAndXPackages[c.mode]
|
||||
_, c.hasCrossRepoHover = xlang.HasCrossRepoHover[c.mode]
|
||||
|
||||
return json.Unmarshal(resultRaw, result)
|
||||
|
||||
case !c.hasXDefinitionAndXPackages:
|
||||
break
|
||||
case method == "textDocument/definition":
|
||||
span.SetTag("LocationAbsent", "true")
|
||||
return marshalResult(c.jumpToDefCrossRepo(ctx, params, opt...))
|
||||
case method == "textDocument/hover" && c.hasCrossRepoHover:
|
||||
return marshalResult(c.hoverCrossRepo(ctx, params, opt...))
|
||||
|
||||
case method == "xsymbol/hover":
|
||||
// Federation. This will only run on sourcegraph.com
|
||||
var syms []lspext.SymbolLocationInformation
|
||||
if err := json.Unmarshal(*params.(*json.RawMessage), &syms); err != nil {
|
||||
return err
|
||||
}
|
||||
return marshalResult(c.symbolHover(ctx, syms))
|
||||
|
||||
case method == "xsymbol/definition":
|
||||
// Federation. This will only run on sourcegraph.com
|
||||
var syms []lspext.SymbolLocationInformation
|
||||
if err := json.Unmarshal(*params.(*json.RawMessage), &syms); err != nil {
|
||||
return err
|
||||
}
|
||||
return marshalResult(c.symbolDefinition(ctx, syms))
|
||||
}
|
||||
return c.Client.Call(ctx, method, params, result, opt...)
|
||||
}
|
||||
|
||||
func (c *xclient) Notify(ctx context.Context, method string, params interface{}, opt ...jsonrpc2.CallOption) error {
|
||||
return c.Client.Notify(ctx, method, params, opt...)
|
||||
}
|
||||
|
||||
func (c *xclient) Close() error {
|
||||
return c.Client.Close()
|
||||
}
|
||||
|
||||
func (c *xclient) xdefQuery(ctx context.Context, syms []lspext.SymbolLocationInformation) (map[lsp.DocumentURI][]lsp.SymbolInformation, error) {
|
||||
span := opentracing.SpanFromContext(ctx)
|
||||
|
||||
symInfos := make(map[lsp.DocumentURI][]lsp.SymbolInformation)
|
||||
// For each symbol in the xdefinition-result-derived query, compute the symbol information for that symbol
|
||||
for _, sym := range syms {
|
||||
|
||||
var rootURIs []lsp.DocumentURI
|
||||
// If we can extract the repository URL from the symbol metadata, do so
|
||||
if repoURL := xlang.SymbolRepoURL(sym.Symbol); repoURL != "" {
|
||||
span.LogFields(otlog.String("event", "extracted repo directly from symbol metadata"),
|
||||
otlog.String("repoURL", repoURL))
|
||||
|
||||
repoInfo, err := vcsurl.Parse(repoURL)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "extract repo URL from symbol metadata")
|
||||
}
|
||||
repoURI := api.RepoURI(string(repoInfo.RepoHost) + "/" + repoInfo.FullName)
|
||||
|
||||
// We issue a workspace/symbols on the URL, so ensure we have the repo / it exists.
|
||||
repo, err := backend.Repos.GetByURI(ctx, repoURI)
|
||||
if err != nil {
|
||||
span.LogFields(otlog.Error(err))
|
||||
if _, isSeeOther := err.(backend.ErrRepoSeeOther); isSeeOther || errcode.IsNotFound(err) {
|
||||
span.LogFields(otlog.String("event", "ignoring not found error"))
|
||||
continue
|
||||
}
|
||||
return nil, errors.Wrap(err, "extract repo URL from symbol metadata")
|
||||
}
|
||||
rev, err := backend.Repos.ResolveRev(ctx, repo, "")
|
||||
if err != nil {
|
||||
span.LogFields(otlog.Error(err))
|
||||
if vcs.IsRepoNotExist(err) {
|
||||
span.LogFields(otlog.String("event", "ignoring not found error"))
|
||||
continue
|
||||
}
|
||||
return nil, errors.Wrap(err, "extract repo URL from symbol metadata")
|
||||
}
|
||||
rootURIs = append(rootURIs, lsp.DocumentURI(string(repoInfo.VCS)+"://"+string(repoURI)+"?"+string(rev)))
|
||||
} else { // if we can't extract the repository URL directly, we have to consult the pkgs database
|
||||
pkgDescriptor, ok := xlang.SymbolPackageDescriptor(sym.Symbol, c.mode)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
span.LogFields(otlog.String("event", "cross-repo jump to def"))
|
||||
pkgs, err := db.Pkgs.ListPackages(ctx, &api.ListPackagesOp{PkgQuery: pkgDescriptor, Lang: c.mode, Limit: 1})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "getting repo by package db query")
|
||||
}
|
||||
span.LogFields(otlog.String("event", "listed repository packages"))
|
||||
for _, pkg := range pkgs {
|
||||
repo, err := backend.Repos.Get(ctx, pkg.RepoID)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "fetch repo for package")
|
||||
}
|
||||
var commit api.CommitID
|
||||
if repo.IndexedRevision != nil {
|
||||
commit = *repo.IndexedRevision
|
||||
} else {
|
||||
var err error
|
||||
commit, err = backend.Repos.ResolveRev(ctx, repo, "")
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "resolve revision for package repo")
|
||||
}
|
||||
}
|
||||
rootURIs = append(rootURIs, lsp.DocumentURI("git://"+string(repo.URI)+"?"+string(commit)))
|
||||
}
|
||||
span.LogFields(otlog.String("event", "resolved rootURIs"))
|
||||
}
|
||||
|
||||
// Issue a workspace/symbol for each repository that provides a definition for the symbol
|
||||
for _, rootURI := range rootURIs {
|
||||
params := &lspext.WorkspaceSymbolParams{Symbol: sym.Symbol, Limit: 10}
|
||||
var repoSymInfos []lsp.SymbolInformation
|
||||
if err := xlang.UnsafeOneShotClientRequest(ctx, c.mode, rootURI, "workspace/symbol", params, &repoSymInfos); err != nil {
|
||||
return nil, errors.Wrap(err, "resolving symbol to location")
|
||||
}
|
||||
symInfos[rootURI] = repoSymInfos
|
||||
}
|
||||
span.LogFields(otlog.String("event", "done issuing workspace/symbol requests"))
|
||||
}
|
||||
return symInfos, nil
|
||||
}
|
||||
|
||||
func (c *xclient) jumpToDefCrossRepo(ctx context.Context, params interface{}, opt ...jsonrpc2.CallOption) ([]lsp.Location, error) {
|
||||
// Issue xdefinition request
|
||||
var syms []lspext.SymbolLocationInformation
|
||||
if err := c.Client.Call(ctx, "textDocument/xdefinition", params, &syms, opt...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
locs := make([]lsp.Location, 0, len(syms))
|
||||
|
||||
var nolocSyms []lspext.SymbolLocationInformation
|
||||
for _, sym := range syms {
|
||||
// If a concrete location is already present, just use that
|
||||
if sym.Location != (lsp.Location{}) {
|
||||
locs = append(locs, sym.Location)
|
||||
} else {
|
||||
nolocSyms = append(nolocSyms, sym)
|
||||
}
|
||||
}
|
||||
|
||||
symLocs, err := c.symbolDefinition(ctx, nolocSyms)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
locs = append(locs, symLocs...)
|
||||
|
||||
// Failed to find the definition locally, try symbolDefinition on Sourcegraph.com
|
||||
// which may have indexed the OSS repo used.
|
||||
if len(locs) == 0 && len(nolocSyms) > 0 && conf.JumpToDefOSSIndexEnabled() {
|
||||
// HACK we need a valid rootURI, even though we are doing symbol queries.
|
||||
rootURI := lsp.DocumentURI("git://github.com/gorilla/mux?4dbd923b0c9e99ff63ad54b0e9705ff92d3cdb06")
|
||||
err := xlang.RemoteOneShotClientRequest(ctx, sourcegraphDotComBaseURL, c.mode, rootURI, "xsymbol/definition", nolocSyms, &locs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return locs, nil
|
||||
}
|
||||
|
||||
return locs, nil
|
||||
}
|
||||
|
||||
func (c *xclient) symbolDefinition(ctx context.Context, syms []lspext.SymbolLocationInformation) ([]lsp.Location, error) {
|
||||
symInfos, err := c.xdefQuery(ctx, syms)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
locs := make([]lsp.Location, 0, len(symInfos))
|
||||
for _, repoSymInfos := range symInfos {
|
||||
for _, symInfo := range repoSymInfos {
|
||||
locs = append(locs, symInfo.Location)
|
||||
}
|
||||
}
|
||||
return locs, nil
|
||||
}
|
||||
|
||||
// hoverCrossRepo translates hover requests in the current repository to a
|
||||
// hover request on the definition in the definition's repository.
|
||||
//
|
||||
// Algorithm:
|
||||
//
|
||||
// 1. If we are hovering over a symbol in the current repository, use the
|
||||
// normal textDocument/hover.
|
||||
// 2. Use textDocument/xdefinition (sg extension) to retrieve symbol information.
|
||||
// 3. symbolHover: Using the symbols, use the first successful hover in the
|
||||
// definition repos.
|
||||
// 4. If we do not find a non-empty hover and federation is enabled, we send
|
||||
// the package query to Sourcegraph.com's xlang API. The assumption is the
|
||||
// dependency is an OSS package so we can consult our public index. If the
|
||||
// response is non-empty we return it.
|
||||
// 5. If we do not find a non-empty hover, fallback to the normal hover.
|
||||
func (c *xclient) hoverCrossRepo(ctx context.Context, params interface{}, opt ...jsonrpc2.CallOption) (*lsp.Hover, error) {
|
||||
// Note: we can't parallelize the hover and xdefinition requests
|
||||
// without breaking the request cancellation logic used by LSP
|
||||
// proxy
|
||||
|
||||
// xdefinition request
|
||||
var syms []lspext.SymbolLocationInformation
|
||||
if err := c.Client.Call(ctx, "textDocument/xdefinition", params, &syms, opt...); err != nil {
|
||||
return nil, errors.Wrap(err, "hoverCrossRepo: textDocument/xdefinition error")
|
||||
}
|
||||
|
||||
// hover request
|
||||
var hover lsp.Hover
|
||||
if err := c.Client.Call(ctx, "textDocument/hover", params, &hover, opt...); err != nil {
|
||||
return nil, errors.Wrap(err, "hoverCrossRepo: textDocument/hover error")
|
||||
}
|
||||
|
||||
// return local hover if local definition found
|
||||
for _, sym := range syms {
|
||||
if sym.Location != (lsp.Location{}) {
|
||||
return &hover, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Cross repo hover is done via the symbols only.
|
||||
xhov, err := c.symbolHover(ctx, syms)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(xhov.Contents) > 0 {
|
||||
// Range is for the queried token, so we need to use the local
|
||||
// hover range.
|
||||
xhov.Range = hover.Range
|
||||
return xhov, nil
|
||||
}
|
||||
|
||||
// Failed to find the hover locally, try symbolHover on Sourcegraph.com
|
||||
// which may have indexed the OSS repo used.
|
||||
if conf.JumpToDefOSSIndexEnabled() {
|
||||
// HACK we need a valid rootURI, even though we are doing symbol queries.
|
||||
rootURI := lsp.DocumentURI("git://github.com/gorilla/mux?4dbd923b0c9e99ff63ad54b0e9705ff92d3cdb06")
|
||||
var remoteHov lsp.Hover
|
||||
err := xlang.RemoteOneShotClientRequest(ctx, sourcegraphDotComBaseURL, c.mode, rootURI, "xsymbol/hover", syms, &remoteHov)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(remoteHov.Contents) > 0 {
|
||||
// Range is for the queried token, so we need to use the local
|
||||
// hover range.
|
||||
remoteHov.Range = hover.Range
|
||||
return &remoteHov, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to local hover contents.
|
||||
return &hover, nil
|
||||
}
|
||||
|
||||
// symbolHover finds a hover contents for the given symbols.
|
||||
//
|
||||
// Algorithm:
|
||||
//
|
||||
// 1. xdefQuery: Consult our packages index to find potential repositories
|
||||
// containing the symbols.
|
||||
// 2. xdefQuery: For each potential repository use workspace/symbol with our
|
||||
// symbol query (sg extension).
|
||||
// 3. For each symbol do a textDocument/hover. The first non-empty hover
|
||||
// content we return to the user.
|
||||
func (c *xclient) symbolHover(ctx context.Context, syms []lspext.SymbolLocationInformation) (*lsp.Hover, error) {
|
||||
symInfos, err := c.xdefQuery(ctx, syms)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// return first hover found
|
||||
for rootURI, repoSymInfos := range symInfos {
|
||||
for _, symInfo := range repoSymInfos {
|
||||
pos := symInfo.Location.Range.Start
|
||||
pos.Character++
|
||||
p := lsp.TextDocumentPositionParams{
|
||||
TextDocument: lsp.TextDocumentIdentifier{URI: symInfo.Location.URI},
|
||||
Position: pos,
|
||||
}
|
||||
var xhov lsp.Hover
|
||||
if err := xlang.UnsafeOneShotClientRequest(ctx, c.mode, rootURI, "textDocument/hover", p, &xhov); err != nil {
|
||||
return nil, errors.Wrap(err, "hoverCrossRepo: external textDocument/hover error")
|
||||
}
|
||||
if len(xhov.Contents) > 0 {
|
||||
return &xhov, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// nothing found, so empty response
|
||||
return &lsp.Hover{}, nil
|
||||
}
|
||||
|
||||
var sourcegraphDotComBaseURL = &url.URL{Scheme: "https", Host: "sourcegraph.com"}
|
||||
3
enterprise/cmd/frontend/internal/licensing/doc.go
Normal file
3
enterprise/cmd/frontend/internal/licensing/doc.go
Normal file
@ -0,0 +1,3 @@
|
||||
// Package licensing handles parsing, verifying, and enforcing the product subscription (specified in
|
||||
// site configuration).
|
||||
package licensing
|
||||
125
enterprise/cmd/frontend/internal/licensing/enforcement.go
Normal file
125
enterprise/cmd/frontend/internal/licensing/enforcement.go
Normal file
@ -0,0 +1,125 @@
|
||||
package licensing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"html"
|
||||
"net/http"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/hooks"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/errcode"
|
||||
log15 "gopkg.in/inconshreveable/log15.v2"
|
||||
)
|
||||
|
||||
// Enforce the use of a valid license key by preventing all HTTP requests if the license is invalid
|
||||
// (due to a error in parsing or verification, or because the license has expired).
|
||||
func init() {
|
||||
hooks.PreAuthMiddleware = func(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
info, err := GetConfiguredProductLicenseInfo()
|
||||
if err != nil {
|
||||
log15.Error("Error reading license key for Sourcegraph subscription.", "err", err)
|
||||
WriteSubscriptionErrorResponse(w, http.StatusInternalServerError, "Error reading Sourcegraph license key", "Site admins may check the logs for more information.")
|
||||
return
|
||||
}
|
||||
if info != nil && info.IsExpiredWithGracePeriod() {
|
||||
WriteSubscriptionErrorResponse(w, http.StatusForbidden, "Sourcegraph license expired", "To continue using Sourcegraph, a site admin must renew the Sourcegraph license (or downgrade to only using Sourcegraph Core features).")
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// WriteSubscriptionErrorResponseForFeature is a wrapper around WriteSubscriptionErrorResponse that
|
||||
// generates the error title and message indicating that the current license does not active the
|
||||
// given feature.
|
||||
func WriteSubscriptionErrorResponseForFeature(w http.ResponseWriter, featureNameHumanReadable string) {
|
||||
WriteSubscriptionErrorResponse(
|
||||
w, http.StatusForbidden,
|
||||
fmt.Sprintf("License is not valid for %s", featureNameHumanReadable),
|
||||
fmt.Sprintf("To use the %s feature, a site admin must upgrade the Sourcegraph license. (The site admin may also remove the site configuration that enables this feature to dismiss this message.)", featureNameHumanReadable),
|
||||
)
|
||||
}
|
||||
|
||||
// WriteSubscriptionErrorResponse writes an HTTP response that displays a standalone error page to
|
||||
// the user.
|
||||
//
|
||||
// The title and message should be full sentences that describe the problem and how to fix it. Use
|
||||
// WriteSubscriptionErrorResponseForFeature to generate these for the common case of a failed
|
||||
// license feature check.
|
||||
func WriteSubscriptionErrorResponse(w http.ResponseWriter, statusCode int, title, message string) {
|
||||
w.WriteHeader(statusCode)
|
||||
w.Header().Set("Cache-Control", "no-cache")
|
||||
w.Header().Set("Content-Type", "text/html; charset=utf-8")
|
||||
// Inline all styles and resources because those requests will fail (our middleware
|
||||
// intercepts all HTTP requests).
|
||||
fmt.Fprintln(w, `
|
||||
<title>`+html.EscapeString(title)+` - Sourcegraph</title>
|
||||
<style>
|
||||
.bg {
|
||||
position: absolute;
|
||||
user-select: none;
|
||||
pointer-events: none;
|
||||
z-index: -1;
|
||||
top: 0;
|
||||
bottom: 0;
|
||||
left: 0;
|
||||
right: 0;
|
||||
/* The Sourcegraph logo in SVG. */
|
||||
background-image: url('data:image/svg+xml,<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 124 127"><g fill="none" fill-rule="evenodd"><path fill="%23F96316" d="M35.942 16.276L63.528 117.12c1.854 6.777 8.85 10.768 15.623 8.912 6.778-1.856 10.765-8.854 8.91-15.63L60.47 9.555C58.615 2.78 51.62-1.212 44.847.645c-6.772 1.853-10.76 8.853-8.905 15.63z"/><path fill="%23B200F8" d="M87.024 15.894L17.944 93.9c-4.66 5.26-4.173 13.303 1.082 17.964 5.255 4.66 13.29 4.174 17.95-1.084l69.08-78.005c4.66-5.26 4.173-13.3-1.082-17.962-5.257-4.664-13.294-4.177-17.95 1.08z"/><path fill="%2300B4F2" d="M8.75 59.12l98.516 32.595c6.667 2.205 13.86-1.414 16.065-8.087 2.21-6.672-1.41-13.868-8.08-16.076L16.738 34.96c-6.67-2.207-13.86 1.412-16.066 8.085-2.204 6.672 1.416 13.87 8.08 16.075z"/></g></svg>');
|
||||
background-repeat: repeat;
|
||||
background-size: 5rem;
|
||||
opacity: 0.1;
|
||||
}
|
||||
|
||||
.msg {
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif;
|
||||
max-width: 30rem;
|
||||
margin: 20vh auto 0;
|
||||
border: solid 2px rgba(255,0,0,0.8);
|
||||
background-color: rgba(255,255,255,0.8);
|
||||
color: rgb(30, 0, 0);
|
||||
padding: 1rem 2rem;
|
||||
}
|
||||
|
||||
h1 {
|
||||
font-size: 1.5rem;
|
||||
}
|
||||
</style>
|
||||
<meta name="robots" content="noindex">
|
||||
<body>
|
||||
<div class=bg></div>
|
||||
<div class=msg><h1>`+html.EscapeString(title)+`</h1><p>`+html.EscapeString(message)+`</p><p>See <a href="https://about.sourcegraph.com/pricing">about.sourcegraph.com</a> for more information.</p></div>`)
|
||||
}
|
||||
|
||||
// Enforce the license's max user count by preventing the creation of new users when the max is
|
||||
// reached.
|
||||
func init() {
|
||||
db.Users.PreCreateUser = func(ctx context.Context) error {
|
||||
info, err := GetConfiguredProductLicenseInfo()
|
||||
if info == nil || err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Block creation of a new user beyond the licensed user count (unless true-up is allowed).
|
||||
userCount, err := db.Users.Count(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Be conservative and treat 0 as unlimited. We don't plan to intentionally generate
|
||||
// licenses with UserCount == 0, but that might result from a bug in license decoding, and
|
||||
// we don't want that to immediately disable Sourcegraph instances.
|
||||
if info.UserCount > 0 && userCount >= int(info.UserCount) {
|
||||
if info.HasTag(TrueUpUserCountTag) {
|
||||
log15.Info("Licensed user count exceeded, but license supports true-up and will not block creation of new user. The new user will be retroactively charged for in the next billing period. Contact sales@sourcegraph.com for help.", "activeUserCount", userCount, "licensedUserCount", info.UserCount)
|
||||
} else {
|
||||
return errcode.NewPresentationError("Unable to create user account: the Sourcegraph license's maximum user count has been reached. A site admin must upgrade the Sourcegraph subscription to allow for more users.")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,86 @@
|
||||
package licensing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/sourcegraph/enterprise/pkg/license"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
|
||||
)
|
||||
|
||||
func TestEnforcementPreCreateUser(t *testing.T) {
|
||||
tests := []struct {
|
||||
license *license.Info
|
||||
activeUserCount uint
|
||||
wantErr bool
|
||||
}{
|
||||
// See the impl for why we treat UserCount == 0 as unlimited.
|
||||
{
|
||||
license: &license.Info{UserCount: 0},
|
||||
activeUserCount: 5,
|
||||
wantErr: false,
|
||||
},
|
||||
|
||||
// Non-true-up licenses.
|
||||
{
|
||||
license: &license.Info{UserCount: 10},
|
||||
activeUserCount: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
license: &license.Info{UserCount: 10},
|
||||
activeUserCount: 5,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
license: &license.Info{UserCount: 10},
|
||||
activeUserCount: 9,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
license: &license.Info{UserCount: 10},
|
||||
activeUserCount: 10,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
license: &license.Info{UserCount: 10},
|
||||
activeUserCount: 11,
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
license: &license.Info{UserCount: 10},
|
||||
activeUserCount: 12,
|
||||
wantErr: true,
|
||||
},
|
||||
|
||||
// True-up licenses.
|
||||
{
|
||||
license: &license.Info{Tags: []string{TrueUpUserCountTag}, UserCount: 10},
|
||||
activeUserCount: 5,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
license: &license.Info{Tags: []string{TrueUpUserCountTag}, UserCount: 10},
|
||||
activeUserCount: 15,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(fmt.Sprintf("license %s with %d active users", test.license, test.activeUserCount), func(t *testing.T) {
|
||||
MockGetConfiguredProductLicenseInfo = func() (*license.Info, error) {
|
||||
return test.license, nil
|
||||
}
|
||||
defer func() { MockGetConfiguredProductLicenseInfo = nil }()
|
||||
db.Mocks.Users.Count = func(context.Context, *db.UsersListOptions) (int, error) {
|
||||
return int(test.activeUserCount), nil
|
||||
}
|
||||
defer func() { db.Mocks = db.MockStores{} }()
|
||||
|
||||
err := db.Users.PreCreateUser(context.Background())
|
||||
if gotErr := (err != nil); gotErr != test.wantErr {
|
||||
t.Errorf("got error %v, want %v", gotErr, test.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
102
enterprise/cmd/frontend/internal/licensing/features.go
Normal file
102
enterprise/cmd/frontend/internal/licensing/features.go
Normal file
@ -0,0 +1,102 @@
|
||||
package licensing
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sourcegraph/enterprise/pkg/license"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/errcode"
|
||||
)
|
||||
|
||||
// Feature is a product feature that is selectively activated based on the current license key.
|
||||
type Feature string
|
||||
|
||||
// The list of features. For each feature, add a new const here and the checking logic in
|
||||
// isFeatureEnabled below.
|
||||
const (
|
||||
// FeatureExternalAuthProvider is whether external user authentication providers (aka "SSO") may
|
||||
// be used.
|
||||
FeatureExternalAuthProvider Feature = "sso-external-user-auth-provider"
|
||||
|
||||
// FeatureExtensionRegistry is whether publishing extensions to this Sourcegraph instance is
|
||||
// allowed. If not, then extensions must be published to Sourcegraph.com. All instances may use
|
||||
// extensions published to Sourcegraph.com.
|
||||
FeatureExtensionRegistry Feature = "private-extension-registry"
|
||||
|
||||
// FeatureRemoteExtensionsAllowDisallow is whether the site admin may explictly specify a list
|
||||
// of allowed remote extensions and prevent any other remote extensions from being used. It does
|
||||
// not apply to locally published extensions.
|
||||
FeatureRemoteExtensionsAllowDisallow = "remote-extensions-allow-disallow"
|
||||
)
|
||||
|
||||
func isFeatureEnabled(info license.Info, feature Feature) bool {
|
||||
// Add feature-specific logic here.
|
||||
switch feature {
|
||||
case FeatureExternalAuthProvider:
|
||||
// Enterprise Starter and Enterprise both allow SSO. Core doesn't, but this func is only
|
||||
// called when there is a valid license.
|
||||
return true
|
||||
case FeatureExtensionRegistry:
|
||||
// Enterprise Starter does not support a local extension registry.
|
||||
return !info.HasTag(EnterpriseStarterTag)
|
||||
case FeatureRemoteExtensionsAllowDisallow:
|
||||
// Enterprise Starter does not support explictly allowing/disallowing remote extensions by
|
||||
// extension ID.
|
||||
return !info.HasTag(EnterpriseStarterTag)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// CheckFeature checks whether the feature is activated based on the current license. If it is
|
||||
// disabled, it returns a non-nil error.
|
||||
//
|
||||
// The returned error may implement errcode.PresentationError to indicate that it can be displayed
|
||||
// directly to the user. Use IsFeatureNotActivated to distinguish between the error reasons.
|
||||
func CheckFeature(feature Feature) error {
|
||||
info, err := GetConfiguredProductLicenseInfo()
|
||||
if err != nil {
|
||||
return errors.WithMessage(err, fmt.Sprintf("checking feature %q activation", feature))
|
||||
}
|
||||
if info == nil {
|
||||
return newFeatureNotActivatedError(fmt.Sprintf("The feature %q is not activated because it requires a valid Sourcegraph license. Purchase a Sourcegraph subscription to activate this feature.", feature))
|
||||
}
|
||||
if !isFeatureEnabled(*info, feature) {
|
||||
return newFeatureNotActivatedError(fmt.Sprintf("The feature %q is not activated for Sourcegraph Enterprise Starter. Upgrade to Sourcegraph Enterprise to use this feature.", feature))
|
||||
}
|
||||
return nil // feature is activated for current license
|
||||
}
|
||||
|
||||
func newFeatureNotActivatedError(message string) featureNotActivatedError {
|
||||
e := errcode.NewPresentationError(message).(errcode.PresentationError)
|
||||
return featureNotActivatedError{e}
|
||||
}
|
||||
|
||||
type featureNotActivatedError struct{ errcode.PresentationError }
|
||||
|
||||
// IsFeatureNotActivated reports whether err indicates that the license is valid but does not
|
||||
// activate the feature.
|
||||
//
|
||||
// It is used to distinguish between the multiple reasons for errors from CheckFeature: either
|
||||
// failed license verification, or a valid license that does not activate a feature (e.g.,
|
||||
// Enterprise Starter not including an Enterprise-only feature).
|
||||
func IsFeatureNotActivated(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
_, ok := err.(featureNotActivatedError)
|
||||
if !ok {
|
||||
// Also check for the pointer type to guard against stupid mistakes.
|
||||
_, ok = err.(*featureNotActivatedError)
|
||||
}
|
||||
return ok
|
||||
}
|
||||
|
||||
// IsFeatureEnabledLenient reports whether the current license enables the given feature. If there
|
||||
// is an error reading the license, it is lenient and returns true.
|
||||
//
|
||||
// This is useful for callers who don't want to handle errors (usually because the user would be
|
||||
// prevented from getting to this point if license verification had failed, so it's not necessary to
|
||||
// handle license verification errors here).
|
||||
func IsFeatureEnabledLenient(feature Feature) bool {
|
||||
return !IsFeatureNotActivated(CheckFeature(feature))
|
||||
}
|
||||
27
enterprise/cmd/frontend/internal/licensing/features_test.go
Normal file
27
enterprise/cmd/frontend/internal/licensing/features_test.go
Normal file
@ -0,0 +1,27 @@
|
||||
package licensing
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/sourcegraph/enterprise/pkg/license"
|
||||
)
|
||||
|
||||
func TestIsFeatureEnabled(t *testing.T) {
|
||||
check := func(t *testing.T, feature Feature, licenseTags []string, wantEnabled bool) {
|
||||
t.Helper()
|
||||
got := isFeatureEnabled(license.Info{Tags: licenseTags}, feature)
|
||||
if got != wantEnabled {
|
||||
t.Errorf("got %v, want %v", got, wantEnabled)
|
||||
}
|
||||
}
|
||||
|
||||
t.Run(string(FeatureExternalAuthProvider), func(t *testing.T) {
|
||||
check(t, FeatureExternalAuthProvider, EnterpriseStarterTags, true)
|
||||
check(t, FeatureExternalAuthProvider, EnterpriseTags, true)
|
||||
})
|
||||
|
||||
t.Run(string(FeatureExtensionRegistry), func(t *testing.T) {
|
||||
check(t, FeatureExtensionRegistry, EnterpriseStarterTags, false)
|
||||
check(t, FeatureExtensionRegistry, EnterpriseTags, true)
|
||||
})
|
||||
}
|
||||
162
enterprise/cmd/frontend/internal/licensing/licensing.go
Normal file
162
enterprise/cmd/frontend/internal/licensing/licensing.go
Normal file
@ -0,0 +1,162 @@
|
||||
package licensing
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"sync"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sourcegraph/enterprise/pkg/license"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/conf"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/env"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// publicKey is the public key used to verify product license keys.
|
||||
//
|
||||
// It is hardcoded here intentionally (we only have one private signing key, and we don't yet
|
||||
// support/need key rotation). The corresponding private key is at
|
||||
// https://team-sourcegraph.1password.com/vaults/dnrhbauihkhjs5ag6vszsme45a/allitems/zkdx6gpw4uqejs3flzj7ef5j4i
|
||||
// and set below in SOURCEGRAPH_LICENSE_GENERATION_KEY.
|
||||
var publicKey = func() ssh.PublicKey {
|
||||
// To convert PKCS#8 format (which `openssl rsa -in key.pem -pubout` produces) to the format
|
||||
// that ssh.ParseAuthorizedKey reads here, use `ssh-keygen -i -mPKCS8 -f key.pub`.
|
||||
const publicKeyData = `ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDUUd9r83fGmYVLzcqQp5InyAoJB5lLxlM7s41SUUtxfnG6JpmvjNd+WuEptJGk0C/Zpyp/cCjCV4DljDs8Z7xjRbvJYW+vklFFxXrMTBs/+HjpIBKlYTmG8SqTyXyu1s4485Kh1fEC5SK6z2IbFaHuSHUXgDi/IepSOg1QudW4n8J91gPtT2E30/bPCBRq8oz/RVwJSDMvYYjYVb//LhV0Mx3O6hg4xzUNuwiCtNjCJ9t4YU2sV87+eJwWtQNbSQ8TelQa8WjG++XSnXUHw12bPDe7wGL/7/EJb7knggKSAMnpYpCyV35dyi4DsVc46c+b6P0gbVSosh3Uc3BJHSWF`
|
||||
var err error
|
||||
publicKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(publicKeyData))
|
||||
if err != nil {
|
||||
panic("failed to parse public key for license verification: " + err.Error())
|
||||
}
|
||||
return publicKey
|
||||
}()
|
||||
|
||||
// ParseProductLicenseKey parses and verifies the license key using the license verification public
|
||||
// key (publicKey in this package).
|
||||
func ParseProductLicenseKey(licenseKey string) (*license.Info, error) {
|
||||
return license.ParseSignedKey(licenseKey, publicKey)
|
||||
}
|
||||
|
||||
// ParseProductLicenseKeyWithBuiltinOrGenerationKey is like ParseProductLicenseKey, except it tries
|
||||
// parsing and verifying the license key with the license generation key (if set), instead of always
|
||||
// using the builtin license key.
|
||||
//
|
||||
// It is useful for local development when using a test license generation key (whose signatures
|
||||
// aren't considered valid when verified using the builtin public key).
|
||||
func ParseProductLicenseKeyWithBuiltinOrGenerationKey(licenseKey string) (*license.Info, error) {
|
||||
var k ssh.PublicKey
|
||||
if licenseGenerationPrivateKey != nil {
|
||||
k = licenseGenerationPrivateKey.PublicKey()
|
||||
} else {
|
||||
k = publicKey
|
||||
}
|
||||
return license.ParseSignedKey(licenseKey, k)
|
||||
}
|
||||
|
||||
// Cache the parsing of the license key because public key crypto can be slow.
|
||||
var (
|
||||
mu sync.Mutex
|
||||
lastKeyText string
|
||||
lastInfo *license.Info
|
||||
)
|
||||
|
||||
var MockGetConfiguredProductLicenseInfo func() (*license.Info, error)
|
||||
|
||||
// GetConfiguredProductLicenseInfo returns information about the current product license key
|
||||
// specified in site configuration.
|
||||
func GetConfiguredProductLicenseInfo() (*license.Info, error) {
|
||||
if MockGetConfiguredProductLicenseInfo != nil {
|
||||
return MockGetConfiguredProductLicenseInfo()
|
||||
}
|
||||
|
||||
// Support reading the license key from the environment (intended for development, because we
|
||||
// don't want to commit a valid license key to dev/config.json in the OSS repo).
|
||||
keyText := os.Getenv("SOURCEGRAPH_LICENSE_KEY")
|
||||
if keyText == "" {
|
||||
keyText = conf.Get().LicenseKey
|
||||
}
|
||||
|
||||
if keyText != "" {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
|
||||
var info *license.Info
|
||||
if keyText == lastKeyText {
|
||||
info = lastInfo
|
||||
} else {
|
||||
var err error
|
||||
info, err = ParseProductLicenseKey(keyText)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
lastKeyText = keyText
|
||||
lastInfo = info
|
||||
}
|
||||
return info, nil
|
||||
}
|
||||
|
||||
// No license key.
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Make the Site.productSubscription GraphQL field return the actual info about the product license,
|
||||
// if any.
|
||||
func init() {
|
||||
graphqlbackend.GetConfiguredProductLicenseInfo = func() (*graphqlbackend.ProductLicenseInfo, error) {
|
||||
info, err := GetConfiguredProductLicenseInfo()
|
||||
if info == nil || err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &graphqlbackend.ProductLicenseInfo{
|
||||
TagsValue: info.Tags,
|
||||
UserCountValue: info.UserCount,
|
||||
ExpiresAtValue: info.ExpiresAt,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// licenseGenerationPrivateKeyURL is the URL where Sourcegraph staff can find the private key for
|
||||
// generating licenses.
|
||||
//
|
||||
// NOTE: If you change this, use text search to replace other instances of it (in source code
|
||||
// comments).
|
||||
const licenseGenerationPrivateKeyURL = "https://team-sourcegraph.1password.com/vaults/dnrhbauihkhjs5ag6vszsme45a/allitems/zkdx6gpw4uqejs3flzj7ef5j4i"
|
||||
|
||||
// envLicenseGenerationPrivateKey (the env var SOURCEGRAPH_LICENSE_GENERATION_KEY) is the
|
||||
// PEM-encoded form of the private key used to sign product license keys. It is stored at
|
||||
// https://team-sourcegraph.1password.com/vaults/dnrhbauihkhjs5ag6vszsme45a/allitems/zkdx6gpw4uqejs3flzj7ef5j4i.
|
||||
var envLicenseGenerationPrivateKey = env.Get("SOURCEGRAPH_LICENSE_GENERATION_KEY", "", "the PEM-encoded form of the private key used to sign product license keys ("+licenseGenerationPrivateKeyURL+")")
|
||||
|
||||
// licenseGenerationPrivateKey is the private key used to generate license keys.
|
||||
var licenseGenerationPrivateKey = func() ssh.Signer {
|
||||
if envLicenseGenerationPrivateKey == "" {
|
||||
// Most Sourcegraph instances don't use/need this key. Generally only Sourcegraph.com and
|
||||
// local dev will have this key set.
|
||||
return nil
|
||||
}
|
||||
privateKey, err := ssh.ParsePrivateKey([]byte(envLicenseGenerationPrivateKey))
|
||||
if err != nil {
|
||||
log.Fatalf("Failed to parse private key in SOURCEGRAPH_LICENSE_GENERATION_KEY env var: %s.", err)
|
||||
}
|
||||
return privateKey
|
||||
}()
|
||||
|
||||
// GenerateProductLicenseKey generates a product license key using the license generation private
|
||||
// key configured in site configuration.
|
||||
func GenerateProductLicenseKey(info license.Info) (string, error) {
|
||||
if envLicenseGenerationPrivateKey == "" {
|
||||
const msg = "no product license generation private key was configured"
|
||||
if env.InsecureDev {
|
||||
// Show more helpful error message in local dev.
|
||||
return "", fmt.Errorf("%s (for testing by Sourcegraph staff: set the SOURCEGRAPH_LICENSE_GENERATION_KEY env var to the key obtained at %s)", msg, licenseGenerationPrivateKeyURL)
|
||||
}
|
||||
return "", errors.New(msg)
|
||||
}
|
||||
|
||||
licenseKey, err := license.GenerateSignedKey(info, licenseGenerationPrivateKey)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return licenseKey, nil
|
||||
}
|
||||
67
enterprise/cmd/frontend/internal/licensing/tags.go
Normal file
67
enterprise/cmd/frontend/internal/licensing/tags.go
Normal file
@ -0,0 +1,67 @@
|
||||
package licensing
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
|
||||
)
|
||||
|
||||
// Make the Site.productSubscription.productNameWithBrand GraphQL field (and other places) use the
|
||||
// proper product name.
|
||||
func init() {
|
||||
graphqlbackend.GetProductNameWithBrand = productNameWithBrand
|
||||
}
|
||||
|
||||
const (
|
||||
// EnterpriseStarterTag is the license tag for Enterprise Starter (which includes only a subset
|
||||
// of Enterprise features).
|
||||
EnterpriseStarterTag = "starter"
|
||||
|
||||
// TrueUpUserCountTag is the license tag that indicates that the licensed user count can be
|
||||
// exceeded and will be charged later.
|
||||
TrueUpUserCountTag = "true-up"
|
||||
)
|
||||
|
||||
var (
|
||||
// EnterpriseStarterTags is the license tags for Enterprise Starter.
|
||||
EnterpriseStarterTags = []string{EnterpriseStarterTag}
|
||||
|
||||
// EnterpriseTags is the license tags for Enterprise (intentionally empty because it has no
|
||||
// feature restrictions)
|
||||
EnterpriseTags = []string{}
|
||||
)
|
||||
|
||||
// productNameWithBrand returns the product name with brand (e.g., "Sourcegraph Enterprise") based
|
||||
// on the license info.
|
||||
func productNameWithBrand(hasLicense bool, licenseTags []string) string {
|
||||
if !hasLicense {
|
||||
return "Sourcegraph Core"
|
||||
}
|
||||
|
||||
hasTag := func(tag string) bool {
|
||||
for _, t := range licenseTags {
|
||||
if tag == t {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var name string
|
||||
if hasTag("starter") {
|
||||
name = " Starter"
|
||||
}
|
||||
|
||||
var misc []string
|
||||
if hasTag("trial") {
|
||||
misc = append(misc, "trial")
|
||||
}
|
||||
if hasTag("dev") {
|
||||
misc = append(misc, "dev use only")
|
||||
}
|
||||
if len(misc) > 0 {
|
||||
name += " (" + strings.Join(misc, ", ") + ")"
|
||||
}
|
||||
|
||||
return "Sourcegraph Enterprise" + name
|
||||
}
|
||||
33
enterprise/cmd/frontend/internal/licensing/tags_test.go
Normal file
33
enterprise/cmd/frontend/internal/licensing/tags_test.go
Normal file
@ -0,0 +1,33 @@
|
||||
package licensing
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestProductNameWithBrand(t *testing.T) {
|
||||
tests := []struct {
|
||||
hasLicense bool
|
||||
licenseTags []string
|
||||
want string
|
||||
}{
|
||||
{hasLicense: false, want: "Sourcegraph Core"},
|
||||
{hasLicense: true, licenseTags: nil, want: "Sourcegraph Enterprise"},
|
||||
{hasLicense: true, licenseTags: []string{}, want: "Sourcegraph Enterprise"},
|
||||
{hasLicense: true, licenseTags: []string{"x"}, want: "Sourcegraph Enterprise"}, // unrecognized tag "x" is ignored
|
||||
{hasLicense: true, licenseTags: []string{"starter"}, want: "Sourcegraph Enterprise Starter"},
|
||||
{hasLicense: true, licenseTags: []string{"trial"}, want: "Sourcegraph Enterprise (trial)"},
|
||||
{hasLicense: true, licenseTags: []string{"dev"}, want: "Sourcegraph Enterprise (dev use only)"},
|
||||
{hasLicense: true, licenseTags: []string{"starter", "trial"}, want: "Sourcegraph Enterprise Starter (trial)"},
|
||||
{hasLicense: true, licenseTags: []string{"starter", "dev"}, want: "Sourcegraph Enterprise Starter (dev use only)"},
|
||||
{hasLicense: true, licenseTags: []string{"starter", "trial", "dev"}, want: "Sourcegraph Enterprise Starter (trial, dev use only)"},
|
||||
{hasLicense: true, licenseTags: []string{"trial", "dev"}, want: "Sourcegraph Enterprise (trial, dev use only)"},
|
||||
}
|
||||
for _, test := range tests {
|
||||
t.Run(fmt.Sprintf("hasLicense=%v licenseTags=%v", test.hasLicense, test.licenseTags), func(t *testing.T) {
|
||||
if got := productNameWithBrand(test.hasLicense, test.licenseTags); got != test.want {
|
||||
t.Errorf("got %q, want %q", got, test.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
58
enterprise/cmd/frontend/internal/registry/allow.go
Normal file
58
enterprise/cmd/frontend/internal/registry/allow.go
Normal file
@ -0,0 +1,58 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"github.com/sourcegraph/enterprise/cmd/frontend/internal/licensing"
|
||||
frontendregistry "github.com/sourcegraph/sourcegraph/cmd/frontend/registry"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/conf"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/registry"
|
||||
)
|
||||
|
||||
func init() {
|
||||
frontendregistry.IsRemoteExtensionAllowed = func(extensionID string) bool {
|
||||
allowedExtensions := getAllowedExtensionsFromSiteConfig()
|
||||
if allowedExtensions == nil {
|
||||
// Default is to allow all extensions.
|
||||
return true
|
||||
}
|
||||
|
||||
for _, x := range allowedExtensions {
|
||||
if extensionID == x {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
frontendregistry.FilterRemoteExtensions = func(extensions []*registry.Extension) []*registry.Extension {
|
||||
allowedExtensions := getAllowedExtensionsFromSiteConfig()
|
||||
if allowedExtensions == nil {
|
||||
// Default is to allow all extensions.
|
||||
return extensions
|
||||
}
|
||||
|
||||
allow := make(map[string]interface{})
|
||||
for _, id := range allowedExtensions {
|
||||
allow[id] = struct{}{}
|
||||
}
|
||||
var keep []*registry.Extension
|
||||
for _, x := range extensions {
|
||||
if _, ok := allow[x.ExtensionID]; ok {
|
||||
keep = append(keep, x)
|
||||
}
|
||||
}
|
||||
return keep
|
||||
}
|
||||
}
|
||||
|
||||
func getAllowedExtensionsFromSiteConfig() []string {
|
||||
// If the remote extension allow/disallow feature is not enabled, all remote extensions are
|
||||
// allowed. This is achieved by a nil list.
|
||||
if !licensing.IsFeatureEnabledLenient(licensing.FeatureRemoteExtensionsAllowDisallow) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if c := conf.Get().Extensions; c != nil {
|
||||
return c.AllowRemoteExtensions
|
||||
}
|
||||
return nil
|
||||
}
|
||||
93
enterprise/cmd/frontend/internal/registry/allow_test.go
Normal file
93
enterprise/cmd/frontend/internal/registry/allow_test.go
Normal file
@ -0,0 +1,93 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"sort"
|
||||
"testing"
|
||||
|
||||
"github.com/sourcegraph/enterprise/cmd/frontend/internal/licensing"
|
||||
"github.com/sourcegraph/enterprise/pkg/license"
|
||||
frontendregistry "github.com/sourcegraph/sourcegraph/cmd/frontend/registry"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/conf"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/registry"
|
||||
"github.com/sourcegraph/sourcegraph/schema"
|
||||
)
|
||||
|
||||
func TestIsRemoteExtensionAllowed(t *testing.T) {
|
||||
licensing.MockGetConfiguredProductLicenseInfo = func() (*license.Info, error) {
|
||||
return &license.Info{Tags: licensing.EnterpriseTags}, nil
|
||||
}
|
||||
defer func() { licensing.MockGetConfiguredProductLicenseInfo = nil }()
|
||||
defer conf.Mock(nil)
|
||||
|
||||
if !frontendregistry.IsRemoteExtensionAllowed("a") {
|
||||
t.Errorf("want %q to be allowed", "a")
|
||||
}
|
||||
|
||||
conf.Mock(&schema.SiteConfiguration{Extensions: &schema.Extensions{AllowRemoteExtensions: nil}})
|
||||
if !frontendregistry.IsRemoteExtensionAllowed("a") {
|
||||
t.Errorf("want %q to be allowed", "a")
|
||||
}
|
||||
|
||||
conf.Mock(&schema.SiteConfiguration{Extensions: &schema.Extensions{AllowRemoteExtensions: []string{}}})
|
||||
if frontendregistry.IsRemoteExtensionAllowed("a") {
|
||||
t.Errorf("want %q to be disallowed", "a")
|
||||
}
|
||||
|
||||
conf.Mock(&schema.SiteConfiguration{Extensions: &schema.Extensions{AllowRemoteExtensions: []string{"a"}}})
|
||||
if !frontendregistry.IsRemoteExtensionAllowed("a") {
|
||||
t.Errorf("want %q to be allowed", "a")
|
||||
}
|
||||
}
|
||||
|
||||
func sameElements(a, b []string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
|
||||
aCopy := make([]string, len(a))
|
||||
bCopy := make([]string, len(b))
|
||||
|
||||
copy(aCopy, a)
|
||||
copy(bCopy, b)
|
||||
|
||||
sort.Strings(aCopy)
|
||||
sort.Strings(bCopy)
|
||||
|
||||
return reflect.DeepEqual(aCopy, bCopy)
|
||||
}
|
||||
|
||||
func TestFilterRemoteExtensions(t *testing.T) {
|
||||
licensing.MockGetConfiguredProductLicenseInfo = func() (*license.Info, error) {
|
||||
return &license.Info{Tags: licensing.EnterpriseTags}, nil
|
||||
}
|
||||
defer func() { licensing.MockGetConfiguredProductLicenseInfo = nil }()
|
||||
|
||||
run := func(allowRemoteExtensions *[]string, extensions []string, want []string) {
|
||||
t.Helper()
|
||||
if allowRemoteExtensions != nil {
|
||||
conf.Mock(&schema.SiteConfiguration{Extensions: &schema.Extensions{AllowRemoteExtensions: *allowRemoteExtensions}})
|
||||
defer conf.Mock(nil)
|
||||
}
|
||||
var xs []*registry.Extension
|
||||
for _, id := range extensions {
|
||||
xs = append(xs, ®istry.Extension{ExtensionID: id})
|
||||
}
|
||||
got := []string{}
|
||||
for _, x := range frontendregistry.FilterRemoteExtensions(xs) {
|
||||
got = append(got, x.ExtensionID)
|
||||
}
|
||||
if !sameElements(got, want) {
|
||||
t.Errorf("want %+v got %+v", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
run(nil, []string{}, []string{})
|
||||
run(nil, []string{"a"}, []string{"a"})
|
||||
run(&[]string{}, []string{}, []string{})
|
||||
run(&[]string{"a"}, []string{}, []string{})
|
||||
run(&[]string{}, []string{"a"}, []string{})
|
||||
run(&[]string{"a"}, []string{"b"}, []string{})
|
||||
run(&[]string{"a"}, []string{"a"}, []string{"a"})
|
||||
run(&[]string{"b", "c"}, []string{"a", "b", "c"}, []string{"b", "c"})
|
||||
}
|
||||
2
enterprise/cmd/frontend/internal/registry/doc.go
Normal file
2
enterprise/cmd/frontend/internal/registry/doc.go
Normal file
@ -0,0 +1,2 @@
|
||||
// Package registry contains the implementation of the extension registry.
|
||||
package registry
|
||||
@ -0,0 +1,96 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
frontendregistry "github.com/sourcegraph/sourcegraph/cmd/frontend/registry"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/conf"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/errcode"
|
||||
)
|
||||
|
||||
func init() {
|
||||
frontendregistry.HandleRegistryExtensionBundle = handleRegistryExtensionBundle
|
||||
}
|
||||
|
||||
// handleRegistryExtensionBundle serves the bundled JavaScript source file or the source map for an
|
||||
// extension in the registry as a raw JavaScript or JSON file.
|
||||
func handleRegistryExtensionBundle(w http.ResponseWriter, r *http.Request) {
|
||||
if conf.Extensions() == nil {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
filename := mux.Vars(r)["RegistryExtensionReleaseFilename"]
|
||||
ext := filepath.Ext(filename)
|
||||
wantSourceMap := ext == ".map"
|
||||
releaseIDStr := strings.TrimSuffix(filename, ext)
|
||||
releaseID, err := strconv.ParseInt(releaseIDStr, 10, 64)
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusBadRequest)
|
||||
return
|
||||
}
|
||||
bundle, sourceMap, err := dbReleases{}.GetArtifacts(r.Context(), releaseID)
|
||||
if errcode.IsNotFound(err) {
|
||||
http.Error(w, err.Error(), http.StatusNotFound)
|
||||
return
|
||||
} else if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
// 🚨 SECURITY: Prevent this URL from being used in a <script> tag from other sites, because
|
||||
// hosting user-provided scripts on this domain would let attackers steal sensitive data from
|
||||
// anyone they lure to the attacker's site.
|
||||
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
|
||||
w.Header().Set("Content-Security-Policy", "default-src 'none'; style-src 'unsafe-inline'; sandbox")
|
||||
w.Header().Set("X-Frame-Options", "deny")
|
||||
w.Header().Set("X-Content-Type-Options", "nosniff")
|
||||
w.Header().Set("X-XSS-Protection", "1; mode=block")
|
||||
|
||||
// Allow downstream Sourcegraph sites' clients to access this file directly.
|
||||
w.Header().Del("Access-Control-Allow-Credentials") // credentials are not needed
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
|
||||
// We want to cache forever because an extension release is immutable, except that if the
|
||||
// database is reset and and the registry_extension_releases.id sequence starts over, we don't
|
||||
// want stale data from pre-reset. So, assume that the presence of a query string means that it
|
||||
// includes some identifier that changes when the database is reset.
|
||||
if r.URL.RawQuery != "" {
|
||||
w.Header().Set("Cache-Control", "max-age=604800, private, immutable")
|
||||
}
|
||||
var data []byte
|
||||
if wantSourceMap {
|
||||
if sourceMap == nil {
|
||||
http.Error(w, "extension has no source map", http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
data = sourceMap
|
||||
} else {
|
||||
data = bundle
|
||||
}
|
||||
w.Write(data)
|
||||
|
||||
if !wantSourceMap && sourceMap != nil {
|
||||
// Append `//# sourceMappingURL=` directive to JS bundle if we have a source map. It is
|
||||
// necessary to provide the absolute URL because the JS bundle is not loaded directly (e.g.,
|
||||
// via importScripts); it is saved to a blob URL and then executed, which means any relative
|
||||
// source map URL would be interpreted relative to the blob URL (so a relative URL wouldn't
|
||||
// work). Also, we can't rely on the original sourceMappingURL directive (if provided at
|
||||
// publish time) because it has no way of knowing the absolute URL to the source map.
|
||||
//
|
||||
// This implementation is not ideal because it means the JS bundle's contents depend on the
|
||||
// app URL, which makes it technically not immutable. But given the blob URL constraint
|
||||
// mentioned above, it's the best known solution.
|
||||
if appURL, _ := url.Parse(conf.Get().AppURL); appURL != nil {
|
||||
sourceMapURL := appURL.ResolveReference(&url.URL{Path: path.Join(path.Dir(r.URL.Path), releaseIDStr+".map")}).String()
|
||||
fmt.Fprintf(w, "\n//# sourceMappingURL=%s", sourceMapURL)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,81 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/registry"
|
||||
)
|
||||
|
||||
func init() {
|
||||
registry.ListLocalRegistryExtensions = listLocalRegistryExtensions
|
||||
registry.CountLocalRegistryExtensions = countLocalRegistryExtensions
|
||||
}
|
||||
|
||||
func listLocalRegistryExtensions(ctx context.Context, args graphqlbackend.RegistryExtensionConnectionArgs) ([]graphqlbackend.RegistryExtension, error) {
|
||||
if args.PrioritizeExtensionIDs != nil {
|
||||
ids := filterStripLocalExtensionIDs(*args.PrioritizeExtensionIDs)
|
||||
args.PrioritizeExtensionIDs = &ids
|
||||
}
|
||||
opt, err := toDBExtensionsListOptions(args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
xs, err := dbExtensions{}.List(ctx, opt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := prefixLocalExtensionID(xs...); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
xs2 := make([]graphqlbackend.RegistryExtension, len(xs))
|
||||
for i, x := range xs {
|
||||
xs2[i] = &extensionDBResolver{v: x}
|
||||
}
|
||||
return xs2, nil
|
||||
}
|
||||
|
||||
func countLocalRegistryExtensions(ctx context.Context, args graphqlbackend.RegistryExtensionConnectionArgs) (int, error) {
|
||||
opt, err := toDBExtensionsListOptions(args)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return dbExtensions{}.Count(ctx, opt)
|
||||
}
|
||||
|
||||
func toDBExtensionsListOptions(args graphqlbackend.RegistryExtensionConnectionArgs) (dbExtensionsListOptions, error) {
|
||||
var opt dbExtensionsListOptions
|
||||
args.ConnectionArgs.Set(&opt.LimitOffset)
|
||||
if args.Publisher != nil {
|
||||
p, err := unmarshalRegistryPublisherID(*args.Publisher)
|
||||
if err != nil {
|
||||
return opt, err
|
||||
}
|
||||
opt.Publisher.UserID = p.userID
|
||||
opt.Publisher.OrgID = p.orgID
|
||||
}
|
||||
if args.Query != nil {
|
||||
opt.Query = *args.Query
|
||||
}
|
||||
if args.PrioritizeExtensionIDs != nil {
|
||||
opt.PrioritizeExtensionIDs = *args.PrioritizeExtensionIDs
|
||||
}
|
||||
return opt, nil
|
||||
}
|
||||
|
||||
// filterStripLocalExtensionIDs filters to local extension IDs and strips the
|
||||
// host prefix.
|
||||
func filterStripLocalExtensionIDs(extensionIDs []string) []string {
|
||||
prefix := registry.GetLocalRegistryExtensionIDPrefix()
|
||||
local := []string{}
|
||||
for _, id := range extensionIDs {
|
||||
parts := strings.SplitN(id, "/", 3)
|
||||
if prefix != nil && len(parts) == 3 && parts[0] == *prefix {
|
||||
local = append(local, parts[1]+"/"+parts[2])
|
||||
} else if (prefix == nil || *prefix == "") && len(parts) == 2 {
|
||||
local = append(local, id)
|
||||
}
|
||||
}
|
||||
return local
|
||||
}
|
||||
@ -0,0 +1,35 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/globals"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/conf"
|
||||
"github.com/sourcegraph/sourcegraph/schema"
|
||||
)
|
||||
|
||||
func TestFilteringExtensionIDs(t *testing.T) {
|
||||
t.Run("filterStripLocalExtensionIDs on localhost", func(t *testing.T) {
|
||||
conf.Mock(&schema.SiteConfiguration{AppURL: "http://localhost:3080"})
|
||||
defer conf.Mock(nil)
|
||||
input := []string{"localhost:3080/owner1/name1", "owner2/name2"}
|
||||
want := []string{"owner1/name1"}
|
||||
got := filterStripLocalExtensionIDs(input)
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("got %+v, want %+v", got, want)
|
||||
}
|
||||
})
|
||||
t.Run("filterStripLocalExtensionIDs on Sourcegraph.com", func(t *testing.T) {
|
||||
oldAppURL := globals.AppURL
|
||||
globals.AppURL = &url.URL{Scheme: "https", Host: "sourcegraph.com"}
|
||||
defer func() { globals.AppURL = oldAppURL }()
|
||||
input := []string{"localhost:3080/owner1/name1", "owner2/name2"}
|
||||
want := []string{"owner2/name2"}
|
||||
got := filterStripLocalExtensionIDs(input)
|
||||
if !reflect.DeepEqual(got, want) {
|
||||
t.Fatalf("got %+v, want %+v", got, want)
|
||||
}
|
||||
})
|
||||
}
|
||||
@ -0,0 +1,72 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
graphql "github.com/graph-gophers/graphql-go"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/backend"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/registry"
|
||||
)
|
||||
|
||||
// extensionDBResolver implements the GraphQL type RegistryExtension.
|
||||
type extensionDBResolver struct {
|
||||
v *dbExtension
|
||||
}
|
||||
|
||||
func (r *extensionDBResolver) ID() graphql.ID {
|
||||
return registry.MarshalRegistryExtensionID(registry.RegistryExtensionID{LocalID: r.v.ID})
|
||||
}
|
||||
|
||||
func (r *extensionDBResolver) UUID() string { return r.v.UUID }
|
||||
func (r *extensionDBResolver) ExtensionID() string { return r.v.NonCanonicalExtensionID }
|
||||
func (r *extensionDBResolver) ExtensionIDWithoutRegistry() string {
|
||||
if r.v.NonCanonicalRegistry != "" {
|
||||
return strings.TrimPrefix(r.v.NonCanonicalExtensionID, r.v.NonCanonicalRegistry+"/")
|
||||
}
|
||||
return r.v.NonCanonicalExtensionID
|
||||
}
|
||||
|
||||
func (r *extensionDBResolver) Publisher(ctx context.Context) (graphqlbackend.RegistryPublisher, error) {
|
||||
return getRegistryPublisher(ctx, r.v.Publisher)
|
||||
}
|
||||
|
||||
func (r *extensionDBResolver) Name() string { return r.v.Name }
|
||||
func (r *extensionDBResolver) Manifest(ctx context.Context) (graphqlbackend.ExtensionManifest, error) {
|
||||
manifest, err := getExtensionManifestWithBundleURL(ctx, r.v.NonCanonicalExtensionID, r.v.ID, "release")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return registry.NewExtensionManifest(manifest), nil
|
||||
}
|
||||
|
||||
func (r *extensionDBResolver) CreatedAt() *string {
|
||||
return strptr(r.v.CreatedAt.Format(time.RFC3339))
|
||||
}
|
||||
|
||||
func (r *extensionDBResolver) UpdatedAt() *string {
|
||||
return strptr(r.v.UpdatedAt.Format(time.RFC3339))
|
||||
}
|
||||
|
||||
func (r *extensionDBResolver) URL() string {
|
||||
return registry.ExtensionURL(r.v.NonCanonicalExtensionID)
|
||||
}
|
||||
func (r *extensionDBResolver) RemoteURL() *string { return nil }
|
||||
|
||||
func (r *extensionDBResolver) RegistryName() (string, error) {
|
||||
return r.v.NonCanonicalRegistry, nil
|
||||
}
|
||||
|
||||
func (r *extensionDBResolver) IsLocal() bool { return true }
|
||||
|
||||
func (r *extensionDBResolver) ViewerCanAdminister(ctx context.Context) (bool, error) {
|
||||
err := toRegistryPublisherID(r.v).viewerCanAdminister(ctx)
|
||||
if err == backend.ErrMustBeSiteAdmin || err == backend.ErrNotAnOrgMember || err == backend.ErrNotAuthenticated {
|
||||
return false, nil
|
||||
}
|
||||
return err == nil, err
|
||||
}
|
||||
|
||||
func strptr(s string) *string { return &s }
|
||||
@ -0,0 +1,71 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"path"
|
||||
"strconv"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/pkg/conf"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/errcode"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/jsonc"
|
||||
)
|
||||
|
||||
// validateExtensionManifest validates a JSON extension manifest for syntax.
|
||||
//
|
||||
// TODO(sqs): Also validate it against the JSON Schema.
|
||||
func validateExtensionManifest(text string) error {
|
||||
var o interface{}
|
||||
return jsonc.Unmarshal(text, &o)
|
||||
}
|
||||
|
||||
// getExtensionManifestWithBundleURL returns the extension manifest as JSON. If there are no
|
||||
// releases, it returns a nil manifest. If the manifest has no "url" field itself, a "url" field
|
||||
// pointing to the extension's bundle is inserted.
|
||||
func getExtensionManifestWithBundleURL(ctx context.Context, extensionID string, registryExtensionID int32, releaseTag string) (*string, error) {
|
||||
var manifest *string
|
||||
release, err := dbReleases{}.GetLatest(ctx, registryExtensionID, releaseTag, false)
|
||||
if err != nil && !errcode.IsNotFound(err) {
|
||||
return nil, err
|
||||
}
|
||||
if release != nil {
|
||||
// Add URL to bundle if necessary.
|
||||
var o map[string]interface{}
|
||||
if err := jsonc.Unmarshal(release.Manifest, &o); err != nil {
|
||||
return nil, fmt.Errorf("parsing extension manifest for extension with ID %d (release tag %q): %s", registryExtensionID, releaseTag, err)
|
||||
}
|
||||
if o == nil {
|
||||
o = map[string]interface{}{}
|
||||
}
|
||||
urlStr, _ := o["url"].(string)
|
||||
if urlStr == "" {
|
||||
// Insert "url" field with link to bundle file on this site.
|
||||
bundleURL, err := makeExtensionBundleURL(release.ID, release.CreatedAt.UnixNano(), extensionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
o["url"] = bundleURL
|
||||
b, err := json.MarshalIndent(o, "", " ")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
release.Manifest = string(b)
|
||||
}
|
||||
|
||||
manifest = &release.Manifest
|
||||
}
|
||||
|
||||
return manifest, nil
|
||||
}
|
||||
|
||||
func makeExtensionBundleURL(registryExtensionReleaseID int64, timestamp int64, extensionIDHint string) (string, error) {
|
||||
u, err := url.Parse(conf.Get().AppURL)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
u.Path = path.Join(u.Path, fmt.Sprintf("/-/static/extension/%d.js", registryExtensionReleaseID))
|
||||
u.RawQuery = extensionIDHint + "--" + strconv.FormatInt(timestamp, 36) // meaningless value, just for cache-busting
|
||||
return u.String(), nil
|
||||
}
|
||||
@ -0,0 +1,63 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestGetExtensionManifestWithBundleURL(t *testing.T) {
|
||||
resetMocks()
|
||||
ctx := context.Background()
|
||||
|
||||
nilOrEmpty := func(s *string) string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
return *s
|
||||
}
|
||||
|
||||
t.Run(`manifest with "url"`, func(t *testing.T) {
|
||||
mocks.releases.GetLatest = func(registryExtensionID int32, releaseTag string, includeArtifacts bool) (*dbRelease, error) {
|
||||
return &dbRelease{
|
||||
Manifest: `{"name":"x","url":"u"}`,
|
||||
}, nil
|
||||
}
|
||||
defer func() { mocks.releases.GetLatest = nil }()
|
||||
manifest, err := getExtensionManifestWithBundleURL(ctx, "x", 1, "t")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if want := `{"name":"x","url":"u"}`; manifest == nil || !jsonDeepEqual(*manifest, want) {
|
||||
t.Errorf("got %q, want %q", nilOrEmpty(manifest), want)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run(`manifest without "url"`, func(t *testing.T) {
|
||||
mocks.releases.GetLatest = func(registryExtensionID int32, releaseTag string, includeArtifacts bool) (*dbRelease, error) {
|
||||
return &dbRelease{
|
||||
Manifest: `{"name":"x"}`,
|
||||
}, nil
|
||||
}
|
||||
defer func() { mocks.releases.GetLatest = nil }()
|
||||
manifest, err := getExtensionManifestWithBundleURL(ctx, "x", 1, "t")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if want := `{"name":"x","url":"/-/static/extension/0.js?x---1fmlvpbbdw2yo"}`; manifest == nil || !jsonDeepEqual(*manifest, want) {
|
||||
t.Errorf("got %q, want %q", nilOrEmpty(manifest), want)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func jsonDeepEqual(a, b string) bool {
|
||||
var va, vb interface{}
|
||||
if err := json.Unmarshal([]byte(a), &va); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
if err := json.Unmarshal([]byte(b), &vb); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return reflect.DeepEqual(va, vb)
|
||||
}
|
||||
37
enterprise/cmd/frontend/internal/registry/extensions.go
Normal file
37
enterprise/cmd/frontend/internal/registry/extensions.go
Normal file
@ -0,0 +1,37 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/registry"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/conf"
|
||||
)
|
||||
|
||||
func init() {
|
||||
conf.DefaultRemoteRegistry = "https://sourcegraph.com/.api/registry"
|
||||
registry.GetLocalExtensionByExtensionID = func(ctx context.Context, extensionIDWithoutPrefix string) (graphqlbackend.RegistryExtension, error) {
|
||||
x, err := dbExtensions{}.GetByExtensionID(ctx, extensionIDWithoutPrefix)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := prefixLocalExtensionID(x); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &extensionDBResolver{v: x}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// prefixLocalExtensionID adds the local registry's extension ID prefix (from
|
||||
// GetLocalRegistryExtensionIDPrefix) to all extensions' extension IDs in the list.
|
||||
func prefixLocalExtensionID(xs ...*dbExtension) error {
|
||||
prefix := registry.GetLocalRegistryExtensionIDPrefix()
|
||||
if prefix == nil {
|
||||
return nil
|
||||
}
|
||||
for _, x := range xs {
|
||||
x.NonCanonicalExtensionID = *prefix + "/" + x.NonCanonicalExtensionID
|
||||
x.NonCanonicalRegistry = *prefix
|
||||
}
|
||||
return nil
|
||||
}
|
||||
323
enterprise/cmd/frontend/internal/registry/extensions_db.go
Normal file
323
enterprise/cmd/frontend/internal/registry/extensions_db.go
Normal file
@ -0,0 +1,323 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/keegancsmith/sqlf"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/db/dbconn"
|
||||
)
|
||||
|
||||
// dbExtension describes an extension in the extension registry.
|
||||
//
|
||||
// It is the internal form of github.com/sourcegraph/sourcegraph/pkg/registry.Extension (which is
|
||||
// the external API type). These types should generally be kept in sync, but registry.Extension
|
||||
// updates require backcompat.
|
||||
type dbExtension struct {
|
||||
ID int32
|
||||
UUID string
|
||||
Publisher dbPublisher
|
||||
Name string
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
|
||||
// NonCanonicalExtensionID is the denormalized fully qualified extension ID
|
||||
// ("[registry/]publisher/name" format), using the username/name of the extension's publisher
|
||||
// (joined from another table) as of when the query executed. Do not persist this, because the
|
||||
// (denormalized) registry and publisher names can change.
|
||||
//
|
||||
// If this value is obtained directly from a method on RegistryExtensions, this field will never
|
||||
// contain the registry name prefix (which is necessary to distinguish local extensions from
|
||||
// remote extensions). Call prefixLocalExtensionID to add it. The recommended way to apply this
|
||||
// automatically (when needed) is to use registry.GetExtensionByExtensionID instead of
|
||||
// (dbExtensions).GetByExtensionID.
|
||||
NonCanonicalExtensionID string
|
||||
|
||||
// NonCanonicalRegistry is the denormalized registry name (as of when this field was set). This
|
||||
// field is only set by prefixLocalExtensionID and is always empty if this value is obtained
|
||||
// directly from a method on RegistryExtensions. Do not persist this value, because the
|
||||
// (denormalized) registry name can change.
|
||||
NonCanonicalRegistry string
|
||||
}
|
||||
|
||||
type dbExtensions struct{}
|
||||
|
||||
// extensionNotFoundError occurs when an extension is not found in the extension registry.
|
||||
type extensionNotFoundError struct {
|
||||
args []interface{}
|
||||
}
|
||||
|
||||
// NotFound implements errcode.NotFounder.
|
||||
func (err extensionNotFoundError) NotFound() bool { return true }
|
||||
|
||||
func (err extensionNotFoundError) Error() string {
|
||||
return fmt.Sprintf("registry extension not found: %v", err.args)
|
||||
}
|
||||
|
||||
// Create creates a new extension in the extension registry. Exactly 1 of publisherUserID and publisherOrgID must be nonzero.
|
||||
func (s dbExtensions) Create(ctx context.Context, publisherUserID, publisherOrgID int32, name string) (id int32, err error) {
|
||||
if mocks.extensions.Create != nil {
|
||||
return mocks.extensions.Create(publisherUserID, publisherOrgID, name)
|
||||
}
|
||||
|
||||
if publisherUserID != 0 && publisherOrgID != 0 {
|
||||
return 0, errors.New("at most 1 of the publisher user/org may be set")
|
||||
}
|
||||
|
||||
uuid, err := uuid.NewRandom()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if err := dbconn.Global.QueryRowContext(ctx,
|
||||
// Include users/orgs table query (with "FOR UPDATE") to ensure that the publisher user/org
|
||||
// not been deleted. If it was deleted, the query will return an error.
|
||||
`
|
||||
INSERT INTO registry_extensions(uuid, publisher_user_id, publisher_org_id, name)
|
||||
VALUES(
|
||||
$1,
|
||||
(SELECT id FROM users WHERE id=$2 AND deleted_at IS NULL FOR UPDATE),
|
||||
(SELECT id FROM orgs WHERE id=$3 AND deleted_at IS NULL FOR UPDATE),
|
||||
$4
|
||||
)
|
||||
RETURNING id
|
||||
`,
|
||||
uuid, publisherUserID, publisherOrgID, name,
|
||||
).Scan(&id); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return id, nil
|
||||
}
|
||||
|
||||
// GetByID retrieves the registry extension (if any) given its ID.
|
||||
//
|
||||
// 🚨 SECURITY: The caller must ensure that the actor is permitted to view this registry extension.
|
||||
func (s dbExtensions) GetByID(ctx context.Context, id int32) (*dbExtension, error) {
|
||||
if mocks.extensions.GetByID != nil {
|
||||
return mocks.extensions.GetByID(id)
|
||||
}
|
||||
|
||||
results, err := s.list(ctx, []*sqlf.Query{sqlf.Sprintf("x.id=%d", id)}, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(results) == 0 {
|
||||
return nil, extensionNotFoundError{[]interface{}{id}}
|
||||
}
|
||||
return results[0], nil
|
||||
}
|
||||
|
||||
// GetByUUID retrieves the registry extension (if any) given its UUID.
|
||||
//
|
||||
// 🚨 SECURITY: The caller must ensure that the actor is permitted to view this registry extension.
|
||||
func (s dbExtensions) GetByUUID(ctx context.Context, uuid string) (*dbExtension, error) {
|
||||
if mocks.extensions.GetByUUID != nil {
|
||||
return mocks.extensions.GetByUUID(uuid)
|
||||
}
|
||||
|
||||
results, err := s.list(ctx, []*sqlf.Query{sqlf.Sprintf("x.uuid=%d", uuid)}, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(results) == 0 {
|
||||
return nil, extensionNotFoundError{[]interface{}{uuid}}
|
||||
}
|
||||
return results[0], nil
|
||||
}
|
||||
|
||||
const (
|
||||
// extensionPublisherNameExpr is the SQL expression for the extension's publisher's name (using
|
||||
// the table aliases created by (dbExtensions).listCountSQL.
|
||||
extensionPublisherNameExpr = "COALESCE(users.username, orgs.name)"
|
||||
|
||||
// extensionIDExpr is the SQL expression for the extension ID (using the table aliases created by
|
||||
// (dbExtensions).listCountSQL.
|
||||
extensionIDExpr = "CONCAT(" + extensionPublisherNameExpr + ", '/', x.name)"
|
||||
)
|
||||
|
||||
// GetByExtensionID retrieves the registry extension (if any) given its extension ID, which is the
|
||||
// concatenation of the publisher name, a slash ("/"), and the extension name.
|
||||
//
|
||||
// 🚨 SECURITY: The caller must ensure that the actor is permitted to view this registry extension.
|
||||
func (s dbExtensions) GetByExtensionID(ctx context.Context, extensionID string) (*dbExtension, error) {
|
||||
if mocks.extensions.GetByExtensionID != nil {
|
||||
return mocks.extensions.GetByExtensionID(extensionID)
|
||||
}
|
||||
|
||||
// TODO(sqs): prevent the creation of an org with the same name as a user so that there is no
|
||||
// ambiguity as to whether the publisher refers to a user or org by the given name
|
||||
// (https://github.com/sourcegraph/sourcegraph/issues/12068).
|
||||
parts := strings.SplitN(extensionID, "/", 2)
|
||||
if len(parts) < 2 {
|
||||
return nil, extensionNotFoundError{[]interface{}{fmt.Sprintf("extensionID %q", extensionID)}}
|
||||
}
|
||||
publisherName := parts[0]
|
||||
extensionName := parts[1]
|
||||
|
||||
results, err := s.list(ctx, []*sqlf.Query{
|
||||
sqlf.Sprintf("x.name=%s", extensionName),
|
||||
sqlf.Sprintf("(users.username=%s OR orgs.name=%s)", publisherName, publisherName),
|
||||
}, nil, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(results) == 0 {
|
||||
return nil, extensionNotFoundError{[]interface{}{fmt.Sprintf("extensionID %q", extensionID)}}
|
||||
}
|
||||
return results[0], nil
|
||||
}
|
||||
|
||||
// dbExtensionsListOptions contains options for listing registry extensions.
|
||||
type dbExtensionsListOptions struct {
|
||||
Publisher dbPublisher
|
||||
Query string // matches the extension ID
|
||||
PrioritizeExtensionIDs []string
|
||||
*db.LimitOffset
|
||||
}
|
||||
|
||||
func (o dbExtensionsListOptions) sqlConditions() []*sqlf.Query {
|
||||
var conds []*sqlf.Query
|
||||
if o.Publisher.UserID != 0 {
|
||||
conds = append(conds, sqlf.Sprintf("x.publisher_user_id=%d", o.Publisher.UserID))
|
||||
}
|
||||
if o.Publisher.OrgID != 0 {
|
||||
conds = append(conds, sqlf.Sprintf("x.publisher_org_id=%d", o.Publisher.OrgID))
|
||||
}
|
||||
if o.Query != "" {
|
||||
conds = append(conds, sqlf.Sprintf(extensionIDExpr+" ILIKE %s", "%"+strings.Replace(strings.ToLower(o.Query), " ", "%", -1)+"%"))
|
||||
}
|
||||
if len(conds) == 0 {
|
||||
conds = append(conds, sqlf.Sprintf("TRUE"))
|
||||
}
|
||||
return conds
|
||||
}
|
||||
|
||||
func (o dbExtensionsListOptions) sqlOrder() []*sqlf.Query {
|
||||
ids := make([]*sqlf.Query, len(o.PrioritizeExtensionIDs)+1)
|
||||
for i, id := range o.PrioritizeExtensionIDs {
|
||||
ids[i] = sqlf.Sprintf("%v", string(id))
|
||||
}
|
||||
ids[len(o.PrioritizeExtensionIDs)] = sqlf.Sprintf("NULL")
|
||||
return []*sqlf.Query{sqlf.Sprintf(extensionIDExpr+` IN (%v) ASC`, sqlf.Join(ids, ","))}
|
||||
}
|
||||
|
||||
// List lists all registry extensions that satisfy the options.
|
||||
//
|
||||
// 🚨 SECURITY: The caller must ensure that the actor is permitted to list with the specified
|
||||
// options.
|
||||
func (s dbExtensions) List(ctx context.Context, opt dbExtensionsListOptions) ([]*dbExtension, error) {
|
||||
return s.list(ctx, opt.sqlConditions(), opt.sqlOrder(), opt.LimitOffset)
|
||||
}
|
||||
|
||||
func (dbExtensions) listCountSQL(conds []*sqlf.Query) *sqlf.Query {
|
||||
return sqlf.Sprintf(`
|
||||
FROM registry_extensions x
|
||||
LEFT JOIN users ON users.id=publisher_user_id AND users.deleted_at IS NULL
|
||||
LEFT JOIN orgs ON orgs.id=publisher_org_id AND orgs.deleted_at IS NULL
|
||||
WHERE (%s) AND x.deleted_at IS NULL`,
|
||||
sqlf.Join(conds, ") AND ("))
|
||||
}
|
||||
|
||||
func (s dbExtensions) list(ctx context.Context, conds, order []*sqlf.Query, limitOffset *db.LimitOffset) ([]*dbExtension, error) {
|
||||
order = append(order, sqlf.Sprintf("TRUE"))
|
||||
q := sqlf.Sprintf(`
|
||||
SELECT x.id, x.uuid, x.publisher_user_id, x.publisher_org_id, x.name, x.created_at, x.updated_at,
|
||||
`+extensionIDExpr+` AS non_canonical_extension_id, `+extensionPublisherNameExpr+` AS non_canonical_publisher_name
|
||||
%s
|
||||
ORDER BY %s, x.id ASC
|
||||
%s`,
|
||||
s.listCountSQL(conds),
|
||||
sqlf.Join(order, ","),
|
||||
limitOffset.SQL(),
|
||||
)
|
||||
|
||||
rows, err := dbconn.Global.QueryContext(ctx, q.Query(sqlf.PostgresBindVar), q.Args()...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
var results []*dbExtension
|
||||
for rows.Next() {
|
||||
var t dbExtension
|
||||
var publisherUserID, publisherOrgID sql.NullInt64
|
||||
if err := rows.Scan(&t.ID, &t.UUID, &publisherUserID, &publisherOrgID, &t.Name, &t.CreatedAt, &t.UpdatedAt, &t.NonCanonicalExtensionID, &t.Publisher.NonCanonicalName); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
t.Publisher.UserID = int32(publisherUserID.Int64)
|
||||
t.Publisher.OrgID = int32(publisherOrgID.Int64)
|
||||
results = append(results, &t)
|
||||
}
|
||||
return results, nil
|
||||
}
|
||||
|
||||
// Count counts all registry extensions that satisfy the options (ignoring limit and offset).
|
||||
//
|
||||
// 🚨 SECURITY: The caller must ensure that the actor is permitted to count the results.
|
||||
func (s dbExtensions) Count(ctx context.Context, opt dbExtensionsListOptions) (int, error) {
|
||||
q := sqlf.Sprintf("SELECT COUNT(*) %s", s.listCountSQL(opt.sqlConditions()))
|
||||
var count int
|
||||
if err := dbconn.Global.QueryRowContext(ctx, q.Query(sqlf.PostgresBindVar), q.Args()...).Scan(&count); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
// Update updates information about the registry extension.
|
||||
func (dbExtensions) Update(ctx context.Context, id int32, name *string) error {
|
||||
if mocks.extensions.Update != nil {
|
||||
return mocks.extensions.Update(id, name)
|
||||
}
|
||||
|
||||
res, err := dbconn.Global.ExecContext(ctx,
|
||||
"UPDATE registry_extensions SET name=COALESCE($2, name), updated_at=now() WHERE id=$1 AND deleted_at IS NULL",
|
||||
id, name,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nrows, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if nrows == 0 {
|
||||
return extensionNotFoundError{[]interface{}{id}}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Delete marks an registry extension as deleted.
|
||||
func (dbExtensions) Delete(ctx context.Context, id int32) error {
|
||||
if mocks.extensions.Delete != nil {
|
||||
return mocks.extensions.Delete(id)
|
||||
}
|
||||
|
||||
res, err := dbconn.Global.ExecContext(ctx, "UPDATE registry_extensions SET deleted_at=now() WHERE id=$1 AND deleted_at IS NULL", id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
nrows, err := res.RowsAffected()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if nrows == 0 {
|
||||
return extensionNotFoundError{[]interface{}{id}}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// mockExtensions mocks the registry extensions store.
|
||||
type mockExtensions struct {
|
||||
Create func(publisherUserID, publisherOrgID int32, name string) (int32, error)
|
||||
GetByID func(id int32) (*dbExtension, error)
|
||||
GetByUUID func(uuid string) (*dbExtension, error)
|
||||
GetByExtensionID func(extensionID string) (*dbExtension, error)
|
||||
Update func(id int32, name *string) error
|
||||
Delete func(id int32) error
|
||||
}
|
||||
273
enterprise/cmd/frontend/internal/registry/extensions_db_test.go
Normal file
273
enterprise/cmd/frontend/internal/registry/extensions_db_test.go
Normal file
@ -0,0 +1,273 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
|
||||
dbtesting "github.com/sourcegraph/sourcegraph/cmd/frontend/db/testing"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/errcode"
|
||||
)
|
||||
|
||||
// registryExtensionNamesForTests is a list of test cases containing valid and invalid registry
|
||||
// extension names.
|
||||
var registryExtensionNamesForTests = []struct {
|
||||
name string
|
||||
wantValid bool
|
||||
}{
|
||||
{"", false},
|
||||
{"a", true},
|
||||
{"-a", false},
|
||||
{"a-", false},
|
||||
{"a-b", true},
|
||||
{"a--b", false},
|
||||
{"a---b", false},
|
||||
{"a.b", true},
|
||||
{"a..b", false},
|
||||
{"a...b", false},
|
||||
{"a_b", true},
|
||||
{"a__b", false},
|
||||
{"a___b", false},
|
||||
{"a-.b", false},
|
||||
{"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", false},
|
||||
}
|
||||
|
||||
func TestRegistryExtensions_validUsernames(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip()
|
||||
}
|
||||
ctx := dbtesting.TestContext(t)
|
||||
|
||||
user, err := db.Users.Create(ctx, db.NewUser{Username: "u"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
for _, test := range registryExtensionNamesForTests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
valid := true
|
||||
if _, err := (dbExtensions{}).Create(ctx, user.ID, 0, test.name); err != nil {
|
||||
if e, ok := err.(*pq.Error); ok && (e.Constraint == "registry_extensions_name_valid_chars" || e.Constraint == "registry_extensions_name_length") {
|
||||
valid = false
|
||||
} else {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
if valid != test.wantValid {
|
||||
t.Errorf("%q: got valid %v, want %v", test.name, valid, test.wantValid)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRegistryExtensions(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip()
|
||||
}
|
||||
ctx := dbtesting.TestContext(t)
|
||||
|
||||
testGetByID := func(t *testing.T, id int32, want *dbExtension, wantPublisherName string) {
|
||||
t.Helper()
|
||||
x, err := dbExtensions{}.GetByID(ctx, id)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !reflect.DeepEqual(x, want) {
|
||||
t.Errorf("got %+v, want %+v", x, want)
|
||||
}
|
||||
if x.Publisher.NonCanonicalName != wantPublisherName {
|
||||
t.Errorf("got publisher name %q, want %q", x.Publisher.NonCanonicalName, wantPublisherName)
|
||||
}
|
||||
}
|
||||
testGetByExtensionID := func(t *testing.T, extensionID string, want *dbExtension) {
|
||||
t.Helper()
|
||||
x, err := dbExtensions{}.GetByExtensionID(ctx, extensionID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !reflect.DeepEqual(x, want) {
|
||||
t.Errorf("got %+v, want %+v", x, want)
|
||||
}
|
||||
if x.NonCanonicalExtensionID != extensionID {
|
||||
t.Errorf("got extension ID %q, want %q", x.NonCanonicalExtensionID, extensionID)
|
||||
}
|
||||
}
|
||||
testList := func(t *testing.T, opt dbExtensionsListOptions, want []*dbExtension) {
|
||||
t.Helper()
|
||||
if ois, err := (dbExtensions{}).List(ctx, opt); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if !reflect.DeepEqual(ois, want) {
|
||||
t.Errorf("got %s, want %s", asJSON(t, ois), asJSON(t, want))
|
||||
}
|
||||
}
|
||||
testListCount := func(t *testing.T, opt dbExtensionsListOptions, want []*dbExtension) {
|
||||
t.Helper()
|
||||
testList(t, opt, want)
|
||||
if n, err := (dbExtensions{}).Count(ctx, opt); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if want := len(want); n != want {
|
||||
t.Errorf("got %d, want %d", n, want)
|
||||
}
|
||||
}
|
||||
|
||||
user, err := db.Users.Create(ctx, db.NewUser{Username: "u"})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
org, err := db.Orgs.Create(ctx, "o", nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
createAndGet := func(t *testing.T, publisherUserID, publisherOrgID int32, name string) *dbExtension {
|
||||
t.Helper()
|
||||
xID, err := dbExtensions{}.Create(ctx, publisherUserID, publisherOrgID, name)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
x, err := dbExtensions{}.GetByID(ctx, xID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return x
|
||||
}
|
||||
xu := createAndGet(t, user.ID, 0, "xu")
|
||||
xo := createAndGet(t, 0, org.ID, "xo")
|
||||
|
||||
t.Run("List/Count/Get publishers", func(t *testing.T) {
|
||||
publishers, err := dbExtensions{}.ListPublishers(ctx, dbPublishersListOptions{})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if want := []*dbPublisher{
|
||||
&xo.Publisher,
|
||||
&xu.Publisher,
|
||||
}; !reflect.DeepEqual(publishers, want) {
|
||||
t.Errorf("got publishers %+v, want %+v", publishers, want)
|
||||
}
|
||||
|
||||
if n, err := (dbExtensions{}).CountPublishers(ctx, dbPublishersListOptions{}); err != nil {
|
||||
t.Fatal(err)
|
||||
} else if want := 2; n != 2 {
|
||||
t.Errorf("got count %d, want %d", n, want)
|
||||
}
|
||||
|
||||
for _, p := range []*dbPublisher{&xo.Publisher, &xu.Publisher} {
|
||||
got, err := dbExtensions{}.GetPublisher(ctx, p.NonCanonicalName)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if !reflect.DeepEqual(got, p) {
|
||||
t.Errorf("got %+v, want %+v", got, p)
|
||||
}
|
||||
}
|
||||
if _, err := (dbExtensions{}).GetPublisher(ctx, "doesntexist"); !errcode.IsNotFound(err) {
|
||||
t.Errorf("got err %v, want errcode.IsNotFound", err)
|
||||
}
|
||||
})
|
||||
|
||||
publishers := map[string]struct {
|
||||
publisherUserID, publisherOrgID int32
|
||||
publisherName string
|
||||
}{
|
||||
"user": {publisherUserID: user.ID, publisherName: "u"},
|
||||
"org": {publisherOrgID: org.ID, publisherName: "o"},
|
||||
}
|
||||
for name, c := range publishers {
|
||||
t.Run(name+" publisher", func(t *testing.T) {
|
||||
x := createAndGet(t, c.publisherUserID, c.publisherOrgID, "x")
|
||||
|
||||
t.Run("GetByID", func(t *testing.T) {
|
||||
testGetByID(t, x.ID, x, c.publisherName)
|
||||
if _, err := (dbExtensions{}).GetByID(ctx, 12345 /* doesn't exist */); !errcode.IsNotFound(err) {
|
||||
t.Errorf("got err %v, want errcode.IsNotFound", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("GetByExtensionID", func(t *testing.T) {
|
||||
testGetByExtensionID(t, c.publisherName+"/"+x.Name, x)
|
||||
if _, err := (dbExtensions{}).GetByExtensionID(ctx, "foo.bar"); !errcode.IsNotFound(err) {
|
||||
t.Errorf("got err %v, want errcode.IsNotFound", err)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("List/Count all", func(t *testing.T) {
|
||||
testListCount(t, dbExtensionsListOptions{}, []*dbExtension{xu, xo, x})
|
||||
})
|
||||
wantByPublisherUser := []*dbExtension{xu}
|
||||
wantByPublisherOrg := []*dbExtension{xo}
|
||||
var wantByCurrent []*dbExtension
|
||||
if c.publisherUserID != 0 {
|
||||
wantByPublisherUser = append(wantByPublisherUser, x)
|
||||
wantByCurrent = wantByPublisherUser
|
||||
} else {
|
||||
wantByPublisherOrg = append(wantByPublisherOrg, x)
|
||||
wantByCurrent = wantByPublisherOrg
|
||||
}
|
||||
t.Run("List/Count by PublisherUserID", func(t *testing.T) {
|
||||
testListCount(t, dbExtensionsListOptions{Publisher: dbPublisher{UserID: user.ID}}, wantByPublisherUser)
|
||||
})
|
||||
t.Run("List/Count by Publisher.OrgID", func(t *testing.T) {
|
||||
testListCount(t, dbExtensionsListOptions{Publisher: dbPublisher{OrgID: org.ID}}, wantByPublisherOrg)
|
||||
})
|
||||
t.Run("List/Count by Publisher.Query all", func(t *testing.T) {
|
||||
testListCount(t, dbExtensionsListOptions{Query: "x"}, []*dbExtension{xu, xo, x})
|
||||
})
|
||||
t.Run("List/Count by Publisher.Query one", func(t *testing.T) {
|
||||
testListCount(t, dbExtensionsListOptions{Query: c.publisherName + "/" + x.Name}, wantByCurrent)
|
||||
})
|
||||
t.Run("List/Count with prioritizeExtensionIDs", func(t *testing.T) {
|
||||
testList(t, dbExtensionsListOptions{PrioritizeExtensionIDs: []string{xu.NonCanonicalExtensionID}, LimitOffset: &db.LimitOffset{Limit: 1}}, []*dbExtension{xu})
|
||||
testList(t, dbExtensionsListOptions{PrioritizeExtensionIDs: []string{xo.NonCanonicalExtensionID}, LimitOffset: &db.LimitOffset{Limit: 1}}, []*dbExtension{xo})
|
||||
})
|
||||
|
||||
if err := (dbExtensions{}).Delete(ctx, x.ID); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if err := (dbExtensions{}).Delete(ctx, x.ID); !errcode.IsNotFound(err) {
|
||||
t.Errorf("2nd Delete: got err %v, want errcode.IsNotFound", err)
|
||||
}
|
||||
if _, err := (dbExtensions{}).GetByID(ctx, x.ID); !errcode.IsNotFound(err) {
|
||||
t.Errorf("GetByID after Delete: got err %v, want errcode.IsNotFound", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
t.Run("Update", func(t *testing.T) {
|
||||
x := xu
|
||||
if err := (dbExtensions{}).Update(ctx, x.ID, nil); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
x1, err := dbExtensions{}.GetByID(ctx, x.ID)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if time.Since(x1.UpdatedAt) > 1*time.Minute {
|
||||
t.Errorf("got UpdatedAt %v, want recent", x1.UpdatedAt)
|
||||
}
|
||||
if x1.Name != x.Name {
|
||||
t.Errorf("got name %q, want %q", x1.Name, x.Name)
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("Create with same publisher and name", func(t *testing.T) {
|
||||
_, err := dbExtensions{}.Create(ctx, user.ID, 0, "zzz")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
if _, err := (dbExtensions{}).Create(ctx, user.ID, 0, "zzz"); err == nil {
|
||||
t.Fatal("err == nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func asJSON(t *testing.T, v interface{}) string {
|
||||
b, err := json.MarshalIndent(v, "", " ")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
return string(b)
|
||||
}
|
||||
239
enterprise/cmd/frontend/internal/registry/http_api.go
Normal file
239
enterprise/cmd/frontend/internal/registry/http_api.go
Normal file
@ -0,0 +1,239 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/prometheus/client_golang/prometheus"
|
||||
frontendregistry "github.com/sourcegraph/sourcegraph/cmd/frontend/registry"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/conf"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/errcode"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/honey"
|
||||
"github.com/sourcegraph/sourcegraph/pkg/registry"
|
||||
)
|
||||
|
||||
func init() {
|
||||
frontendregistry.HandleRegistry = handleRegistry
|
||||
}
|
||||
|
||||
// Funcs called by serveRegistry to get registry data. If fakeRegistryData is set, it is used as
|
||||
// the data source instead of the database.
|
||||
var (
|
||||
registryList = func(ctx context.Context, opt dbExtensionsListOptions) ([]*registry.Extension, error) {
|
||||
vs, err := dbExtensions{}.List(ctx, opt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
xs := make([]*registry.Extension, len(vs))
|
||||
for i, v := range vs {
|
||||
x, err := toRegistryAPIExtension(ctx, v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
xs[i] = x
|
||||
}
|
||||
return xs, nil
|
||||
}
|
||||
|
||||
registryGetByUUID = func(ctx context.Context, uuid string) (*registry.Extension, error) {
|
||||
x, err := dbExtensions{}.GetByUUID(ctx, uuid)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return toRegistryAPIExtension(ctx, x)
|
||||
}
|
||||
|
||||
registryGetByExtensionID = func(ctx context.Context, extensionID string) (*registry.Extension, error) {
|
||||
x, err := dbExtensions{}.GetByExtensionID(ctx, extensionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return toRegistryAPIExtension(ctx, x)
|
||||
}
|
||||
)
|
||||
|
||||
func toRegistryAPIExtension(ctx context.Context, v *dbExtension) (*registry.Extension, error) {
|
||||
manifest, err := getExtensionManifestWithBundleURL(ctx, v.NonCanonicalExtensionID, v.ID, "release")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
baseURL := strings.TrimSuffix(conf.Get().AppURL, "/")
|
||||
return ®istry.Extension{
|
||||
UUID: v.UUID,
|
||||
ExtensionID: v.NonCanonicalExtensionID,
|
||||
Publisher: registry.Publisher{
|
||||
Name: v.Publisher.NonCanonicalName,
|
||||
URL: baseURL + frontendregistry.PublisherExtensionsURL(v.Publisher.UserID != 0, v.Publisher.OrgID != 0, v.Publisher.NonCanonicalName),
|
||||
},
|
||||
Name: v.Name,
|
||||
Manifest: manifest,
|
||||
CreatedAt: v.CreatedAt,
|
||||
UpdatedAt: v.UpdatedAt,
|
||||
URL: baseURL + frontendregistry.ExtensionURL(v.NonCanonicalExtensionID),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// handleRegistry serves the external HTTP API for the extension registry.
|
||||
func handleRegistry(w http.ResponseWriter, r *http.Request) (err error) {
|
||||
if conf.Extensions() == nil {
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return
|
||||
}
|
||||
|
||||
builder := honey.Builder("registry")
|
||||
builder.AddField("api_version", r.Header.Get("Accept"))
|
||||
builder.AddField("url", r.URL.String())
|
||||
ev := builder.NewEvent()
|
||||
defer func() {
|
||||
ev.AddField("success", err == nil)
|
||||
if err == nil {
|
||||
registryRequestsSuccessCounter.Inc()
|
||||
} else {
|
||||
registryRequestsErrorCounter.Inc()
|
||||
ev.AddField("error", err.Error())
|
||||
}
|
||||
ev.Send()
|
||||
}()
|
||||
|
||||
// Identify this response as coming from the registry API.
|
||||
w.Header().Set(registry.MediaTypeHeaderName, registry.MediaType)
|
||||
w.Header().Set("Vary", registry.MediaTypeHeaderName)
|
||||
|
||||
// Validate API version.
|
||||
if v := r.Header.Get("Accept"); v != registry.AcceptHeader {
|
||||
http.Error(w, fmt.Sprintf("invalid Accept header: expected %q", registry.AcceptHeader), http.StatusBadRequest)
|
||||
return nil
|
||||
}
|
||||
|
||||
// This handler can be mounted at either /.internal or /.api.
|
||||
urlPath := r.URL.Path
|
||||
switch {
|
||||
case strings.HasPrefix(urlPath, "/.internal"):
|
||||
urlPath = strings.TrimPrefix(urlPath, "/.internal")
|
||||
case strings.HasPrefix(urlPath, "/.api"):
|
||||
urlPath = strings.TrimPrefix(urlPath, "/.api")
|
||||
}
|
||||
|
||||
const extensionsPath = "/registry/extensions"
|
||||
var result interface{}
|
||||
switch {
|
||||
case urlPath == extensionsPath:
|
||||
query := r.URL.Query().Get("q")
|
||||
ev.AddField("query", query)
|
||||
xs, err := registryList(r.Context(), dbExtensionsListOptions{Query: query})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ev.AddField("results_count", len(xs))
|
||||
result = xs
|
||||
|
||||
case strings.HasPrefix(urlPath, extensionsPath+"/"):
|
||||
var (
|
||||
spec = strings.TrimPrefix(urlPath, extensionsPath+"/")
|
||||
x *registry.Extension
|
||||
err error
|
||||
)
|
||||
switch {
|
||||
case strings.HasPrefix(spec, "uuid/"):
|
||||
x, err = registryGetByUUID(r.Context(), strings.TrimPrefix(spec, "uuid/"))
|
||||
case strings.HasPrefix(spec, "extension-id/"):
|
||||
x, err = registryGetByExtensionID(r.Context(), strings.TrimPrefix(spec, "extension-id/"))
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return nil
|
||||
}
|
||||
if x == nil || err != nil {
|
||||
if x == nil || errcode.IsNotFound(err) {
|
||||
w.Header().Set("Cache-Control", "max-age=5, private")
|
||||
http.Error(w, "extension not found", http.StatusNotFound)
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
ev.AddField("extension-id", x.ExtensionID)
|
||||
result = x
|
||||
|
||||
default:
|
||||
w.WriteHeader(http.StatusNotFound)
|
||||
return nil
|
||||
}
|
||||
|
||||
w.Header().Set("Cache-Control", "max-age=30, private")
|
||||
data, err := json.Marshal(result)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
w.Write(data)
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
registryRequestsSuccessCounter = prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Namespace: "src",
|
||||
Subsystem: "registry",
|
||||
Name: "requests_success",
|
||||
Help: "Number of successful requests (HTTP 200) to the HTTP registry API",
|
||||
})
|
||||
registryRequestsErrorCounter = prometheus.NewCounter(prometheus.CounterOpts{
|
||||
Namespace: "src",
|
||||
Subsystem: "registry",
|
||||
Name: "requests_error",
|
||||
Help: "Number of failed (non-HTTP 200) requests to the HTTP registry API",
|
||||
})
|
||||
)
|
||||
|
||||
func init() {
|
||||
prometheus.MustRegister(registryRequestsSuccessCounter)
|
||||
prometheus.MustRegister(registryRequestsErrorCounter)
|
||||
}
|
||||
|
||||
func init() {
|
||||
// Allow providing fake registry data for local dev (intended for use in local dev only).
|
||||
//
|
||||
// If FAKE_REGISTRY is set and refers to a valid JSON file (of []*registry.Extension), is used
|
||||
// by serveRegistry (instead of the DB) as the source for registry data.
|
||||
path := os.Getenv("FAKE_REGISTRY")
|
||||
if path == "" {
|
||||
return
|
||||
}
|
||||
|
||||
readFakeExtensions := func() ([]*registry.Extension, error) {
|
||||
data, err := ioutil.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var xs []*registry.Extension
|
||||
if err := json.Unmarshal(data, &xs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return xs, nil
|
||||
}
|
||||
|
||||
registryList = func(ctx context.Context, opt dbExtensionsListOptions) ([]*registry.Extension, error) {
|
||||
xs, err := readFakeExtensions()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return frontendregistry.FilterRegistryExtensions(xs, opt.Query), nil
|
||||
}
|
||||
registryGetByUUID = func(ctx context.Context, uuid string) (*registry.Extension, error) {
|
||||
xs, err := readFakeExtensions()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return frontendregistry.FindRegistryExtension(xs, "uuid", uuid), nil
|
||||
}
|
||||
registryGetByExtensionID = func(ctx context.Context, extensionID string) (*registry.Extension, error) {
|
||||
xs, err := readFakeExtensions()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return frontendregistry.FindRegistryExtension(xs, "extensionID", extensionID), nil
|
||||
}
|
||||
}
|
||||
12
enterprise/cmd/frontend/internal/registry/mock_db.go
Normal file
12
enterprise/cmd/frontend/internal/registry/mock_db.go
Normal file
@ -0,0 +1,12 @@
|
||||
package registry
|
||||
|
||||
func resetMocks() {
|
||||
mocks = dbMocks{}
|
||||
}
|
||||
|
||||
type dbMocks struct {
|
||||
extensions mockExtensions
|
||||
releases mockReleases
|
||||
}
|
||||
|
||||
var mocks dbMocks
|
||||
@ -0,0 +1,74 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend/graphqlutil"
|
||||
frontendregistry "github.com/sourcegraph/sourcegraph/cmd/frontend/registry"
|
||||
)
|
||||
|
||||
func init() {
|
||||
frontendregistry.ExtensionRegistry.PublishersFunc = extensionRegistryPublishers
|
||||
}
|
||||
|
||||
func extensionRegistryPublishers(ctx context.Context, args *graphqlutil.ConnectionArgs) (graphqlbackend.RegistryPublisherConnection, error) {
|
||||
var opt dbPublishersListOptions
|
||||
args.Set(&opt.LimitOffset)
|
||||
return ®istryPublisherConnection{opt: opt}, nil
|
||||
}
|
||||
|
||||
// registryPublisherConnection resolves a list of registry publishers.
|
||||
type registryPublisherConnection struct {
|
||||
opt dbPublishersListOptions
|
||||
|
||||
// cache results because they are used by multiple fields
|
||||
once sync.Once
|
||||
registryPublishers []*dbPublisher
|
||||
err error
|
||||
}
|
||||
|
||||
func (r *registryPublisherConnection) compute(ctx context.Context) ([]*dbPublisher, error) {
|
||||
r.once.Do(func() {
|
||||
opt2 := r.opt
|
||||
if opt2.LimitOffset != nil {
|
||||
tmp := *opt2.LimitOffset
|
||||
opt2.LimitOffset = &tmp
|
||||
opt2.Limit++ // so we can detect if there is a next page
|
||||
}
|
||||
|
||||
r.registryPublishers, r.err = dbExtensions{}.ListPublishers(ctx, opt2)
|
||||
})
|
||||
return r.registryPublishers, r.err
|
||||
}
|
||||
|
||||
func (r *registryPublisherConnection) Nodes(ctx context.Context) ([]graphqlbackend.RegistryPublisher, error) {
|
||||
publishers, err := r.compute(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var l []graphqlbackend.RegistryPublisher
|
||||
for _, publisher := range publishers {
|
||||
p, err := getRegistryPublisher(ctx, *publisher)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l = append(l, p)
|
||||
}
|
||||
return l, nil
|
||||
}
|
||||
|
||||
func (r *registryPublisherConnection) TotalCount(ctx context.Context) (int32, error) {
|
||||
count, err := dbExtensions{}.CountPublishers(ctx, r.opt)
|
||||
return int32(count), err
|
||||
}
|
||||
|
||||
func (r *registryPublisherConnection) PageInfo(ctx context.Context) (*graphqlutil.PageInfo, error) {
|
||||
publishers, err := r.compute(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return graphqlutil.HasNextPage(r.opt.LimitOffset != nil && len(publishers) > r.opt.Limit), nil
|
||||
}
|
||||
148
enterprise/cmd/frontend/internal/registry/publisher_graphql.go
Normal file
148
enterprise/cmd/frontend/internal/registry/publisher_graphql.go
Normal file
@ -0,0 +1,148 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
graphql "github.com/graph-gophers/graphql-go"
|
||||
"github.com/graph-gophers/graphql-go/relay"
|
||||
"github.com/sourcegraph/enterprise/cmd/frontend/internal/licensing"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/backend"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
|
||||
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
|
||||
frontendregistry "github.com/sourcegraph/sourcegraph/cmd/frontend/registry"
|
||||
)
|
||||
|
||||
func init() {
|
||||
frontendregistry.ExtensionRegistry.ViewerPublishersFunc = extensionRegistryViewerPublishers
|
||||
}
|
||||
|
||||
func extensionRegistryViewerPublishers(ctx context.Context) ([]graphqlbackend.RegistryPublisher, error) {
|
||||
// The feature check here makes it so the any "New extension" form will show an error, so the
|
||||
// user finds out before trying to submit the form that the feature is disabled.
|
||||
if err := licensing.CheckFeature(licensing.FeatureExtensionRegistry); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var publishers []graphqlbackend.RegistryPublisher
|
||||
user, err := graphqlbackend.CurrentUser(ctx)
|
||||
if err != nil || user == nil {
|
||||
return nil, err
|
||||
}
|
||||
publishers = append(publishers, ®istryPublisher{user: user})
|
||||
|
||||
orgs, err := db.Orgs.GetByUserID(ctx, user.SourcegraphID())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for _, org := range orgs {
|
||||
publishers = append(publishers, ®istryPublisher{org: graphqlbackend.NewOrg(org)})
|
||||
}
|
||||
return publishers, nil
|
||||
}
|
||||
|
||||
// registryPublisher implements the GraphQL type RegistryPublisher.
|
||||
type registryPublisher struct {
|
||||
user *graphqlbackend.UserResolver
|
||||
org *graphqlbackend.OrgResolver
|
||||
}
|
||||
|
||||
var _ graphqlbackend.RegistryPublisher = ®istryPublisher{}
|
||||
|
||||
func (r *registryPublisher) ToUser() (*graphqlbackend.UserResolver, bool) {
|
||||
return r.user, r.user != nil
|
||||
}
|
||||
func (r *registryPublisher) ToOrg() (*graphqlbackend.OrgResolver, bool) { return r.org, r.org != nil }
|
||||
|
||||
func (r *registryPublisher) toDBRegistryPublisher() dbPublisher {
|
||||
switch {
|
||||
case r.user != nil:
|
||||
return dbPublisher{UserID: r.user.SourcegraphID(), NonCanonicalName: r.user.Username()}
|
||||
case r.org != nil:
|
||||
return dbPublisher{OrgID: r.org.OrgID(), NonCanonicalName: r.org.Name()}
|
||||
default:
|
||||
return dbPublisher{}
|
||||
}
|
||||
}
|
||||
|
||||
func (r *registryPublisher) RegistryExtensionConnectionURL() (*string, error) {
|
||||
p := r.toDBRegistryPublisher()
|
||||
url := frontendregistry.PublisherExtensionsURL(p.UserID != 0, p.OrgID != 0, p.NonCanonicalName)
|
||||
if url == "" {
|
||||
return nil, errRegistryUnknownPublisher
|
||||
}
|
||||
return &url, nil
|
||||
}
|
||||
|
||||
var errRegistryUnknownPublisher = errors.New("unknown registry extension publisher")
|
||||
|
||||
func getRegistryPublisher(ctx context.Context, publisher dbPublisher) (*registryPublisher, error) {
|
||||
switch {
|
||||
case publisher.UserID != 0:
|
||||
user, err := graphqlbackend.UserByIDInt32(ctx, publisher.UserID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ®istryPublisher{user: user}, nil
|
||||
case publisher.OrgID != 0:
|
||||
org, err := graphqlbackend.OrgByIDInt32(ctx, publisher.OrgID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ®istryPublisher{org: org}, nil
|
||||
default:
|
||||
return nil, errRegistryUnknownPublisher
|
||||
}
|
||||
}
|
||||
|
||||
type registryPublisherID struct {
|
||||
userID, orgID int32
|
||||
}
|
||||
|
||||
func toRegistryPublisherID(extension *dbExtension) *registryPublisherID {
|
||||
return ®istryPublisherID{
|
||||
userID: extension.Publisher.UserID,
|
||||
orgID: extension.Publisher.OrgID,
|
||||
}
|
||||
}
|
||||
|
||||
// unmarshalRegistryPublisherID unmarshals the GraphQL ID into the possible publisher ID
|
||||
// types.
|
||||
//
|
||||
// 🚨 SECURITY
|
||||
func unmarshalRegistryPublisherID(id graphql.ID) (*registryPublisherID, error) {
|
||||
var (
|
||||
p registryPublisherID
|
||||
err error
|
||||
)
|
||||
switch kind := relay.UnmarshalKind(id); kind {
|
||||
case "User":
|
||||
p.userID, err = graphqlbackend.UnmarshalUserID(id)
|
||||
case "Org":
|
||||
p.orgID, err = graphqlbackend.UnmarshalOrgID(id)
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown registry extension publisher type: %q", kind)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &p, nil
|
||||
}
|
||||
|
||||
// viewerCanAdminister returns whether the current user is allowed to perform mutations on a
|
||||
// registry extension with the given publisher.
|
||||
//
|
||||
// 🚨 SECURITY
|
||||
func (p *registryPublisherID) viewerCanAdminister(ctx context.Context) error {
|
||||
switch {
|
||||
case p.userID != 0:
|
||||
// 🚨 SECURITY: Check that the current user is either the publisher or a site admin.
|
||||
return backend.CheckSiteAdminOrSameUser(ctx, p.userID)
|
||||
case p.orgID != 0:
|
||||
// 🚨 SECURITY: Check that the current user is a member of the publisher org.
|
||||
return backend.CheckOrgAccess(ctx, p.orgID)
|
||||
default:
|
||||
return errRegistryUnknownPublisher
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user