add enterprise/ directory

This commit is contained in:
Beyang Liu 2018-10-26 15:59:32 -07:00
parent 200fb1bed9
commit ffd2ccfc84
315 changed files with 64956 additions and 0 deletions

18
enterprise/.gitignore vendored Normal file
View File

@ -0,0 +1,18 @@
cmd/server/dockerfile.go
# Web
.nyc_output/
coverage/
node_modules/
out/
package-lock.json
package-lock.json
puppeteer/
src/backend/graphqlschema.ts
yarn-error.log
/ui/assets/*
!/ui/assets/img
!/ui/assets/img/*
/.bin
/cmd/frontend/internal/assets/distassets_vfsdata.go
/vendor/.bin

View File

@ -0,0 +1,17 @@
.bin/
*.bundle.*
client/phabricator/scripts/loader.js
cmd/frontend/internal/db/schema.md
cmd/xlang-python/python-langserver/
package-lock.json
package.json
ui/assets/
vendor/
.nyc_output/
coverage/
out/
src/backend/graphqlschema.ts
src/schema/
ts-node-*
xlang/testdata
cmd/xlang-go/internal/server/testdata/

View File

@ -0,0 +1,3 @@
{
"extends": ["@sourcegraph/stylelint-config"]
}

12
enterprise/.vscode/settings.json vendored Normal file
View File

@ -0,0 +1,12 @@
{
"typescript.tsdk": "node_modules/typescript/lib",
"files.associations": {
"**/dev/config.json": "jsonc"
},
"json.schemas": [
{
"fileMatch": ["dev/config.json"],
"url": "/vendor/github.com/sourcegraph/sourcegraph/schema/site.schema.json"
}
]
}

78
enterprise/README.dev.md Normal file
View File

@ -0,0 +1,78 @@
# Sourcegraph development process
1. Configure your repository with two remotes, `oss` and `ent`, by running the following:
```bash
enterprise/dev/init-repo.sh
```
After running this script, the following should hold:
- `oss` should point to `https://github.com/sourcegraph/sourcegraph`.
- `ent` should point to `https://github.com/sourcegraph/enterprise`.
- There should be no `origin` remote.
2. Decide whether you will branch off the open-source (`oss`) or enterprise (`ent`) repo. When in
doubt, prefer `oss`.
## Developing off `oss`
1. Run `dev/launch.sh` from the root of this repository.
1. Push your branch up to `oss` and open a PR.
1. Wait for CI to pass, resolve any merge conflicts.
1. Merge the PR into `oss` `master`. This will trigger another PR for you in `ent` that contains the
commits you just merged. This PR will be automatically merged if `ent` CI passes and there are no
merge conflicts.
1. If the `ent` PR cannot be automatically merged, update the PR to resolve any CI errors or merge
conflicts.
- If you make any changes to OSS code in the `ent` repository, you are responsible for ensuring
these changes make it into `oss`.
## Developing off `ent`
If a substantial subset of your change can be made as an independent change to `oss`, prefer to do
that.
1. Run `dev/start.sh` from `enterprise` directory.
1. Push your branch up to `ent` and open a PR.
1. Wait for CI to pass, resolve any merge conflicts.
1. Merge the PR into `ent` `master`.
1. Cherry-pick your changes onto `oss/master` using `dev/prune-pick.sh`.
1. If there are no conflicts, push directly to `oss/master`. If there are conflicts, resolve them,
push to a `oss` branch, wait for CI, and then merge into `oss/master`.
- If you make any additional changes, you are responsible for syncing these into `ent`.
### Build notes
**IMPORTANT:** Commands that build enterprise targets (e.g., `go build`, `yarn`,
`enterprise/dev/go-install.sh`) should always be run with the `enterprise` directory as the current
working directory. Otherwise, build tools like `yarn` and `go` may try to update the root
`package.json` and `go.mod` files as a side effect, instead of updating `enterprise/package.json`
and `enterprise/go.mod`.
The OSS web app is `yarn link`ed into `enterprise/node_modules`. It will run both the build of the
enterprise webapp as well as the part of the build for the OSS repo that generates the distributed
files for the npm package.
## Fallback syncing
Following the above instructions should prevent most long-term divergence between `oss` and
`ent`. Divergence between the two can easily be tested by running `dev/git-diff-no-enterprise.sh oss/master ent/master`.
If `oss` and `ent` diverge too severely to the point where it becomes onerous to cherry-pick
specific commits between the two, syncing can be accomplished by merging `oss/master` into
`ent/master`:
```
git fetch oss
git fetch ent
git checkout ent/master -b ent-master
git merge oss/master --no-commit -X theirs # this allows us to view an accurate history of oss code in the enterprise repo
git checkout oss/master -- . # overwrite all oss code in the enterprise repo with the state from oss repo
git reset enterprise && git checkout enterprise
git commit -m"Sync oss to ent $(date '+%Y-%m-%d')"
git push ent HEAD:master
```
Note that this means that `oss` is the source of truth for all code outside the `enterprise`
directory.

25
enterprise/README.md Normal file
View File

@ -0,0 +1,25 @@
# Sourcegraph Enterprise
[![build](https://badge.buildkite.com/f0e47ba39d32616d973b38e846f8e1aa25893920047221738e.svg?branch=master)](https://buildkite.com/sourcegraph/enterprise)
[![codecov](https://codecov.io/gh/sourcegraph/enterprise/branch/master/graph/badge.svg?token=itk6ydR7l3)](https://codecov.io/gh/sourcegraph/enterprise)
[![code style: prettier](https://img.shields.io/badge/code_style-prettier-ff69b4.svg)](https://github.com/prettier/prettier)
This repository contains all of the Sourcegraph Enterprise code.
## Project layout
- The main Sourcegraph codebase is open source, see [github.com/sourcegraph/sourcegraph](https://github.com/sourcegraph/sourcegraph).
- This codebase just wraps the open source codebase and links in some private code for enterprise features.
- Only the enterprise codebase is published to e.g. Docker Hub. Enterprise features are behind paywalls (or good-faith). The open-source codebase is not published on Docker Hub (this avoids confusion and keeps the upgrade/downgrade process from open source <-> enterprise easy).
## Dev
See [README.dev.md](README.dev.md).
### Updating dependencies
- `go get -u $MODULE` to update `$MODULE` and all its transitive dependencies to their latest version.
- `go mod edit -replace $MODULE@$VERSION` to update `$MODULE` to a specific version.
- `go mod tidy` if updates to `go.mod` or `go.sum` have been made as a result of other build
invocations during development and you wish now to update `go.mod` and `go.sum` to be consistent
with how the build will run in CI.

View File

@ -0,0 +1,17 @@
// @ts-check
/** @type {import('@babel/core').TransformOptions} */
const config = {
plugins: ['@babel/plugin-syntax-dynamic-import', 'babel-plugin-lodash'],
presets: [
[
'@babel/preset-env',
{
useBuiltIns: 'entry',
modules: false,
},
],
],
}
module.exports = config

View File

@ -0,0 +1,38 @@
package httpheader
import (
"github.com/sourcegraph/sourcegraph/pkg/conf"
"github.com/sourcegraph/sourcegraph/schema"
)
// getProviderConfig returns the HTTP header auth provider config. At most 1 can be specified in
// site config; if there is more than 1, it returns multiple == true (which the caller should handle
// by returning an error and refusing to proceed with auth).
func getProviderConfig() (pc *schema.HTTPHeaderAuthProvider, multiple bool) {
for _, p := range conf.Get().AuthProviders {
if p.HttpHeader != nil {
if pc != nil {
return pc, true // multiple http-header auth providers
}
pc = p.HttpHeader
}
}
return pc, false
}
func init() {
conf.ContributeValidator(validateConfig)
}
func validateConfig(c schema.SiteConfiguration) (problems []string) {
var httpHeaderAuthProviders int
for _, p := range c.AuthProviders {
if p.HttpHeader != nil {
httpHeaderAuthProviders++
}
}
if httpHeaderAuthProviders >= 2 {
problems = append(problems, `at most 1 http-header auth provider may be used`)
}
return problems
}

View File

@ -0,0 +1,38 @@
package httpheader
import (
"testing"
"github.com/sourcegraph/sourcegraph/pkg/conf"
"github.com/sourcegraph/sourcegraph/schema"
)
func TestValidateCustom(t *testing.T) {
tests := map[string]struct {
input schema.SiteConfiguration
wantProblems []string
}{
"single": {
input: schema.SiteConfiguration{
AuthProviders: []schema.AuthProviders{
{HttpHeader: &schema.HTTPHeaderAuthProvider{Type: "http-header"}},
},
},
wantProblems: nil,
},
"multiple": {
input: schema.SiteConfiguration{
AuthProviders: []schema.AuthProviders{
{HttpHeader: &schema.HTTPHeaderAuthProvider{Type: "http-header"}},
{HttpHeader: &schema.HTTPHeaderAuthProvider{Type: "http-header"}},
},
},
wantProblems: []string{"at most 1"},
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
conf.TestValidator(t, test.input, validateConfig, test.wantProblems)
})
}
}

View File

@ -0,0 +1,49 @@
package httpheader
import (
"reflect"
"sync"
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/auth"
"github.com/sourcegraph/sourcegraph/pkg/conf"
"github.com/sourcegraph/sourcegraph/schema"
log15 "gopkg.in/inconshreveable/log15.v2"
)
// Watch for configuration changes related to the http-header auth provider.
func init() {
var (
init = true
mu sync.Mutex
pc *schema.HTTPHeaderAuthProvider
pi auth.Provider
)
conf.Watch(func() {
mu.Lock()
defer mu.Unlock()
// Only react when the config changes.
newPC, _ := getProviderConfig()
if reflect.DeepEqual(newPC, pc) {
return
}
if !init {
log15.Info("Reloading changed http-header authentication provider configuration.")
}
updates := map[auth.Provider]bool{}
var newPI auth.Provider
if newPC != nil {
newPI = &provider{c: newPC}
updates[newPI] = true
}
if pi != nil {
updates[pi] = false
}
auth.UpdateProviders(updates)
pc = newPC
pi = newPI
})
init = false
}

View File

@ -0,0 +1,96 @@
package httpheader
import (
"net/http"
"github.com/sourcegraph/enterprise/cmd/frontend/internal/licensing"
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/auth"
"github.com/sourcegraph/sourcegraph/pkg/actor"
log15 "gopkg.in/inconshreveable/log15.v2"
)
const providerType = "http-header"
// Middleware is the same for both the app and API because the HTTP proxy is assumed to wrap
// requests to both the app and API and add headers.
//
// See the "func middleware" docs for more information.
var Middleware = &auth.Middleware{
API: middleware,
App: middleware,
}
// middleware is middleware that checks for an HTTP header from an auth proxy that specifies the
// client's authenticated username. It's for use with auth proxies like
// https://github.com/bitly/oauth2_proxy and is configured with the http-header auth provider in
// site config.
//
// TESTING: Use the testproxy test program to test HTTP auth proxy behavior. For example, run `go
// run cmd/frontend/external/auth/httpheader/testproxy.go -username=alice` then go to
// http://localhost:4080. See `-h` for flag help.
//
// 🚨 SECURITY
func middleware(next http.Handler) http.Handler {
const misconfiguredMessage = "Misconfigured http-header auth provider."
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authProvider, multiple := getProviderConfig()
if multiple {
log15.Error("At most 1 HTTP header auth provider may be set in site config.")
http.Error(w, misconfiguredMessage, http.StatusInternalServerError)
return
}
if authProvider == nil {
next.ServeHTTP(w, r)
return
}
if authProvider.UsernameHeader == "" {
log15.Error("No HTTP header set for username (set the http-header auth provider's usernameHeader property).")
http.Error(w, "misconfigured http-header auth provider", http.StatusInternalServerError)
return
}
headerValue := r.Header.Get(authProvider.UsernameHeader)
// Continue onto next auth provider if no header is set (in case the auth proxy allows
// unauthenticated users to bypass it, which some do). Also respect already authenticated
// actors (e.g., via access token).
//
// It would NOT add any additional security to return an error here, because a user who can
// access this HTTP endpoint directly can just as easily supply a fake username whose
// identity to assume.
if headerValue == "" || actor.FromContext(r.Context()).IsAuthenticated() {
next.ServeHTTP(w, r)
return
}
// License check.
if !licensing.IsFeatureEnabledLenient(licensing.FeatureExternalAuthProvider) {
licensing.WriteSubscriptionErrorResponseForFeature(w, "http-header user authentication (SSO)")
return
}
// Otherwise, get or create the user and proceed with the authenticated request.
username, err := auth.NormalizeUsername(headerValue)
if err != nil {
log15.Error("Error normalizing username from HTTP auth proxy.", "username", headerValue, "err", err)
http.Error(w, "unable to normalize username", http.StatusInternalServerError)
return
}
userID, safeErrMsg, err := auth.CreateOrUpdateUser(r.Context(), db.NewUser{Username: username}, db.ExternalAccountSpec{
ServiceType: providerType,
// Store headerValue, not normalized username, to prevent two users with distinct
// pre-normalization usernames from being merged into the same normalized username
// (and therefore letting them each impersonate the other).
AccountID: headerValue,
}, db.ExternalAccountData{})
if err != nil {
log15.Error("unable to get/create user from SSO header", "header", authProvider.UsernameHeader, "headerValue", headerValue, "err", err, "userErr", safeErrMsg)
http.Error(w, safeErrMsg, http.StatusInternalServerError)
return
}
r = r.WithContext(actor.WithActor(r.Context(), &actor.Actor{UID: userID}))
next.ServeHTTP(w, r)
})
}

View File

@ -0,0 +1,117 @@
package httpheader
import (
"context"
"fmt"
"net/http"
"net/http/httptest"
"testing"
"github.com/sourcegraph/enterprise/cmd/frontend/internal/licensing"
"github.com/sourcegraph/enterprise/pkg/license"
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/auth"
"github.com/sourcegraph/sourcegraph/pkg/actor"
"github.com/sourcegraph/sourcegraph/pkg/conf"
"github.com/sourcegraph/sourcegraph/schema"
)
// SEE ALSO FOR MANUAL TESTING: See the Middleware docstring for information about the testproxy
// helper program, which helps with manual testing of the HTTP auth proxy behavior.
func TestMiddleware(t *testing.T) {
licensing.MockGetConfiguredProductLicenseInfo = func() (*license.Info, error) {
return &license.Info{Tags: licensing.EnterpriseTags}, nil
}
defer func() { licensing.MockGetConfiguredProductLicenseInfo = nil }()
handler := middleware(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
actor := actor.FromContext(r.Context())
if actor.IsAuthenticated() {
fmt.Fprintf(w, "user %v", actor.UID)
} else {
fmt.Fprint(w, "no user")
}
}))
const headerName = "x-sso-user-header"
conf.Mock(&schema.SiteConfiguration{AuthProviders: []schema.AuthProviders{{HttpHeader: &schema.HTTPHeaderAuthProvider{UsernameHeader: headerName}}}})
defer conf.Mock(nil)
t.Run("not sent", func(t *testing.T) {
rr := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/", nil)
handler.ServeHTTP(rr, req)
if got, want := rr.Body.String(), "no user"; got != want {
t.Errorf("got %q, want %q", got, want)
}
})
t.Run("not sent, actor present", func(t *testing.T) {
rr := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/", nil)
req = req.WithContext(actor.WithActor(context.Background(), &actor.Actor{UID: 123}))
handler.ServeHTTP(rr, req)
if got, want := rr.Body.String(), "user 123"; got != want {
t.Errorf("got %q, want %q", got, want)
}
})
t.Run("sent, user", func(t *testing.T) {
rr := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/", nil)
req.Header.Set(headerName, "alice")
var calledMock bool
auth.SetMockCreateOrUpdateUser(func(u db.NewUser, a db.ExternalAccountSpec) (userID int32, err error) {
calledMock = true
if a.ServiceType == "http-header" && a.ServiceID == "" && a.ClientID == "" && a.AccountID == "alice" {
return 1, nil
}
return 0, fmt.Errorf("account %v not found in mock", a)
})
defer auth.SetMockCreateOrUpdateUser(nil)
handler.ServeHTTP(rr, req)
if got, want := rr.Body.String(), "user 1"; got != want {
t.Errorf("got %q, want %q", got, want)
}
if !calledMock {
t.Error("!calledMock")
}
})
t.Run("sent, actor already set", func(t *testing.T) {
rr := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/", nil)
req.Header.Set(headerName, "alice")
req = req.WithContext(actor.WithActor(context.Background(), &actor.Actor{UID: 123}))
handler.ServeHTTP(rr, req)
if got, want := rr.Body.String(), "user 123"; got != want {
t.Errorf("got %q, want %q", got, want)
}
})
t.Run("sent, with un-normalized username", func(t *testing.T) {
rr := httptest.NewRecorder()
req, _ := http.NewRequest("GET", "/", nil)
req.Header.Set(headerName, "alice.zhao")
const wantNormalizedUsername = "alice-zhao"
var calledMock bool
auth.SetMockCreateOrUpdateUser(func(u db.NewUser, a db.ExternalAccountSpec) (userID int32, err error) {
calledMock = true
if u.Username != wantNormalizedUsername {
t.Errorf("got %q, want %q", u.Username, wantNormalizedUsername)
}
if a.ServiceType == "http-header" && a.ServiceID == "" && a.ClientID == "" && a.AccountID == "alice.zhao" {
return 1, nil
}
return 0, fmt.Errorf("account %v not found in mock", a)
})
defer auth.SetMockCreateOrUpdateUser(nil)
handler.ServeHTTP(rr, req)
if got, want := rr.Body.String(), "user 1"; got != want {
t.Errorf("got %q, want %q", got, want)
}
if !calledMock {
t.Error("!calledMock")
}
})
}

View File

@ -0,0 +1,30 @@
package httpheader
import (
"context"
"fmt"
"net/textproto"
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/auth"
"github.com/sourcegraph/sourcegraph/schema"
)
type provider struct {
c *schema.HTTPHeaderAuthProvider
}
// ConfigID implements auth.Provider.
func (provider) ConfigID() auth.ProviderConfigID { return auth.ProviderConfigID{Type: providerType} }
// Config implements auth.Provider.
func (p provider) Config() schema.AuthProviders { return schema.AuthProviders{HttpHeader: p.c} }
// Refresh implements auth.Provider.
func (p provider) Refresh(context.Context) error { return nil }
// CachedInfo implements auth.Provider.
func (p provider) CachedInfo() *auth.ProviderInfo {
return &auth.ProviderInfo{
DisplayName: fmt.Sprintf("HTTP authentication proxy (%q header)", textproto.CanonicalMIMEHeaderKey(p.c.UsernameHeader)),
}
}

View File

@ -0,0 +1,47 @@
// The testproxy command runs a simple HTTP proxy that wraps a Sourcegraph server running with the
// http-header auth provider to test the authentication HTTP proxy support.
// +build ignore
package main
import (
"flag"
"log"
"net/http"
"net/http/httputil"
"net/url"
"os"
)
var (
addr = flag.String("addr", ":4080", "HTTP listen address")
urlStr = flag.String("url", "http://localhost:3080", "proxy origin URL (Sourcegraph HTTP/HTTPS URL)") // CI:LOCALHOST_OK
username = flag.String("username", os.Getenv("USER"), "username to report to Sourcegraph")
httpHeader = flag.String("header", "X-Forwarded-User", "name of HTTP header to add to request")
)
func main() {
flag.Parse()
log.SetFlags(0)
url, err := url.Parse(*urlStr)
if err != nil {
log.Fatalf("Error: Invalid -url: %s.", err)
}
if *username == "" {
log.Fatal("Error: No -username specified.")
}
if *httpHeader == "" {
log.Fatal("Error: No -header specified.")
}
log.Printf(`Listening on %s, forwarding requests to %s with added header "%s: %s"`, *addr, url, *httpHeader, *username)
p := httputil.NewSingleHostReverseProxy(url)
log.Fatalf("Server error: %s.", http.ListenAndServe(*addr, &httputil.ReverseProxy{
Director: func(r *http.Request) {
r.Header.Set(*httpHeader, *username)
r.Host = url.Host
p.Director(r)
},
}))
}

View File

@ -0,0 +1,85 @@
// Package auth is imported for side-effects to enable enterprise-only SSO.
package auth
import (
"fmt"
"net/http"
"strings"
"github.com/sourcegraph/enterprise/cmd/frontend/auth/httpheader"
"github.com/sourcegraph/enterprise/cmd/frontend/auth/openidconnect"
"github.com/sourcegraph/enterprise/cmd/frontend/auth/saml"
"github.com/sourcegraph/enterprise/cmd/frontend/internal/licensing"
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/app"
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/auth"
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
"github.com/sourcegraph/sourcegraph/pkg/conf"
log15 "gopkg.in/inconshreveable/log15.v2"
)
func init() {
// Register enterprise auth middleware
auth.RegisterMiddlewares(
openidconnect.Middleware,
saml.Middleware,
httpheader.Middleware,
)
// Register app-level sign-out handler
app.RegisterSSOSignOutHandler(ssoSignOutHandler)
}
func ssoSignOutHandler(w http.ResponseWriter, r *http.Request) (signOutURLs []app.SignOutURL) {
for _, p := range conf.Get().AuthProviders {
var e app.SignOutURL
var err error
switch {
case p.Openidconnect != nil:
e.ProviderDisplayName = p.Openidconnect.DisplayName
e.ProviderServiceType = p.Openidconnect.Type
e.URL, err = openidconnect.SignOut(w, r)
case p.Saml != nil:
e.ProviderDisplayName = p.Saml.DisplayName
e.ProviderServiceType = p.Saml.Type
e.URL, err = saml.SignOut(w, r)
}
if e.URL != "" {
signOutURLs = append(signOutURLs, e)
}
if err != nil {
log15.Error("Error clearing auth provider session data.", "err", err)
}
}
return signOutURLs
}
func init() {
// Warn about usage of auth providers that are not enabled by the license.
graphqlbackend.AlertFuncs = append(graphqlbackend.AlertFuncs, func(args graphqlbackend.AlertFuncArgs) []*graphqlbackend.Alert {
// Only site admins can act on this alert, so only show it to site admins.
if !args.IsSiteAdmin {
return nil
}
if licensing.IsFeatureEnabledLenient(licensing.FeatureExternalAuthProvider) {
return nil
}
var externalAuthProviderTypes []string
for _, p := range conf.Get().AuthProviders {
if p.Builtin == nil {
externalAuthProviderTypes = append(externalAuthProviderTypes, conf.AuthProviderType(p))
}
}
if len(externalAuthProviderTypes) > 0 {
return []*graphqlbackend.Alert{
{
TypeValue: graphqlbackend.AlertTypeError,
MessageValue: fmt.Sprintf("A Sourcegraph license is required for user authentication providers (SSO): %s. [**Get a license.**](/site-admin/license)", strings.Join(externalAuthProviderTypes, ", ")),
},
}
}
return nil
})
}

View File

@ -0,0 +1,94 @@
package openidconnect
import (
"context"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"github.com/sourcegraph/enterprise/cmd/frontend/internal/licensing"
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/auth"
"github.com/sourcegraph/sourcegraph/pkg/conf"
"github.com/sourcegraph/sourcegraph/schema"
log15 "gopkg.in/inconshreveable/log15.v2"
)
var mockGetProviderValue *provider
// getProvider looks up the registered openidconnect auth provider with the given ID.
func getProvider(id string) *provider {
if mockGetProviderValue != nil {
return mockGetProviderValue
}
p, _ := auth.GetProviderByConfigID(auth.ProviderConfigID{Type: providerType, ID: id}).(*provider)
return p
}
func handleGetProvider(ctx context.Context, w http.ResponseWriter, id string) (p *provider, handled bool) {
handled = true // safer default
// License check.
if !licensing.IsFeatureEnabledLenient(licensing.FeatureExternalAuthProvider) {
licensing.WriteSubscriptionErrorResponseForFeature(w, "OpenID Connect user authentication (SSO)")
return nil, true
}
p = getProvider(id)
if p == nil {
log15.Error("No OpenID Connect auth provider found with ID.", "id", id)
http.Error(w, "Misconfigured OpenID Connect auth provider.", http.StatusInternalServerError)
return nil, true
}
if p.config.Issuer == "" {
log15.Error("No issuer set for OpenID Connect auth provider (set the openidconnect auth provider's issuer property).", "id", p.ConfigID())
http.Error(w, "Misconfigured OpenID Connect auth provider.", http.StatusInternalServerError)
return nil, true
}
if err := p.Refresh(ctx); err != nil {
log15.Error("Error refreshing OpenID Connect auth provider.", "id", p.ConfigID(), "error", err)
http.Error(w, "Unexpected error refreshing OpenID Connect authentication provider.", http.StatusInternalServerError)
return nil, true
}
return p, false
}
func init() {
conf.ContributeValidator(validateConfig)
}
func validateConfig(c schema.SiteConfiguration) (problems []string) {
var loggedNeedsAppURL bool
for _, p := range c.AuthProviders {
if p.Openidconnect != nil && c.AppURL == "" && !loggedNeedsAppURL {
problems = append(problems, `openidconnect auth provider requires appURL to be set to the external URL of your site (example: https://sourcegraph.example.com)`)
loggedNeedsAppURL = true
}
}
seen := map[schema.OpenIDConnectAuthProvider]int{}
for i, p := range c.AuthProviders {
if p.Openidconnect != nil {
if j, ok := seen[*p.Openidconnect]; ok {
problems = append(problems, fmt.Sprintf("OpenID Connect auth provider at index %d is duplicate of index %d, ignoring", i, j))
} else {
seen[*p.Openidconnect] = i
}
}
}
return problems
}
// providerConfigID produces a semi-stable identifier for an openidconnect auth provider config
// object. It is used to distinguish between multiple auth providers of the same type when in
// multi-step auth flows. Its value is never persisted, and it must be deterministic.
func providerConfigID(pc *schema.OpenIDConnectAuthProvider) string {
data, err := json.Marshal(pc)
if err != nil {
panic(err)
}
b := sha256.Sum256(data)
return base64.RawURLEncoding.EncodeToString(b[:16])
}

View File

@ -0,0 +1,40 @@
package openidconnect
import (
"testing"
"github.com/sourcegraph/sourcegraph/pkg/conf"
"github.com/sourcegraph/sourcegraph/schema"
)
func TestValidateCustom(t *testing.T) {
tests := map[string]struct {
input schema.SiteConfiguration
wantProblems []string
}{
"duplicates": {
input: schema.SiteConfiguration{
AppURL: "x",
AuthProviders: []schema.AuthProviders{
{Openidconnect: &schema.OpenIDConnectAuthProvider{Type: "openidconnect", Issuer: "x"}},
{Openidconnect: &schema.OpenIDConnectAuthProvider{Type: "openidconnect", Issuer: "x"}},
},
},
wantProblems: []string{"OpenID Connect auth provider at index 1 is duplicate of index 0"},
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
conf.TestValidator(t, test.input, validateConfig, test.wantProblems)
})
}
}
func TestProviderConfigID(t *testing.T) {
p := schema.OpenIDConnectAuthProvider{Issuer: "x"}
id1 := providerConfigID(&p)
id2 := providerConfigID(&p)
if id1 != id2 {
t.Errorf("id1 (%q) != id2 (%q)", id1, id2)
}
}

View File

@ -0,0 +1,84 @@
package openidconnect
import (
"context"
"sync"
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/auth"
"github.com/sourcegraph/sourcegraph/pkg/conf"
"github.com/sourcegraph/sourcegraph/schema"
log15 "gopkg.in/inconshreveable/log15.v2"
)
// Start trying to populate the cache of issuer metadata (given the configured OpenID Connect issuer
// URL) immediately upon server startup and site config changes so users don't incur the wait on the
// first auth flow request.
func init() {
providersOfType := func(ps []schema.AuthProviders) []*schema.OpenIDConnectAuthProvider {
var pcs []*schema.OpenIDConnectAuthProvider
for _, p := range ps {
if p.Openidconnect != nil {
pcs = append(pcs, p.Openidconnect)
}
}
return pcs
}
var (
init = true
mu sync.Mutex
cur []*schema.OpenIDConnectAuthProvider
reg = map[schema.OpenIDConnectAuthProvider]auth.Provider{}
)
conf.Watch(func() {
mu.Lock()
defer mu.Unlock()
// Only react when the config changes.
new := providersOfType(conf.Get().AuthProviders)
diff := diffProviderConfig(cur, new)
if len(diff) == 0 {
return
}
if !init {
log15.Info("Reloading changed OpenID Connect authentication provider configuration.")
}
updates := make(map[auth.Provider]bool, len(diff))
for pc, op := range diff {
if old, ok := reg[pc]; ok {
delete(reg, pc)
updates[old] = false
}
if op {
new := &provider{config: pc}
reg[pc] = new
updates[new] = true
go func(p *provider) {
if err := p.Refresh(context.Background()); err != nil {
log15.Error("Error prefetching OpenID Connect service provider metadata.", "error", err)
}
}(new)
}
}
auth.UpdateProviders(updates)
cur = new
})
init = false
}
func diffProviderConfig(old, new []*schema.OpenIDConnectAuthProvider) map[schema.OpenIDConnectAuthProvider]bool {
diff := map[schema.OpenIDConnectAuthProvider]bool{}
for _, oldPC := range old {
diff[*oldPC] = false
}
for _, newPC := range new {
if _, ok := diff[*newPC]; ok {
delete(diff, *newPC)
} else {
diff[*newPC] = true
}
}
return diff
}

View File

@ -0,0 +1,46 @@
package openidconnect
import (
"reflect"
"testing"
"github.com/sourcegraph/sourcegraph/schema"
)
func TestDiffProviderConfig(t *testing.T) {
var (
pc0 = &schema.OpenIDConnectAuthProvider{Issuer: "0"}
pc0c = &schema.OpenIDConnectAuthProvider{Issuer: "0", ClientSecret: "x"}
pc1 = &schema.OpenIDConnectAuthProvider{Issuer: "1"}
)
tests := map[string]struct {
old, new []*schema.OpenIDConnectAuthProvider
want map[schema.OpenIDConnectAuthProvider]bool
}{
"empty": {want: map[schema.OpenIDConnectAuthProvider]bool{}},
"added": {
old: nil,
new: []*schema.OpenIDConnectAuthProvider{pc0, pc1},
want: map[schema.OpenIDConnectAuthProvider]bool{*pc0: true, *pc1: true},
},
"changed": {
old: []*schema.OpenIDConnectAuthProvider{pc0, pc1},
new: []*schema.OpenIDConnectAuthProvider{pc0c, pc1},
want: map[schema.OpenIDConnectAuthProvider]bool{*pc0: false, *pc0c: true},
},
"removed": {
old: []*schema.OpenIDConnectAuthProvider{pc0, pc1},
new: []*schema.OpenIDConnectAuthProvider{pc1},
want: map[schema.OpenIDConnectAuthProvider]bool{*pc0: false},
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
diff := diffProviderConfig(test.old, test.new)
if !reflect.DeepEqual(diff, test.want) {
t.Errorf("got != want\n got %+v\nwant %+v", diff, test.want)
}
})
}
}

View File

@ -0,0 +1,317 @@
package openidconnect
import (
"encoding/base64"
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
"github.com/gorilla/csrf"
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/auth"
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/session"
"github.com/sourcegraph/sourcegraph/pkg/actor"
"golang.org/x/oauth2"
log15 "gopkg.in/inconshreveable/log15.v2"
oidc "github.com/coreos/go-oidc"
)
const stateCookieName = "sg-oidc-state"
// All OpenID Connect endpoints are under this path prefix.
const authPrefix = auth.AuthURLPrefix + "/openidconnect"
type userClaims struct {
Name string `json:"name"`
GivenName string `json:"given_name"`
FamilyName string `json:"family_name"`
PreferredUsername string `json:"preferred_username"`
Picture string `json:"picture"`
EmailVerified *bool `json:"email_verified"`
}
// Middleware is middleware for OpenID Connect (OIDC) authentication, adding endpoints under the
// auth path prefix ("/.auth") to enable the login flow and requiring login for all other endpoints.
//
// The OIDC spec (http://openid.net/specs/openid-connect-core-1_0.html) describes an authentication protocol
// that involves 3 parties: the Relying Party (e.g., Sourcegraph), the OpenID Provider (e.g., Okta, OneLogin,
// or another SSO provider), and the End User (e.g., a user's web browser).
//
// This middleware implements two things: (1) the OIDC Authorization Code Flow
// (http://openid.net/specs/openid-connect-core-1_0.html#CodeFlowAuth) and (2) Sourcegraph-specific session management
// (outside the scope of the OIDC spec). Upon successful completion of the OIDC login flow, the handler will create
// a new session and session cookie. The expiration of the session is the expiration of the OIDC ID Token.
//
// 🚨 SECURITY
var Middleware = &auth.Middleware{
API: func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleOpenIDConnectAuth(w, r, next, true)
})
},
App: func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
handleOpenIDConnectAuth(w, r, next, false)
})
},
}
// handleOpenIDConnectAuth performs OpenID Connect authentication (if configured) for HTTP requests,
// both API requests and non-API requests.
func handleOpenIDConnectAuth(w http.ResponseWriter, r *http.Request, next http.Handler, isAPIRequest bool) {
// Fixup URL path. We use "/.auth/callback" as the redirect URI for OpenID Connect, but the rest
// of this middleware's handlers expect paths of "/.auth/openidconnect/...", so add the
// "openidconnect" path component. We can't change the redirect URI because it is hardcoded in
// instances' external auth providers.
if r.URL.Path == auth.AuthURLPrefix+"/callback" {
// Rewrite "/.auth/callback" -> "/.auth/openidconnect/callback".
r.URL.Path = authPrefix + "/callback"
}
// Delegate to the OpenID Connect auth handler.
if !isAPIRequest && strings.HasPrefix(r.URL.Path, authPrefix+"/") {
authHandler(w, r)
return
}
// If the actor is authenticated and not performing an OpenID Connect flow, then proceed to
// next.
if actor.FromContext(r.Context()).IsAuthenticated() {
next.ServeHTTP(w, r)
return
}
// If there is only one auth provider configured, the single auth provider is OpenID Connect,
// and it's an app request, redirect to signin immediately. The user wouldn't be able to do
// anything else anyway; there's no point in showing them a signin screen with just a single
// signin option.
if ps := auth.Providers(); len(ps) == 1 && ps[0].Config().Openidconnect != nil && !isAPIRequest {
p, handled := handleGetProvider(r.Context(), w, ps[0].ConfigID().ID)
if handled {
return
}
redirectToAuthRequest(w, r, p, auth.SafeRedirectURL(r.URL.String()))
return
}
next.ServeHTTP(w, r)
}
// mockVerifyIDToken mocks the OIDC ID Token verification step. It should only be set in tests.
var mockVerifyIDToken func(rawIDToken string) *oidc.IDToken
// authHandler handles the OIDC Authentication Code Flow
// (http://openid.net/specs/openid-connect-core-1_0.html#CodeFlowAuth) on the Relying Party's end.
//
// 🚨 SECURITY
func authHandler(w http.ResponseWriter, r *http.Request) {
switch strings.TrimPrefix(r.URL.Path, authPrefix) {
case "/login":
// Endpoint that starts the Authentication Request Code Flow.
p, handled := handleGetProvider(r.Context(), w, r.URL.Query().Get("pc"))
if handled {
return
}
redirectToAuthRequest(w, r, p, r.URL.Query().Get("redirect"))
return
case "/callback":
// Endpoint for the OIDC Authorization Response. See http://openid.net/specs/openid-connect-core-1_0.html#AuthResponse.
ctx := r.Context()
if authError := r.URL.Query().Get("error"); authError != "" {
errorDesc := r.URL.Query().Get("error_description")
log15.Error("OpenID Connect auth provider returned error to callback.", "error", authError, "description", errorDesc)
http.Error(w, fmt.Sprintf("Authentication failed. Try signing in again (and clearing cookies for the current site). The authentication provider reported the following problems.\n\n%s\n\n%s", authError, errorDesc), http.StatusUnauthorized)
return
}
// Validate state parameter to prevent CSRF attacks
stateParam := r.URL.Query().Get("state")
if stateParam == "" {
http.Error(w, "Authentication failed. Try signing in again (and clearing cookies for the current site). No OpenID Connect state query parameter specified.", http.StatusBadRequest)
return
}
stateCookie, err := r.Cookie(stateCookieName)
if err == http.ErrNoCookie {
log15.Error("OpenID Connect auth failed: no state cookie found (possible request forgery).")
http.Error(w, fmt.Sprintf("Authentication failed. Try signing in again (and clearing cookies for the current site). The error was: no OpenID Connect state cookie found (possible request forgery, or more than %s elapsed since you started the authentication process).", stateCookieTimeout), http.StatusBadRequest)
return
} else if err != nil {
log15.Error("OpenID Connect auth failed: could not read state cookie (possible request forgery).", "error", err)
http.Error(w, "Authentication failed. Try signing in again (and clearing cookies for the current site). The error was: invalid OpenID Connect state cookie.", http.StatusInternalServerError)
return
}
if stateCookie.Value != stateParam {
log15.Error("OpenID Connect auth failed: state cookie mismatch (possible request forgery).")
http.Error(w, "Authentication failed. Try signing in again (and clearing cookies for the current site). The error was: OpenID Connect state parameter did not match the expected value (possible request forgery).", http.StatusBadRequest)
return
}
// Decode state param value
var state authnState
if err := state.Decode(stateParam); err != nil {
log15.Error("OpenID Connect auth failed: state parameter was malformed.", "error", err)
http.Error(w, "Authentication failed. OpenID Connect state parameter was malformed.", http.StatusBadRequest)
return
}
// 🚨 SECURITY: TODO(sqs): Do we need to check state.CSRFToken?
p, handled := handleGetProvider(r.Context(), w, state.ProviderID)
if handled {
return
}
verifier := p.oidc.Verifier(&oidc.Config{ClientID: p.config.ClientID})
// Exchange the code for an access token. See http://openid.net/specs/openid-connect-core-1_0.html#TokenRequest.
oauth2Token, err := p.oauth2Config().Exchange(ctx, r.URL.Query().Get("code"))
if err != nil {
log15.Error("OpenID Connect auth failed: failed to obtain access token from OP.", "error", err)
http.Error(w, "Authentication failed. Try signing in again. The error was: unable to obtain access token from issuer.", http.StatusUnauthorized)
return
}
// Extract the ID Token from the Access Token. See http://openid.net/specs/openid-connect-core-1_0.html#TokenResponse.
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
if !ok {
log15.Error("OpenID Connect auth failed: the issuer's authorization response did not contain an ID token.")
http.Error(w, "Authentication failed. Try signing in again. The error was: the issuer's authorization response did not contain an ID token.", http.StatusUnauthorized)
return
}
// Parse and verify ID Token payload. See http://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation.
var idToken *oidc.IDToken
if mockVerifyIDToken != nil {
idToken = mockVerifyIDToken(rawIDToken)
} else {
idToken, err = verifier.Verify(ctx, rawIDToken)
if err != nil {
log15.Error("OpenID Connect auth failed: the ID token verification failed.", "error", err)
http.Error(w, "Authentication failed. Try signing in again. The error was: OpenID Connect ID token could not be verified.", http.StatusUnauthorized)
return
}
}
// Validate the nonce. The Verify method explicitly doesn't handle nonce validation, so we do that here.
// We set the nonce to be the same as the state in the Authentication Request state, so we check for equality
// here.
if idToken.Nonce != stateParam {
log15.Error("OpenID Connect auth failed: nonce is incorrect (possible replay attach).")
http.Error(w, "Authentication failed. Try signing in again (and clearing cookies for the current site). The error was: OpenID Connect nonce is incorrect (possible replay attack).", http.StatusUnauthorized)
return
}
userInfo, err := p.oidc.UserInfo(ctx, oauth2.StaticTokenSource(oauth2Token))
if err != nil {
log15.Error("Failed to get userinfo", "error", err)
http.Error(w, "Failed to get userinfo: "+err.Error(), http.StatusInternalServerError)
return
}
if p.config.RequireEmailDomain != "" && !strings.HasSuffix(userInfo.Email, "@"+p.config.RequireEmailDomain) {
log15.Error("OpenID Connect auth failed: user's email is not from allowed domain.", "userEmail", userInfo.Email, "requireEmailDomain", p.config.RequireEmailDomain)
http.Error(w, fmt.Sprintf("Authentication failed. Only users in %q are allowed.", p.config.RequireEmailDomain), http.StatusUnauthorized)
return
}
var claims userClaims
if err := userInfo.Claims(&claims); err != nil {
log15.Warn("OpenID Connect auth: could not parse userInfo claims.", "error", err)
}
actr, safeErrMsg, err := getOrCreateUser(ctx, p, idToken, userInfo, &claims)
if err != nil {
log15.Error("OpenID Connect auth failed: error looking up OpenID-authenticated user.", "error", err, "userErr", safeErrMsg)
http.Error(w, safeErrMsg, http.StatusInternalServerError)
return
}
var exp time.Duration
// 🚨 SECURITY: TODO(sqs): We *should* uncomment the lines below to make our own sessions
// only last for as long as the OP said the access token is active for. Unfortunately,
// until we support refreshing access tokens in the background
// (https://github.com/sourcegraph/sourcegraph/issues/11340), this provides a bad user
// experience because users need to re-authenticate via OIDC every minute or so
// (assuming their OIDC OP, like many, has a 1-minute access token validity period).
//
// if !idToken.Expiry.IsZero() {
// exp = time.Until(idToken.Expiry)
// }
if err := session.SetActor(w, r, actr, exp); err != nil {
log15.Error("OpenID Connect auth failed: could not initiate session.", "error", err)
http.Error(w, "Authentication failed. Try signing in again (and clearing cookies for the current site). The error was: could not initiate session.", http.StatusInternalServerError)
return
}
data := sessionData{
ID: p.ConfigID(),
AccessToken: oauth2Token.AccessToken,
TokenType: oauth2Token.TokenType,
}
if err := session.SetData(w, r, sessionKey, data); err != nil {
// It's not fatal if this fails. It just means we won't be able to sign the user out of
// the OP.
log15.Warn("Failed to set OpenID Connect session data. The session is still secure, but Sourcegraph will be unable to revoke the user's token or redirect the user to the end-session endpoint after the user signs out of Sourcegraph.", "error", err)
}
// 🚨 SECURITY: Call auth.SafeRedirectURL to avoid an open-redirect vuln.
http.Redirect(w, r, auth.SafeRedirectURL(state.Redirect), http.StatusFound)
default:
http.Error(w, "", http.StatusNotFound)
}
}
// authnState is the state parameter passed to the Authn request and returned in the Authn response callback.
type authnState struct {
CSRFToken string `json:"csrfToken"`
Redirect string `json:"redirect"`
// Allow /.auth/callback to demux callbacks from multiple OpenID Connect OPs.
ProviderID string `json:"p"`
}
// Encode returns the base64-encoded JSON representation of the authn state.
func (s *authnState) Encode() string {
b, _ := json.Marshal(s)
return base64.StdEncoding.EncodeToString(b)
}
// Decode decodes the base64-encoded JSON representation of the authn state into the receiver.
func (s *authnState) Decode(encoded string) error {
b, err := base64.StdEncoding.DecodeString(encoded)
if err != nil {
return err
}
return json.Unmarshal(b, s)
}
const stateCookieTimeout = time.Minute * 15
func redirectToAuthRequest(w http.ResponseWriter, r *http.Request, p *provider, returnToURL string) {
// The state parameter is an opaque value used to maintain state between the original Authentication Request
// and the callback. We do not record any state beyond a CSRF token used to defend against CSRF attacks against the callback.
// We use the CSRF token created by gorilla/csrf that is used for other app endpoints as the OIDC state parameter.
//
// See http://openid.net/specs/openid-connect-core-1_0.html#AuthRequest of the OIDC spec.
state := (&authnState{
CSRFToken: csrf.Token(r),
Redirect: returnToURL,
ProviderID: p.ConfigID().ID,
}).Encode()
http.SetCookie(w, &http.Cookie{
Name: stateCookieName,
Value: state,
Path: auth.AuthURLPrefix + "/", // include the OIDC redirect URI (/.auth/callback not /.auth/openidconnect/callback for BACKCOMPAT)
Expires: time.Now().Add(stateCookieTimeout),
})
// Redirect to the OP's Authorization Endpoint for authentication. The nonce is an optional
// string value used to associate a Client session with an ID Token and to mitigate replay attacks.
// Whereas the state parameter is used in validating the Authentication Request
// callback, the nonce is used in validating the response to the ID Token request.
// We re-use the Authn request state as the nonce.
//
// See http://openid.net/specs/openid-connect-core-1_0.html#AuthRequest of the OIDC spec.
http.Redirect(w, r, p.oauth2Config().AuthCodeURL(state, oidc.Nonce(state)), http.StatusFound)
}

View File

@ -0,0 +1,365 @@
package openidconnect
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"os"
"strings"
"testing"
"time"
"github.com/sourcegraph/enterprise/cmd/frontend/internal/licensing"
"github.com/sourcegraph/enterprise/pkg/license"
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/auth"
"github.com/sourcegraph/sourcegraph/pkg/actor"
"github.com/sourcegraph/sourcegraph/schema"
oidc "github.com/coreos/go-oidc"
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/session"
)
// providerJSON is the JSON structure the OIDC provider returns at its discovery endpoing
type providerJSON struct {
Issuer string `json:"issuer"`
AuthURL string `json:"authorization_endpoint"`
TokenURL string `json:"token_endpoint"`
JWKSURL string `json:"jwks_uri"`
UserInfoURL string `json:"userinfo_endpoint"`
}
var (
testOIDCUser = "bob-test-user"
testClientID = "aaaaaaaaaaaaaa"
)
// new OIDCIDServer returns a new running mock OIDC ID Provider service. It is the caller's
// responsibility to call Close().
func newOIDCIDServer(t *testing.T, code string, oidcProvider *schema.OpenIDConnectAuthProvider) (server *httptest.Server, emailPtr *string) {
idBearerToken := "test_id_token_f4bdefbd77f"
s := http.NewServeMux()
s.HandleFunc("/.well-known/openid-configuration", func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
json.NewEncoder(w).Encode(providerJSON{
Issuer: oidcProvider.Issuer,
AuthURL: oidcProvider.Issuer + "/oauth2/v1/authorize",
TokenURL: oidcProvider.Issuer + "/oauth2/v1/token",
UserInfoURL: oidcProvider.Issuer + "/oauth2/v1/userinfo",
})
})
s.HandleFunc("/oauth2/v1/token", func(w http.ResponseWriter, r *http.Request) {
if r.Method != "POST" {
http.Error(w, "unexpected", http.StatusBadRequest)
return
}
b, _ := ioutil.ReadAll(r.Body)
values, _ := url.ParseQuery(string(b))
if values.Get("code") != code {
t.Errorf("got code %q, want %q", values.Get("code"), code)
}
if got, want := values.Get("grant_type"), "authorization_code"; got != want {
t.Errorf("got grant_type %v, want %v", got, want)
}
redirectURI, _ := url.QueryUnescape(values.Get("redirect_uri"))
if want := "http://example.com/.auth/callback"; redirectURI != want {
t.Errorf("got redirect_uri %v, want %v", redirectURI, want)
}
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(fmt.Sprintf(`{
"access_token": "aaaaa",
"token_type": "Bearer",
"expires_in": 3600,
"scope": "openid",
"id_token": %q
}`, idBearerToken)))
})
email := "bob@example.com"
s.HandleFunc("/oauth2/v1/userinfo", func(w http.ResponseWriter, r *http.Request) {
authzHeader := r.Header.Get("Authorization")
authzParts := strings.Split(authzHeader, " ")
if len(authzParts) != 2 {
t.Fatalf("Expected 2 parts to authz header, instead got %d: %q", len(authzParts), authzHeader)
}
if authzParts[0] != "Bearer" {
t.Fatalf("No bearer token found in authz header %q", authzHeader)
}
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(fmt.Sprintf(`{
"sub": %q,
"profile": "This is a profile",
"email": "`+email+`",
"email_verified": true,
"picture": "https://example.com/picture.png"
}`, testOIDCUser)))
})
srv := httptest.NewServer(s)
auth.SetMockCreateOrUpdateUser(func(u db.NewUser, a db.ExternalAccountSpec) (userID int32, err error) {
if a.ServiceType == "openidconnect" && a.ServiceID == oidcProvider.Issuer && a.ClientID == testClientID && a.AccountID == testOIDCUser {
return 123, nil
}
return 0, fmt.Errorf("account %v not found in mock", a)
})
return srv, &email
}
func TestMiddleware(t *testing.T) {
cleanup := session.ResetMockSessionStore(t)
defer cleanup()
licensing.MockGetConfiguredProductLicenseInfo = func() (*license.Info, error) {
return &license.Info{Tags: licensing.EnterpriseTags}, nil
}
defer func() { licensing.MockGetConfiguredProductLicenseInfo = nil }()
tempdir, err := ioutil.TempDir("", "sourcegraph-oidc-test")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tempdir)
mockGetProviderValue = &provider{
config: schema.OpenIDConnectAuthProvider{
ClientID: testClientID,
ClientSecret: "aaaaaaaaaaaaaaaaaaaaaaaaa",
RequireEmailDomain: "example.com",
},
}
defer func() { mockGetProviderValue = nil }()
auth.SetMockProviders([]auth.Provider{mockGetProviderValue})
defer func() { auth.SetMockProviders(nil) }()
oidcIDServer, emailPtr := newOIDCIDServer(t, "THECODE", &mockGetProviderValue.config)
defer oidcIDServer.Close()
defer func() { auth.SetMockCreateOrUpdateUser(nil) }()
mockGetProviderValue.config.Issuer = oidcIDServer.URL
if err := mockGetProviderValue.Refresh(context.Background()); err != nil {
t.Fatal(err)
}
validState := (&authnState{CSRFToken: "THE_CSRF_TOKEN", Redirect: "/redirect", ProviderID: mockGetProviderValue.ConfigID().ID}).Encode()
mockVerifyIDToken = func(rawIDToken string) *oidc.IDToken {
if rawIDToken != "test_id_token_f4bdefbd77f" {
t.Fatalf("unexpected raw ID token: %s", rawIDToken)
}
return &oidc.IDToken{
Issuer: oidcIDServer.URL,
Subject: testOIDCUser,
Expiry: time.Now().Add(time.Hour),
Nonce: validState, // we re-use the state param as the nonce
}
}
const mockUserID = 123
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
authedHandler := http.NewServeMux()
authedHandler.Handle("/.api/", Middleware.API(h))
authedHandler.Handle("/", Middleware.App(h))
doRequest := func(method, urlStr, body string, cookies []*http.Cookie, authed bool) *http.Response {
req := httptest.NewRequest(method, urlStr, bytes.NewBufferString(body))
for _, cookie := range cookies {
req.AddCookie(cookie)
}
if authed {
req = req.WithContext(actor.WithActor(context.Background(), &actor.Actor{UID: mockUserID}))
}
respRecorder := httptest.NewRecorder()
authedHandler.ServeHTTP(respRecorder, req)
return respRecorder.Result()
}
state := func(t *testing.T, urlStr string) (state authnState) {
u, _ := url.Parse(urlStr)
if err := state.Decode(u.Query().Get("nonce")); err != nil {
t.Fatal(err)
}
return state
}
t.Run("unauthenticated homepage visit -> oidc auth flow", func(t *testing.T) {
resp := doRequest("GET", "http://example.com/", "", nil, false)
if want := http.StatusFound; resp.StatusCode != want {
t.Errorf("got response code %v, want %v", resp.StatusCode, want)
}
if got, want := resp.Header.Get("Location"), "/oauth2/v1/authorize?"; !strings.Contains(got, want) {
t.Errorf("got redirect URL %v, want contains %v", got, want)
}
if state, want := state(t, resp.Header.Get("Location")), "/"; state.Redirect != want {
t.Errorf("got redirect destination %q, want %q", state.Redirect, want)
}
})
t.Run("unauthenticated subpage visit -> oidc auth flow", func(t *testing.T) {
resp := doRequest("GET", "http://example.com/page", "", nil, false)
if want := http.StatusFound; resp.StatusCode != want {
t.Errorf("got response code %v, want %v", resp.StatusCode, want)
}
if got, want := resp.Header.Get("Location"), "/oauth2/v1/authorize?"; !strings.Contains(got, want) {
t.Errorf("got redirect URL %v, want contains %v", got, want)
}
if state, want := state(t, resp.Header.Get("Location")), "/page"; state.Redirect != want {
t.Errorf("got redirect destination %q, want %q", state.Redirect, want)
}
})
t.Run("unauthenticated non-existent page visit -> oidc auth flow", func(t *testing.T) {
resp := doRequest("GET", "http://example.com/nonexistent", "", nil, false)
if want := http.StatusFound; resp.StatusCode != want {
t.Errorf("got response code %v, want %v", resp.StatusCode, want)
}
if got, want := resp.Header.Get("Location"), "/oauth2/v1/authorize?"; !strings.Contains(got, want) {
t.Errorf("got redirect URL %v, want contains %v", got, want)
}
if state, want := state(t, resp.Header.Get("Location")), "/nonexistent"; state.Redirect != want {
t.Errorf("got redirect destination %q, want %q", state.Redirect, want)
}
})
t.Run("unauthenticated API request -> pass through", func(t *testing.T) {
resp := doRequest("GET", "http://example.com/.api/foo", "", nil, false)
if want := http.StatusOK; resp.StatusCode != want {
t.Errorf("got response code %v, want %v", resp.StatusCode, want)
}
})
t.Run("login -> oidc auth flow", func(t *testing.T) {
resp := doRequest("GET", "http://example.com/.auth/openidconnect/login?p="+mockGetProviderValue.ConfigID().ID, "", nil, false)
if want := http.StatusFound; resp.StatusCode != want {
t.Errorf("got response code %v, want %v", resp.StatusCode, want)
}
locHeader := resp.Header.Get("Location")
if !strings.HasPrefix(locHeader, mockGetProviderValue.config.Issuer+"/") {
t.Error("did not redirect to OIDC Provider")
}
idpLoginURL, err := url.Parse(locHeader)
if err != nil {
t.Fatal(err)
}
if got, want := idpLoginURL.Query().Get("client_id"), mockGetProviderValue.config.ClientID; got != want {
t.Errorf("got client id %q, want %q", got, want)
}
if got, want := idpLoginURL.Query().Get("redirect_uri"), "http://example.com/.auth/callback"; got != want {
t.Errorf("got redirect_uri %v, want %v", got, want)
}
if got, want := idpLoginURL.Query().Get("response_type"), "code"; got != want {
t.Errorf("got response_type %v, want %v", got, want)
}
if got, want := idpLoginURL.Query().Get("scope"), "openid profile email"; got != want {
t.Errorf("got scope %v, want %v", got, want)
}
})
t.Run("OIDC callback without CSRF token -> error", func(t *testing.T) {
invalidState := (&authnState{CSRFToken: "bad", ProviderID: mockGetProviderValue.ConfigID().ID}).Encode()
resp := doRequest("GET", "http://example.com/.auth/callback?code=THECODE&state="+url.PathEscape(invalidState), "", nil, false)
if want := http.StatusBadRequest; resp.StatusCode != want {
t.Errorf("got status code %v, want %v", resp.StatusCode, want)
}
})
t.Run("OIDC callback with CSRF token -> set auth cookies", func(t *testing.T) {
resp := doRequest("GET", "http://example.com/.auth/callback?code=THECODE&state="+url.PathEscape(validState), "", []*http.Cookie{{Name: stateCookieName, Value: validState}}, false)
if want := http.StatusFound; resp.StatusCode != want {
t.Errorf("got status code %v, want %v", resp.StatusCode, want)
}
if got, want := resp.Header.Get("Location"), "/redirect"; got != want {
t.Errorf("got redirect URL %v, want %v", got, want)
}
})
*emailPtr = "bob@invalid.com" // doesn't match requiredEmailDomain
t.Run("OIDC callback with bad email domain -> error", func(t *testing.T) {
resp := doRequest("GET", "http://example.com/.auth/callback?code=THECODE&state="+url.PathEscape(validState), "", []*http.Cookie{{Name: stateCookieName, Value: validState}}, false)
if want := http.StatusUnauthorized; resp.StatusCode != want {
t.Errorf("got status code %v, want %v", resp.StatusCode, want)
}
})
t.Run("authenticated app request", func(t *testing.T) {
resp := doRequest("GET", "http://example.com/", "", nil, true)
if want := http.StatusOK; resp.StatusCode != want {
t.Errorf("got response code %v, want %v", resp.StatusCode, want)
}
})
t.Run("authenticated API request", func(t *testing.T) {
resp := doRequest("GET", "http://example.com/.api/foo", "", nil, true)
if want := http.StatusOK; resp.StatusCode != want {
t.Errorf("got response code %v, want %v", resp.StatusCode, want)
}
})
}
func TestMiddleware_NoOpenRedirect(t *testing.T) {
cleanup := session.ResetMockSessionStore(t)
defer cleanup()
licensing.MockGetConfiguredProductLicenseInfo = func() (*license.Info, error) {
return &license.Info{Tags: licensing.EnterpriseTags}, nil
}
defer func() { licensing.MockGetConfiguredProductLicenseInfo = nil }()
tempdir, err := ioutil.TempDir("", "sourcegraph-oidc-test-no-open-redirect")
if err != nil {
t.Fatal(err)
}
defer os.RemoveAll(tempdir)
mockGetProviderValue = &provider{
config: schema.OpenIDConnectAuthProvider{
ClientID: testClientID,
ClientSecret: "aaaaaaaaaaaaaaaaaaaaaaaaa",
},
}
defer func() { mockGetProviderValue = nil }()
auth.SetMockProviders([]auth.Provider{mockGetProviderValue})
defer func() { auth.SetMockProviders(nil) }()
oidcIDServer, _ := newOIDCIDServer(t, "THECODE", &mockGetProviderValue.config)
defer oidcIDServer.Close()
defer func() { auth.SetMockCreateOrUpdateUser(nil) }()
mockGetProviderValue.config.Issuer = oidcIDServer.URL
if err := mockGetProviderValue.Refresh(context.Background()); err != nil {
t.Fatal(err)
}
state := (&authnState{CSRFToken: "THE_CSRF_TOKEN", Redirect: "http://evil.com", ProviderID: mockGetProviderValue.ConfigID().ID}).Encode()
mockVerifyIDToken = func(rawIDToken string) *oidc.IDToken {
if rawIDToken != "test_id_token_f4bdefbd77f" {
t.Fatalf("unexpected raw ID token: %s", rawIDToken)
}
return &oidc.IDToken{
Issuer: oidcIDServer.URL,
Subject: testOIDCUser,
Expiry: time.Now().Add(time.Hour),
Nonce: state, // we re-use the state param as the nonce
}
}
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {})
authedHandler := Middleware.App(h)
doRequest := func(method, urlStr, body string, cookies []*http.Cookie) *http.Response {
req := httptest.NewRequest(method, urlStr, bytes.NewBufferString(body))
for _, cookie := range cookies {
req.AddCookie(cookie)
}
respRecorder := httptest.NewRecorder()
authedHandler.ServeHTTP(respRecorder, req)
return respRecorder.Result()
}
t.Run("OIDC callback with CSRF token -> set auth cookies", func(t *testing.T) {
resp := doRequest("GET", "http://example.com/.auth/callback?code=THECODE&state="+url.PathEscape(state), "", []*http.Cookie{{Name: stateCookieName, Value: state}})
if want := http.StatusFound; resp.StatusCode != want {
t.Errorf("got status code %v, want %v", resp.StatusCode, want)
}
if got, want := resp.Header.Get("Location"), "/"; got != want {
t.Errorf("got redirect URL %v, want %v", got, want)
} // Redirect to "/", NOT "http://evil.com"
})
}

View File

@ -0,0 +1,158 @@
package openidconnect
import (
"context"
"fmt"
"net/http"
"net/url"
"path"
"strings"
"sync"
oidc "github.com/coreos/go-oidc"
"github.com/pkg/errors"
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/auth"
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/globals"
"github.com/sourcegraph/sourcegraph/schema"
"golang.org/x/net/context/ctxhttp"
"golang.org/x/oauth2"
)
const providerType = "openidconnect"
type provider struct {
config schema.OpenIDConnectAuthProvider
mu sync.Mutex
oidc *oidcProvider
refreshErr error
}
// ConfigID implements auth.Provider.
func (p *provider) ConfigID() auth.ProviderConfigID {
return auth.ProviderConfigID{
Type: providerType,
ID: providerConfigID(&p.config),
}
}
// Config implements auth.Provider.
func (p *provider) Config() schema.AuthProviders {
return schema.AuthProviders{Openidconnect: &p.config}
}
// Refresh implements auth.Provider.
func (p *provider) Refresh(ctx context.Context) error {
p.mu.Lock()
defer p.mu.Unlock()
p.oidc, p.refreshErr = newProvider(ctx, p.config.Issuer)
return p.refreshErr
}
func (p *provider) getCachedInfoAndError() (*auth.ProviderInfo, error) {
info := auth.ProviderInfo{
ServiceID: p.config.Issuer,
ClientID: p.config.ClientID,
DisplayName: p.config.DisplayName,
AuthenticationURL: (&url.URL{
Path: path.Join(authPrefix, "login"),
RawQuery: (url.Values{"pc": []string{providerConfigID(&p.config)}}).Encode(),
}).String(),
}
if info.DisplayName == "" {
info.DisplayName = "OpenID Connect"
}
p.mu.Lock()
defer p.mu.Unlock()
err := p.refreshErr
if err != nil {
err = errors.WithMessage(err, "failed to initialize OpenID Connect auth provider")
} else if p.oidc == nil {
err = errors.New("OpenID Connect auth provider is not yet initialized")
}
return &info, err
}
// CachedInfo implements auth.Provider.
func (p *provider) CachedInfo() *auth.ProviderInfo {
info, _ := p.getCachedInfoAndError()
return info
}
func (p *provider) oauth2Config() *oauth2.Config {
return &oauth2.Config{
ClientID: p.config.ClientID,
ClientSecret: p.config.ClientSecret,
// It would be nice if this was "/.auth/openidconnect/callback" not "/.auth/callback", but
// many instances have the "/.auth/callback" value hardcoded in their external auth
// provider, so we can't change it easily
RedirectURL: globals.AppURL().ResolveReference(&url.URL{Path: path.Join(auth.AuthURLPrefix, "callback")}).String(),
Endpoint: p.oidc.Endpoint(),
Scopes: []string{oidc.ScopeOpenID, "profile", "email"},
}
}
// oidcProvider is an OpenID Connect oidcProvider with additional claims parsed from the service oidcProvider
// discovery response (beyond what github.com/coreos/go-oidc parses by default).
type oidcProvider struct {
oidc.Provider
providerExtraClaims
}
type providerExtraClaims struct {
// EndSessionEndpoint is the URL of the OP's endpoint that logs the user out of the OP (provided
// in the "end_session_endpoint" field of the OP's service discovery response). See
// https://openid.net/specs/openid-connect-session-1_0.html#OPMetadata.
EndSessionEndpoint string `json:"end_session_endpoint,omitempty"`
// RevocationEndpoint is the URL of the OP's revocation endpoint (provided in the
// "revocation_endpoint" field of the OP's service discovery response). See
// https://openid.net/specs/openid-heart-openid-connect-1_0.html#rfc.section.3.5 and
// https://tools.ietf.org/html/rfc7009.
RevocationEndpoint string `json:"revocation_endpoint,omitempty"`
}
var mockNewProvider func(issuerURL string) (*oidcProvider, error)
func newProvider(ctx context.Context, issuerURL string) (*oidcProvider, error) {
if mockNewProvider != nil {
return mockNewProvider(issuerURL)
}
bp, err := oidc.NewProvider(context.Background(), issuerURL)
if err != nil {
return nil, err
}
p := &oidcProvider{Provider: *bp}
if err := bp.Claims(&p.providerExtraClaims); err != nil {
return nil, err
}
return p, nil
}
// revokeToken implements Token Revocation. See https://tools.ietf.org/html/rfc7009.
func revokeToken(ctx context.Context, p *provider, accessToken, tokenType string) error {
postData := url.Values{}
postData.Set("token", accessToken)
if tokenType != "" {
postData.Set("token_type_hint", tokenType)
}
req, err := http.NewRequest(p.oidc.RevocationEndpoint, "application/x-www-form-urlencoded", strings.NewReader(postData.Encode()))
if err != nil {
return err
}
req.SetBasicAuth(p.config.ClientID, p.config.ClientSecret)
resp, err := ctxhttp.Do(ctx, nil, req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("non-200 HTTP response from token revocation endpoint %s: HTTP %d", p.oidc.RevocationEndpoint, resp.StatusCode)
}
return nil
}

View File

@ -0,0 +1,56 @@
package openidconnect
import (
"fmt"
"net/http"
"github.com/pkg/errors"
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/auth"
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/session"
)
const sessionKey = "oidc@0"
type sessionData struct {
ID auth.ProviderConfigID
// Store only the oauth2.Token fields we need, to avoid hitting the ~4096-byte session data
// limit.
AccessToken string
TokenType string
}
// SignOut clears OpenID Connect-related data from the session. If possible, it revokes the token
// from the OP. If there is an end-session endpoint, it returns that for the caller to present to
// the user.
func SignOut(w http.ResponseWriter, r *http.Request) (endSessionEndpoint string, err error) {
defer func() {
if err != nil {
_ = session.SetData(w, r, sessionKey, nil) // clear the bad data
}
}()
var data *sessionData
if err := session.GetData(r, sessionKey, &data); err != nil {
return "", errors.WithMessage(err, "reading OpenID Connect session data")
}
if err := session.SetData(w, r, sessionKey, nil); err != nil {
return "", errors.WithMessage(err, "clearing OpenID Connect session data")
}
if data != nil {
p := getProvider(data.ID.ID)
if p == nil {
return "", fmt.Errorf("unable to revoke token or end session for OpenID Connect because no provider %q exists", data.ID)
}
endSessionEndpoint = p.oidc.EndSessionEndpoint
if p.oidc.RevocationEndpoint != "" {
if err := revokeToken(r.Context(), p, data.AccessToken, data.TokenType); err != nil {
return endSessionEndpoint, errors.WithMessage(err, "revoking OpenID Connect token")
}
}
}
return endSessionEndpoint, nil
}

View File

@ -0,0 +1,73 @@
package openidconnect
import (
"context"
"fmt"
oidc "github.com/coreos/go-oidc"
"github.com/pkg/errors"
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/auth"
"github.com/sourcegraph/sourcegraph/pkg/actor"
)
// getOrCreateUser gets or creates a user account based on the OpenID Connect token. It returns the
// authenticated actor if successful; otherwise it returns an friendly error message (safeErrMsg)
// that is safe to display to users, and a non-nil err with lower-level error details.
func getOrCreateUser(ctx context.Context, p *provider, idToken *oidc.IDToken, userInfo *oidc.UserInfo, claims *userClaims) (_ *actor.Actor, safeErrMsg string, err error) {
if userInfo.Email == "" {
return nil, "Only users with an email address may authenticate to Sourcegraph.", errors.New("no email address in claims")
}
if unverifiedEmail := claims.EmailVerified != nil && !*claims.EmailVerified; unverifiedEmail {
// If the OP explicitly reports `"email_verified": false`, then reject the authentication
// attempt. If undefined or true, then it will be allowed.
return nil, fmt.Sprintf("Only users with verified email addresses may authenticate to Sourcegraph. The email address %q is not verified on the external authentication provider.", userInfo.Email), fmt.Errorf("refusing unverified user email address %q", userInfo.Email)
}
pi, err := p.getCachedInfoAndError()
if err != nil {
return nil, "", err
}
login := claims.PreferredUsername
if login == "" {
login = userInfo.Email
}
email := userInfo.Email
var displayName = claims.GivenName
if displayName == "" {
if claims.Name == "" {
displayName = claims.Name
} else {
displayName = login
}
}
login, err = auth.NormalizeUsername(login)
if err != nil {
return nil, fmt.Sprintf("Error normalizing the username %q. See https://about.sourcegraph.com/docs/config/authentication#username-normalization.", login), err
}
var data db.ExternalAccountData
auth.SetExternalAccountData(&data.AccountData, struct {
IDToken *oidc.IDToken `json:"idToken"`
UserInfo *oidc.UserInfo `json:"userInfo"`
UserClaims *userClaims `json:"userClaims"`
}{IDToken: idToken, UserInfo: userInfo, UserClaims: claims})
userID, safeErrMsg, err := auth.CreateOrUpdateUser(ctx, db.NewUser{
Username: login,
Email: email,
EmailIsVerified: email != "", // TODO(sqs): https://github.com/sourcegraph/sourcegraph/issues/10118
DisplayName: displayName,
AvatarURL: claims.Picture,
}, db.ExternalAccountSpec{
ServiceType: providerType,
ServiceID: pi.ServiceID,
ClientID: pi.ClientID,
AccountID: idToken.Subject,
}, data)
if err != nil {
return nil, safeErrMsg, err
}
return actor.FromUser(userID), "", nil
}

View File

@ -0,0 +1,150 @@
package saml
import (
"context"
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"log"
"net/http"
"path"
"strconv"
"strings"
"github.com/sourcegraph/enterprise/cmd/frontend/internal/licensing"
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/auth"
"github.com/sourcegraph/sourcegraph/pkg/conf"
"github.com/sourcegraph/sourcegraph/pkg/env"
"github.com/sourcegraph/sourcegraph/schema"
log15 "gopkg.in/inconshreveable/log15.v2"
)
var mockGetProviderValue *provider
// getProvider looks up the registered saml auth provider with the given ID.
func getProvider(pcID string) *provider {
if mockGetProviderValue != nil {
return mockGetProviderValue
}
p, _ := auth.GetProviderByConfigID(auth.ProviderConfigID{Type: providerType, ID: pcID}).(*provider)
if p != nil {
return p
}
// Special case: if there is only a single SAML auth provider, return it regardless of the pcID.
for _, ap := range auth.Providers() {
if ap.Config().Saml != nil {
if p != nil {
return nil // multiple SAML providers, can't use this special case
}
p = ap.(*provider)
}
}
return p
}
func handleGetProvider(ctx context.Context, w http.ResponseWriter, pcID string) (p *provider, handled bool) {
handled = true // safer default
// License check.
if !licensing.IsFeatureEnabledLenient(licensing.FeatureExternalAuthProvider) {
licensing.WriteSubscriptionErrorResponseForFeature(w, "SAML user authentication (SSO)")
return nil, true
}
p = getProvider(pcID)
if p == nil {
log15.Error("No SAML auth provider found with ID.", "id", pcID)
http.Error(w, "Misconfigured SAML auth provider.", http.StatusInternalServerError)
return nil, true
}
if err := p.Refresh(ctx); err != nil {
log15.Error("Error refreshing SAML auth provider.", "id", p.ConfigID(), "error", err)
http.Error(w, "Unexpected error refreshing SAML authentication provider.", http.StatusInternalServerError)
return nil, true
}
return p, false
}
func init() {
conf.ContributeValidator(validateConfig)
}
func validateConfig(c schema.SiteConfiguration) (problems []string) {
var loggedNeedsAppURL bool
for _, p := range c.AuthProviders {
if p.Saml != nil && c.AppURL == "" && !loggedNeedsAppURL {
problems = append(problems, `saml auth provider requires appURL to be set to the external URL of your site (example: https://sourcegraph.example.com)`)
loggedNeedsAppURL = true
}
}
seen := map[schema.SAMLAuthProvider]int{}
for i, p := range c.AuthProviders {
if p.Saml != nil {
if j, ok := seen[*p.Saml]; ok {
problems = append(problems, fmt.Sprintf("SAML auth provider at index %d is duplicate of index %d, ignoring", i, j))
} else {
seen[*p.Saml] = i
}
}
}
return problems
}
func withConfigDefaults(pc *schema.SAMLAuthProvider) *schema.SAMLAuthProvider {
if pc.ServiceProviderIssuer == "" {
appURL := conf.Get().AppURL
if appURL == "" {
// An empty issuer will be detected as an error later.
return pc
}
// Derive default issuer from appURL.
tmp := *pc
tmp.ServiceProviderIssuer = strings.TrimSuffix(conf.Get().AppURL, "/") + path.Join(authPrefix, "metadata")
return &tmp
}
return pc
}
func getNameIDFormat(pc *schema.SAMLAuthProvider) string {
// Persistent is best because users will reuse their user_external_accounts row instead of (as
// with transient) creating a new one each time they authenticate.
const defaultNameIDFormat = "urn:oasis:names:tc:SAML:2.0:nameid-format:persistent"
if pc.NameIDFormat != "" {
return pc.NameIDFormat
}
return defaultNameIDFormat
}
// providerConfigID produces a semi-stable identifier for a saml auth provider config object. It is
// used to distinguish between multiple auth providers of the same type when in multi-step auth
// flows. Its value is never persisted, and it must be deterministic.
//
// If there is only a single saml auth provider, it returns the empty string because that satisfies
// the requirements above.
func providerConfigID(pc *schema.SAMLAuthProvider, multiple bool) string {
if !multiple {
return ""
}
data, err := json.Marshal(pc)
if err != nil {
panic(err)
}
b := sha256.Sum256(data)
return base64.RawURLEncoding.EncodeToString(b[:16])
}
var traceLogEnabled, _ = strconv.ParseBool(env.Get("INSECURE_SAML_LOG_TRACES", "false", "Log all SAML requests and responses. Only use during testing because the log messages will contain sensitive data."))
func traceLog(description, body string) {
if traceLogEnabled {
const n = 40
log.Printf("%s SAML trace: %s\n%s\n%s", strings.Repeat("=", n), description, body, strings.Repeat("=", n+len(description)+1))
}
}

View File

@ -0,0 +1,40 @@
package saml
import (
"testing"
"github.com/sourcegraph/sourcegraph/pkg/conf"
"github.com/sourcegraph/sourcegraph/schema"
)
func TestValidateCustom(t *testing.T) {
tests := map[string]struct {
input schema.SiteConfiguration
wantProblems []string
}{
"duplicates": {
input: schema.SiteConfiguration{
AppURL: "x",
AuthProviders: []schema.AuthProviders{
{Saml: &schema.SAMLAuthProvider{Type: "saml", IdentityProviderMetadataURL: "x"}},
{Saml: &schema.SAMLAuthProvider{Type: "saml", IdentityProviderMetadataURL: "x"}},
},
},
wantProblems: []string{"SAML auth provider at index 1 is duplicate of index 0"},
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
conf.TestValidator(t, test.input, validateConfig, test.wantProblems)
})
}
}
func TestProviderConfigID(t *testing.T) {
p := schema.SAMLAuthProvider{ServiceProviderIssuer: "x"}
id1 := providerConfigID(&p, true)
id2 := providerConfigID(&p, true)
if id1 != id2 {
t.Errorf("id1 (%q) != id2 (%q)", id1, id2)
}
}

View File

@ -0,0 +1,84 @@
package saml
import (
"context"
"sync"
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/auth"
"github.com/sourcegraph/sourcegraph/pkg/conf"
"github.com/sourcegraph/sourcegraph/schema"
log15 "gopkg.in/inconshreveable/log15.v2"
)
// Start trying to populate the cache of SAML IdP metadata immediately upon server startup and site
// config changes so users don't incur the wait on the first auth flow request.
func init() {
providersOfType := func(ps []schema.AuthProviders) []*schema.SAMLAuthProvider {
var pcs []*schema.SAMLAuthProvider
for _, p := range ps {
if p.Saml != nil {
pcs = append(pcs, withConfigDefaults(p.Saml))
}
}
return pcs
}
var (
init = true
mu sync.Mutex
cur []*schema.SAMLAuthProvider
reg = map[schema.SAMLAuthProvider]auth.Provider{}
)
conf.Watch(func() {
mu.Lock()
defer mu.Unlock()
// Only react when the config changes.
new := providersOfType(conf.Get().AuthProviders)
diff := diffProviderConfig(cur, new)
if len(diff) == 0 {
return
}
if !init {
log15.Info("Reloading changed SAML authentication provider configuration.")
}
multiple := len(new) >= 2
updates := make(map[auth.Provider]bool, len(diff))
for pc, op := range diff {
if old, ok := reg[pc]; ok {
delete(reg, pc)
updates[old] = false
}
if op {
new := &provider{config: pc, multiple: multiple}
reg[pc] = new
updates[new] = true
go func(p *provider) {
if err := p.Refresh(context.Background()); err != nil {
log15.Error("Error prefetching SAML service provider metadata.", "error", err)
}
}(new)
}
}
auth.UpdateProviders(updates)
cur = new
})
init = false
}
func diffProviderConfig(old, new []*schema.SAMLAuthProvider) map[schema.SAMLAuthProvider]bool {
diff := map[schema.SAMLAuthProvider]bool{}
for _, oldPC := range old {
diff[*oldPC] = false
}
for _, newPC := range new {
if _, ok := diff[*newPC]; ok {
delete(diff, *newPC)
} else {
diff[*newPC] = true
}
}
return diff
}

View File

@ -0,0 +1,46 @@
package saml
import (
"reflect"
"testing"
"github.com/sourcegraph/sourcegraph/schema"
)
func TestDiffProviderConfig(t *testing.T) {
var (
pc0 = &schema.SAMLAuthProvider{ServiceProviderIssuer: "0"}
pc0c = &schema.SAMLAuthProvider{ServiceProviderIssuer: "0", ServiceProviderPrivateKey: "x"}
pc1 = &schema.SAMLAuthProvider{ServiceProviderIssuer: "1"}
)
tests := map[string]struct {
old, new []*schema.SAMLAuthProvider
want map[schema.SAMLAuthProvider]bool
}{
"empty": {want: map[schema.SAMLAuthProvider]bool{}},
"added": {
old: nil,
new: []*schema.SAMLAuthProvider{pc0, pc1},
want: map[schema.SAMLAuthProvider]bool{*pc0: true, *pc1: true},
},
"changed": {
old: []*schema.SAMLAuthProvider{pc0, pc1},
new: []*schema.SAMLAuthProvider{pc0c, pc1},
want: map[schema.SAMLAuthProvider]bool{*pc0: false, *pc0c: true},
},
"removed": {
old: []*schema.SAMLAuthProvider{pc0, pc1},
new: []*schema.SAMLAuthProvider{pc1},
want: map[schema.SAMLAuthProvider]bool{*pc0: false},
},
}
for name, test := range tests {
t.Run(name, func(t *testing.T) {
diff := diffProviderConfig(test.old, test.new)
if !reflect.DeepEqual(diff, test.want) {
t.Errorf("got != want\n got %+v\nwant %+v", diff, test.want)
}
})
}
}

View File

@ -0,0 +1,3 @@
// Package SAML provides HTTP middleware that provides the necessary endpoints for a SAML Service
// Provider (SP) to complete the SAML authentication flow to authenticate to the frontend.
package saml

View File

@ -0,0 +1,262 @@
package saml
import (
"encoding/base64"
"encoding/json"
"encoding/xml"
"fmt"
"net/http"
"strings"
"time"
log15 "gopkg.in/inconshreveable/log15.v2"
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/auth"
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/session"
"github.com/sourcegraph/sourcegraph/pkg/actor"
)
// All SAML endpoints are under this path prefix.
const authPrefix = auth.AuthURLPrefix + "/saml"
// Middleware is middleware for SAML authentication, adding endpoints under the auth path prefix to
// enable the login flow an requiring login for all other endpoints.
//
// 🚨 SECURITY
var Middleware = &auth.Middleware{
API: func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authHandler(w, r, next, true)
})
},
App: func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
authHandler(w, r, next, false)
})
},
}
// authHandler is the new SAML HTTP auth handler.
//
// It uses github.com/russelhaering/gosaml2 and (unlike authHandler1) makes it possible to support
// multiple auth providers with SAML and expose more SAML functionality.
func authHandler(w http.ResponseWriter, r *http.Request, next http.Handler, isAPIRequest bool) {
// Delegate to SAML ACS and metadata endpoint handlers.
if !isAPIRequest && strings.HasPrefix(r.URL.Path, auth.AuthURLPrefix+"/saml/") {
samlSPHandler(w, r)
return
}
// If the actor is authenticated and not performing a SAML operation, then proceed to next.
if actor.FromContext(r.Context()).IsAuthenticated() {
next.ServeHTTP(w, r)
return
}
// If there is only one auth provider configured, the single auth provider is SAML, and it's an
// app request, redirect to signin immediately. The user wouldn't be able to do anything else
// anyway; there's no point in showing them a signin screen with just a single signin option.
if ps := auth.Providers(); len(ps) == 1 && ps[0].Config().Saml != nil && !isAPIRequest {
p, handled := handleGetProvider(r.Context(), w, ps[0].ConfigID().ID)
if handled {
return
}
redirectToAuthURL(w, r, p, auth.SafeRedirectURL(r.URL.String()))
return
}
next.ServeHTTP(w, r)
}
func samlSPHandler(w http.ResponseWriter, r *http.Request) {
requestPath := strings.TrimPrefix(r.URL.Path, authPrefix)
// Handle GET endpoints.
if r.Method == "GET" {
// All of these endpoints expect the provider ID in the URL query.
p, handled := handleGetProvider(r.Context(), w, r.URL.Query().Get("pc"))
if handled {
return
}
switch requestPath {
case "/metadata":
metadata, err := p.samlSP.Metadata()
if err != nil {
log15.Error("Error generating SAML service provider metadata.", "err", err)
http.Error(w, "", http.StatusInternalServerError)
return
}
buf, err := xml.MarshalIndent(metadata, "", " ")
if err != nil {
log15.Error("Error encoding SAML service provider metadata.", "err", err)
http.Error(w, "", http.StatusInternalServerError)
return
}
traceLog(fmt.Sprintf("Service Provider metadata: %s", p.ConfigID().ID), string(buf))
w.Header().Set("Content-Type", "application/samlmetadata+xml; charset=utf-8")
w.Write(buf)
return
case "/login":
// It is safe to use r.Referer() because the redirect-to URL will be checked later,
// before the client is actually instructed to navigate there.
redirectToAuthURL(w, r, p, r.Referer())
return
}
}
if r.Method != "POST" {
http.Error(w, "", http.StatusMethodNotAllowed)
return
}
if err := r.ParseForm(); err != nil {
http.Error(w, "", http.StatusBadRequest)
return
}
// The remaining endpoints all expect the provider ID in the POST data's RelayState.
traceLog("SAML RelayState", r.FormValue("RelayState"))
var relayState relayState
relayState.decode(r.FormValue("RelayState"))
p, handled := handleGetProvider(r.Context(), w, relayState.ProviderID)
if handled {
return
}
// Handle POST endpoints.
switch requestPath {
case "/acs":
info, err := readAuthnResponse(p, r.FormValue("SAMLResponse"))
if err != nil {
log15.Error("Error validating SAML assertions. Set the env var INSECURE_SAML_LOG_TRACES=1 to log all SAML requests and responses.", "err", err)
http.Error(w, "Error validating SAML assertions. Try signing in again. If the problem persists, a site admin must check the configuration.", http.StatusForbidden)
return
}
actor, safeErrMsg, err := getOrCreateUser(r.Context(), info)
if err != nil {
log15.Error("Error looking up SAML-authenticated user.", "err", err, "userErr", safeErrMsg)
http.Error(w, safeErrMsg, http.StatusInternalServerError)
return
}
var exp time.Duration
// 🚨 SECURITY: TODO(sqs): We *should* uncomment the line below to make our own sessions
// only last for as long as the IdP said the authn grant is active for. Unfortunately,
// until we support refreshing SAML authn in the background
// (https://github.com/sourcegraph/sourcegraph/issues/11340), this provides a bad user
// experience because users need to re-authenticate via SAML every minute or so
// (assuming their SAML IdP, like many, has a 1-minute access token validity period).
//
// if info.SessionNotOnOrAfter != nil {
// exp = time.Until(*info.SessionNotOnOrAfter)
// }
if err := session.SetActor(w, r, actor, exp); err != nil {
log15.Error("Error setting SAML-authenticated actor in session.", "err", err)
http.Error(w, "Error starting SAML-authenticated session. Try signing in again.", http.StatusInternalServerError)
return
}
// 🚨 SECURITY: Call auth.SafeRedirectURL to avoid an open-redirect vuln.
http.Redirect(w, r, auth.SafeRedirectURL(relayState.ReturnToURL), http.StatusFound)
case "/logout":
encodedResp := r.FormValue("SAMLResponse")
{
if raw, err := base64.StdEncoding.DecodeString(encodedResp); err == nil {
traceLog(fmt.Sprintf("LogoutResponse: %s", p.ConfigID().ID), string(raw))
}
}
// TODO(sqs): Fully validate the LogoutResponse here (i.e., also validate that the document
// is a valid LogoutResponse). It is possible that this request is being spoofed, but it
// doesn't let an attacker do very much (just log a user out and redirect).
//
// 🚨 SECURITY: If this logout handler starts to do anything more advanced, it probably must
// validate the LogoutResponse to avoid being vulnerable to spoofing.
_, err := p.samlSP.ValidateEncodedResponse(encodedResp)
if err != nil && !strings.HasPrefix(err.Error(), "unable to unmarshal response:") {
log15.Error("Error validating SAML logout response.", "err", err)
http.Error(w, "Error validating SAML logout response.", http.StatusForbidden)
return
}
// If this is an SP-initiated logout, then the actor has already been cleared from the
// session (but there's no harm in clearing it again). If it's an IdP-initiated logout,
// then it hasn't, and we must clear it here.
if err := session.SetActor(w, r, nil, 0); err != nil {
log15.Error("Error clearing actor from session in SAML logout handler.", "err", err)
http.Error(w, "Error signing out of SAML-authenticated session.", http.StatusInternalServerError)
return
}
http.Redirect(w, r, "/", http.StatusFound)
default:
http.Error(w, "", http.StatusNotFound)
}
}
func redirectToAuthURL(w http.ResponseWriter, r *http.Request, p *provider, returnToURL string) {
authURL, err := buildAuthURLRedirect(p, relayState{
ProviderID: p.ConfigID().ID,
ReturnToURL: auth.SafeRedirectURL(returnToURL),
})
if err != nil {
log15.Error("Failed to build SAML auth URL.", "err", err)
http.Error(w, "Unexpected error in SAML authentication provider.", http.StatusInternalServerError)
return
}
http.Redirect(w, r, authURL, http.StatusFound)
}
func buildAuthURLRedirect(p *provider, relayState relayState) (string, error) {
doc, err := p.samlSP.BuildAuthRequestDocument()
if err != nil {
return "", err
}
{
if data, err := doc.WriteToString(); err == nil {
traceLog(fmt.Sprintf("AuthnRequest: %s", p.ConfigID().ID), data)
}
}
return p.samlSP.BuildAuthURLRedirect(relayState.encode(), doc)
}
// relayState represents the decoded RelayState value in both the IdP-initiated and SP-initiated
// login flows.
//
// SAML overloads the term "RelayState".
// * In the SP-initiated login flow, it is an opaque value originated from the SP and reflected
// back in the AuthnResponse. The Sourcegraph SP uses the base64-encoded JSON of this struct as
// the RelayState.
// * In the IdP-initiated login flow, the RelayState can be any arbitrary hint, but in practice
// is the desired post-login redirect URL in plain text.
type relayState struct {
ProviderID string `json:"k"`
ReturnToURL string `json:"r"`
}
// encode returns the base64-encoded JSON representation of the relay state.
func (s *relayState) encode() string {
b, _ := json.Marshal(s)
return base64.StdEncoding.EncodeToString(b)
}
// Decode decodes the base64-encoded JSON representation of the relay state into the receiver.
func (s *relayState) decode(encoded string) {
if strings.HasPrefix(encoded, "http://") || strings.HasPrefix(encoded, "https://") || encoded == "" {
s.ProviderID, s.ReturnToURL = "", encoded
return
}
if b, err := base64.StdEncoding.DecodeString(encoded); err == nil {
if err := json.Unmarshal(b, s); err == nil {
return
}
}
s.ProviderID, s.ReturnToURL = "", ""
}

View File

@ -0,0 +1,409 @@
package saml
import (
"bytes"
"compress/flate"
"context"
"crypto"
"crypto/x509"
"encoding/base64"
"encoding/pem"
"encoding/xml"
"fmt"
"io/ioutil"
"net/http"
"net/http/httptest"
"net/url"
"strings"
"testing"
"time"
"github.com/beevik/etree"
"github.com/crewjam/saml"
"github.com/sourcegraph/enterprise/cmd/frontend/internal/licensing"
"github.com/sourcegraph/enterprise/pkg/license"
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/auth"
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/session"
"github.com/sourcegraph/sourcegraph/pkg/actor"
"github.com/sourcegraph/sourcegraph/pkg/conf"
"github.com/sourcegraph/sourcegraph/schema"
"github.com/crewjam/saml/samlidp"
)
const (
testSAMLSPCert = `-----BEGIN CERTIFICATE-----
MIIC+zCCAeOgAwIBAgIQFkK4RCQNFkAFzj8dJHnXJjANBgkqhkiG9w0BAQsFADAS
MRAwDgYDVQQKEwdBY21lIENvMB4XDTE4MTAyMDA4NTQ1NVoXDTI4MTAxNzA4NTQ1
NVowEjEQMA4GA1UEChMHQWNtZSBDbzCCASIwDQYJKoZIhvcNAQEBBQADggEPADCC
AQoCggEBAKMD2RmTnAI+1+s+hiakkdbOXHoEwRoG45yeCV5z8A7TnZtF238kReBN
JSOUTvgrvg5WbfG8ULSariepAI45BH3yYoNOBXe0biVsCB+0h6szeV1+N6y9wj0j
ns/AOOV6ec/GbUZufF+XeJmVX/kRoOthUCEWhCGn/ZCa9VNcr2u/EhCZhvk6JcY9
p/gu2YYJepihYpkrzzHwlC+ye+AfPX0/LiZQLGM8ciiziXden8DqEhskkg5HqnPl
hwscqI6qlYIcUFw5QB3xA738N4/92Uj7Jstf05ESFDf6zbUTn/hSsLXivNHI0G4P
4gsVy5Y5pygrw3b3FuodJbuVtLU9cwMCAwEAAaNNMEswDgYDVR0PAQH/BAQDAgWg
MBMGA1UdJQQMMAoGCCsGAQUFBwMBMAwGA1UdEwEB/wQCMAAwFgYDVR0RBA8wDYIL
ZXhhbXBsZS5jb20wDQYJKoZIhvcNAQELBQADggEBADQ/UgbXlW7zPwWswJSlbgph
yjepD5dJ/My1ByIM2GSSYlvnLGq9tSOwUWZ0fZY/G8WOowNSBQlUVPT7tS11j7Ce
BdrImHWpDleZYyagk08vaU059LJFI4EM8Vzn570h3dxdXkSoIGjqNAfywat591k3
K8llPk2wrQ8fv3KA7tNNmJW+Ee1DHIAui53aFe3sHmp7JN1tE6HlqrSLIDymSd28
tOfJ1Y9kOvUF7DY8pkSVDukO9wsy0X568hfJOz4PQe/1LHJ1YxlomTCkyVV8xtW7
hbnEyvPo2yr/SHbk4Fz1yXP20cBm8vO2DmKlI0kaKGQw1Rybl8NQw+OPdb/V6pM=
-----END CERTIFICATE-----`
testSAMLSPKey = `-----BEGIN RSA PRIVATE KEY-----
MIIEpAIBAAKCAQEAowPZGZOcAj7X6z6GJqSR1s5cegTBGgbjnJ4JXnPwDtOdm0Xb
fyRF4E0lI5RO+Cu+DlZt8bxQtJquJ6kAjjkEffJig04Fd7RuJWwIH7SHqzN5XX43
rL3CPSOez8A45Xp5z8ZtRm58X5d4mZVf+RGg62FQIRaEIaf9kJr1U1yva78SEJmG
+Tolxj2n+C7Zhgl6mKFimSvPMfCUL7J74B89fT8uJlAsYzxyKLOJd16fwOoSGySS
Dkeqc+WHCxyojqqVghxQXDlAHfEDvfw3j/3ZSPsmy1/TkRIUN/rNtROf+FKwteK8
0cjQbg/iCxXLljmnKCvDdvcW6h0lu5W0tT1zAwIDAQABAoIBAC/t4LYhbVxHp+p1
zrGr72lN8Wi63x/M6L1SxgRsaCej1pIhvwCp5JWneQT2BSX4jn/er6LEsKH5XL0y
doRahVSWoJpkpTzl4wDDu7u+s6kFkGiJxMrYXDTntTj2FoR6Nzh86gIsWAsvGPln
LvmnUj4CtbGU0jKnFumedgUVmko+QmalghYrkc7dReprGJ6EDWNLGb0ASG9/R7iQ
NOKu17nK9W3yCWJc48SR8y9HWUEUqtKbsshJ6PewNNttsSC3JjeGuiH5fRmXLi6L
wXr2l1AAPGRWbI8djrm7DFLa1s8pfJKkTV0YUDHFNBXny0h+oUGwC+KCQNsfE3t8
GbKqdKkCgYEA0A3dFKuxzm9QbZRcmGwZbHTNfWb7EMlknNq9wqZLgbTK77P1bhXW
l0YP7HuNZnKToMt3UrM5tYFzk28a73p4Ur8+va23lbtwwBeZ5qsZ/vAfI9b2GTj5
AchOI5Xtwf8eWT3OIdHbe4W5hkyb/siPdkRJ1zXfDn8XN0+XIB7NeT0CgYEAyJTn
xjtxMq3Jab5tycZ7ZVC9Y5tpd6phb6KLdLNGNcNKSPvC2EichgLH3Awckpe92HvD
wujlQnlKod42/aPxVcOOnwqN2PMfzPXK91pcOt9QyWFzRLFtvkFB9Dd4vz5uhdWF
CpSBDN87PpFEOuJJApy78e7hfZ5YpxWGK7N24T8CgYEAxtT4+9A6VTc8ffzToTdt
9KCL4dSRDDHr3ZuOzn9umb7WUs6BN3vXYSqr/Sz2rXnCbGEG4Bo4hKX6dmQwMb2x
UCNFKrDiSk6gKnRjuHa8mU+R8wZ0mxY/otxzEL8wQb42msLeRKPyRdI+w4JjctLp
h/UrPGlXitsarNl7bE8Dv2ECgYBXOgog9rCfbVvtlFaCLMJ0qMvziR4wX/PHbFRh
B6U8tBSV8IYnMEyBKqxnUQ0L4tk4T3ouRMGOStjd05jubGEC/uwC1cAh3Hiz1R/S
uYTqRTsImExcTxx+ZDqeTZFA+ZFuuhAFLdeBFYLaDqoxQT6m2CoTZ+K/kiDTaFTU
pFLKWQKBgQCNCeMVMkNwJtMeR/vKUEOPZLKFZihPrORi8F5v0f+qrcg3bKDKd1KI
6kocCulZR3uvEFHUAoNyMNwCZs6YyIK9zEN3/Pb46ThnJNNMXv0CG+W8df/Vbnd0
REijBJT1tjS4dXBULokRuI2640iWll8KX1KQDzqo3l++JRGqoP57Sg==
-----END RSA PRIVATE KEY-----`
)
var (
idpCert = func() *x509.Certificate {
b, _ := pem.Decode([]byte(`-----BEGIN CERTIFICATE-----
MIIC/DCCAeSgAwIBAgIRANzQVAHz24KrifhcM9kaVcYwDQYJKoZIhvcNAQELBQAw
EjEQMA4GA1UEChMHQWNtZSBDbzAeFw0xODEwMjAwODU1NDNaFw0yODEwMTcwODU1
NDNaMBIxEDAOBgNVBAoTB0FjbWUgQ28wggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAw
ggEKAoIBAQCYO2FXgBvsrHyDzFFjUxNuF5OtPYOw+yjGnMR7r1CU93ZaSFN0z9Ux
Qv6rHAmoGwLp+dJjYZ2g/km6/TnONVGjSDh8TxCEH+cP5kIyRN4L3MPW9tsZ36Fd
5yqNbVCrxp3gJKAHmcUHYONzQ6WxxOCEBkvVknysstG8hXhbOcElXrSIyRVPQuEu
TBQSAJahFbQYCKFU93rlO142hgPJDkHibz8PhLgEU7v3Eo23JrOSKNUysXnp5hLT
RhOyQyWNpXA91wwsTwETOD2KlDKDHIcpKEdMSWhRQya6S2z49RVKYpZmATW4Qq2l
hDWWuyG7/IaheuyPum+BF0FFmDKfQY83AgMBAAGjTTBLMA4GA1UdDwEB/wQEAwIF
oDATBgNVHSUEDDAKBggrBgEFBQcDATAMBgNVHRMBAf8EAjAAMBYGA1UdEQQPMA2C
C2V4YW1wbGUuY29tMA0GCSqGSIb3DQEBCwUAA4IBAQBvdi9Gz1qADI0F/oC/7fh/
TjevChAO3XsuGz53yqeM0z49yHJooCV3jzgBrjX82DAhk/nQ0kF+YZummoPG1fqf
rycmsq0yD7Gy/do8sW7XSvpQpkbiBb9yg7rP1eVEfn+vDv0ZS0F3mqbfXl1v7FQ0
PwPtE49OaO7rb3FbLPBtEXocvGjvga8SRmuT3/5oCI46AKldENL6+CKEEWWUIuW8
HeKPQIRYzcqi1dy88nRk44DkCyNxe/h7X/MGt4Mu9HjDH7lDJs0038sdghsX74ET
gN6hZwBR6U2UoJHDytj/+KtSL/XGZSwTgOFyFyMZROcqUPWRwl7Zk3dOKy+3T2Bi
-----END CERTIFICATE-----
`))
c, _ := x509.ParseCertificate(b.Bytes)
return c
}()
idpKey = func() crypto.PrivateKey {
b, _ := pem.Decode([]byte(`-----BEGIN RSA PRIVATE KEY-----
MIIEpAIBAAKCAQEAmDthV4Ab7Kx8g8xRY1MTbheTrT2DsPsoxpzEe69QlPd2WkhT
dM/VMUL+qxwJqBsC6fnSY2GdoP5Juv05zjVRo0g4fE8QhB/nD+ZCMkTeC9zD1vbb
Gd+hXecqjW1Qq8ad4CSgB5nFB2Djc0OlscTghAZL1ZJ8rLLRvIV4WznBJV60iMkV
T0LhLkwUEgCWoRW0GAihVPd65TteNoYDyQ5B4m8/D4S4BFO79xKNtyazkijVMrF5
6eYS00YTskMljaVwPdcMLE8BEzg9ipQygxyHKShHTEloUUMmukts+PUVSmKWZgE1
uEKtpYQ1lrshu/yGoXrsj7pvgRdBRZgyn0GPNwIDAQABAoIBAQCXS7zM5+fY6vy9
SJ1C59gRvKDqto5hoNy/uCKXAoBF7UPVKri3Ca/Ky9irWqxGRMI6pC1y1BuDW/cP
Pojq5pcCfs6UzUeO6N4OMTxtFYDRrVF+Hc1YA6gu2YazFIfukPFrSTs7Epp9YM/t
SLgu24p/7HoGAxah1P8aLFSX5eiOJ+8t8byYOrKLp3Rn67lC9Y+9LX4X6GHlBMDc
WHYupi3ZA7Q59dXQCJHFNG/hk17AMtB8lFra9rUid8teX8ZJKJQ26hU2O0UMujjM
mFlCdmvc97lJ4LhjrWHv/9yacf90bViHIkL52Yux1jNt/jl3/7CyBwHbau4b0qoZ
QkM4WIihAoGBAMlzsUeJxBCbUyMd/6TiLn60SDn0HMf4ZVdGqqxkhopVmsvRTn+P
wu9YHWFPwXwVL3wdtuBjsEA9nMKWWMQKbQUZhm1Y+AQIVpVNQqesgyLctVoIUBNY
fglvKrs8JuRuwMpE2P/3lXMsxtV9AyCpxxXhya8KqJa2jcMB/Lr+lx+fAoGBAMFz
16yHU+Zo6OOvy7T2xh67btwOrS2kIzhGO6gcK7qabkGGeaKLShnMYEmFGp4EaHTf
OVie+SU0OWF/e5bgFWC+fm6jWyhO0xPRbvs2P+l2KtnT2UBT9IgjhrVUIzp+Vn7t
cjfb32m7km1kZZ48ySP9cH/4/xnT6XEC33PoNwlpAoGAG1t+w7xNyAOP8sDsKrQc
pFBPTq98CRwOhx+tpeOw8bBWaT9vbZtUWbSZqNFv8S3fWPegEjD3ioHTfAl23Iid
7Ydd3hOq+sE3IOdxGdwvotheOG/QkBAAbb+PCgZNMdBolg9reLdisFVwWyWy+wiT
ZMFY5lCIPI9mCQmIDMzuMPkCgYBFJKJxh+z07YpP1wV4KLunQFbfUF+VcJUmB/RK
ocb/azL9OJNBBYf2sJW5sVlSIUE0hJR6mFd0dLYNowMJag46Bdwqrzhlr8bBzplc
MIenahTmxlFgLKG6Bvie1vPAdGd19mhcjrnLkL9FWhz38cHymyMammSTVqqZOe2j
/9usAQKBgQCT//j6XflAr20gb+mNcoJqVxRTFtSsZa23kJnZ3Sfs3R8XXu5NcZEZ
ODI9ZpZ9tg8oK9EB5vqMFTTnyjpar7F2jqFWtUmNju/rGlrQCZx0we+EAW/R2hFP
YGYu4Z+SyXTsv/Ys5VGWuuCJO32RuRBeC4eJCmpyH0mqPhIBZmV4Jw==
-----END RSA PRIVATE KEY-----
`))
k, _ := x509.ParsePKCS1PrivateKey(b.Bytes)
return k
}()
)
// newSAMLIDPServer returns a new running SAML IDP server. It is the caller's
// responsibility to call Close().
func newSAMLIDPServer(t *testing.T) (*httptest.Server, *samlidp.Server) {
h := http.NewServeMux()
srv := httptest.NewServer(h)
srvURL, err := url.Parse(srv.URL)
if err != nil {
t.Fatal(err)
}
idpServer, err := samlidp.New(samlidp.Options{
URL: *srvURL,
Key: idpKey,
Certificate: idpCert,
Store: &samlidp.MemoryStore{},
})
if err != nil {
t.Fatal(err)
}
h.Handle("/", idpServer)
return srv, idpServer
}
func TestMiddleware(t *testing.T) {
idpHTTPServer, idpServer := newSAMLIDPServer(t)
defer idpHTTPServer.Close()
licensing.MockGetConfiguredProductLicenseInfo = func() (*license.Info, error) {
return &license.Info{Tags: licensing.EnterpriseTags}, nil
}
defer func() { licensing.MockGetConfiguredProductLicenseInfo = nil }()
conf.Mock(&schema.SiteConfiguration{
AppURL: "http://example.com",
ExperimentalFeatures: &schema.ExperimentalFeatures{},
})
defer conf.Mock(nil)
config := withConfigDefaults(&schema.SAMLAuthProvider{
Type: "saml",
IdentityProviderMetadataURL: idpServer.IDP.MetadataURL.String(),
ServiceProviderCertificate: testSAMLSPCert,
ServiceProviderPrivateKey: testSAMLSPKey,
})
mockGetProviderValue = &provider{config: *config}
defer func() { mockGetProviderValue = nil }()
auth.SetMockProviders([]auth.Provider{mockGetProviderValue})
defer func() { auth.SetMockProviders(nil) }()
cleanup := session.ResetMockSessionStore(t)
defer cleanup()
providerID := providerConfigID(&mockGetProviderValue.config, true)
// Mock user
mockedExternalID := "testuser_id"
const mockedUserID = 123
auth.SetMockCreateOrUpdateUser(func(u db.NewUser, a db.ExternalAccountSpec) (userID int32, err error) {
if a.ServiceType == "saml" && a.ServiceID == idpServer.IDP.MetadataURL.String() && a.ClientID == "http://example.com/.auth/saml/metadata" && a.AccountID == mockedExternalID {
return mockedUserID, nil
}
return 0, fmt.Errorf("account %v not found in mock", a)
})
defer func() { auth.SetMockCreateOrUpdateUser(nil) }()
// Set up the test handler.
authedHandler := http.NewServeMux()
authedHandler.Handle("/.api/", Middleware.API(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if uid := actor.FromContext(r.Context()).UID; uid != mockedUserID && uid != 0 {
t.Errorf("got actor UID %d, want %d", uid, mockedUserID)
}
})))
authedHandler.Handle("/", Middleware.App(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
switch r.URL.Path {
case "/":
w.Write([]byte("This is the home"))
case "/page":
w.Write([]byte("This is a page"))
case "/require-authn":
actr := actor.FromContext(r.Context())
if actr.UID == 0 {
t.Errorf("in authn expected-endpoint, no actor was set; expected actor with UID %d", mockedUserID)
} else if actr.UID != mockedUserID {
t.Errorf("in authn expected-endpoint, actor with incorrect UID was set; %d != %d", actr.UID, mockedUserID)
}
w.Write([]byte("Authenticated"))
default:
http.Error(w, "", http.StatusNotFound)
}
})))
// doRequest simulates a request to our authed handler (i.e., the SAML Service Provider).
//
// authed sets an authed actor in the request context to simulate an authenticated request.
doRequest := func(method, urlStr, body string, cookies []*http.Cookie, authed bool, form url.Values) *http.Response {
req := httptest.NewRequest(method, urlStr, bytes.NewBufferString(body))
for _, cookie := range cookies {
req.AddCookie(cookie)
}
if form != nil {
req.PostForm = form
req.Header.Add("Content-Type", "application/x-www-form-urlencoded")
}
if authed {
req = req.WithContext(actor.WithActor(context.Background(), &actor.Actor{UID: mockedUserID}))
}
respRecorder := httptest.NewRecorder()
authedHandler.ServeHTTP(respRecorder, req)
return respRecorder.Result()
}
var (
authnRequest saml.AuthnRequest
authnCookies []*http.Cookie
authnRequestURL string
)
t.Run("unauthenticated homepage visit -> IDP SSO URL", func(t *testing.T) {
resp := doRequest("GET", "http://example.com/", "", nil, false, nil)
if want := http.StatusFound; resp.StatusCode != want {
t.Errorf("got response code %v, want %v", resp.StatusCode, want)
}
locURL, err := url.Parse(resp.Header.Get("Location"))
if err != nil {
t.Fatal(err)
}
if !strings.HasPrefix(locURL.String(), idpServer.IDP.SSOURL.String()) {
t.Error("wrong redirect URL")
}
// save cookies and Authn request
authnCookies = unexpiredCookies(resp)
authnRequestURL = locURL.String()
deflatedSAMLRequest, err := base64.StdEncoding.DecodeString(locURL.Query().Get("SAMLRequest"))
if err != nil {
t.Fatal(err)
}
if err := xml.NewDecoder(flate.NewReader(bytes.NewBuffer(deflatedSAMLRequest))).Decode(&authnRequest); err != nil {
t.Fatal(err)
}
})
t.Run("unauthenticated API visit -> pass through", func(t *testing.T) {
resp := doRequest("GET", "http://example.com/.api/foo", "", nil, false, nil)
if got, want := resp.StatusCode, http.StatusOK; got != want {
t.Errorf("wrong response code: got %v, want %v", got, want)
}
})
var (
loggedInCookies []*http.Cookie
)
t.Run("get SP metadata and register SP with IDP", func(t *testing.T) {
resp := doRequest("GET", "http://example.com/.auth/saml/metadata?pc="+providerID, "", nil, false, nil)
service := samlidp.Service{}
if err := xml.NewDecoder(resp.Body).Decode(&service.Metadata); err != nil {
t.Fatal(err)
}
serviceMetadataBytes, err := xml.Marshal(service.Metadata)
if err != nil {
t.Fatal(err)
}
req, err := http.NewRequest("PUT", idpHTTPServer.URL+"/services/id", bytes.NewBuffer(serviceMetadataBytes))
if err != nil {
t.Fatal(err)
}
resp, err = http.DefaultClient.Do(req)
if err != nil {
t.Fatalf("could not register SP with IDP, error: %s, resp: %v", err, resp)
}
defer resp.Body.Close()
if want := http.StatusNoContent; resp.StatusCode != want {
t.Errorf("got HTTP %d, want %d", resp.StatusCode, want)
}
})
t.Run("get SAML assertion from IDP and post the assertion to the SP ACS URL", func(t *testing.T) {
authnReq, err := http.NewRequest("GET", authnRequestURL, nil)
if err != nil {
t.Fatal(err)
}
idpAuthnReq, err := saml.NewIdpAuthnRequest(&idpServer.IDP, authnReq)
if err != nil {
t.Fatal(err)
}
if err := idpAuthnReq.Validate(); err != nil {
t.Fatal(err)
}
session := saml.Session{
ID: "session-id",
CreateTime: time.Now(),
ExpireTime: time.Now().Add(24 * time.Hour),
Index: "index",
NameID: "testuser_id",
UserName: "testuser_username",
UserEmail: "testuser@email.com",
}
if err := (saml.DefaultAssertionMaker{}).MakeAssertion(idpAuthnReq, &session); err != nil {
t.Fatal(err)
}
if err := idpAuthnReq.MakeResponse(); err != nil {
t.Fatal(err)
}
doc := etree.NewDocument()
doc.SetRoot(idpAuthnReq.ResponseEl)
responseBuf, err := doc.WriteToBytes()
if err != nil {
t.Fatal(err)
}
samlResponse := base64.StdEncoding.EncodeToString(responseBuf)
reqParams := url.Values{}
reqParams.Set("SAMLResponse", samlResponse)
reqParams.Set("RelayState", idpAuthnReq.RelayState)
resp := doRequest("POST", "http://example.com/.auth/saml/acs", "", authnCookies, false, reqParams)
if want := http.StatusFound; resp.StatusCode != want {
t.Errorf("got status code %v, want %v", resp.StatusCode, want)
}
if got, want1, want2 := resp.Header.Get("Location"), "http://example.com/", "/"; got != want1 && got != want2 {
t.Errorf("got redirect location %v, want %v or %v", got, want1, want2)
}
// save the cookies from the login response
loggedInCookies = unexpiredCookies(resp)
})
t.Run("authenticated request to home page", func(t *testing.T) {
resp := doRequest("GET", "http://example.com/", "", loggedInCookies, true, nil)
respBody, _ := ioutil.ReadAll(resp.Body)
if want := http.StatusOK; resp.StatusCode != want {
t.Errorf("got status code %v, want %v", resp.StatusCode, want)
}
if got, want := string(respBody), "This is the home"; got != want {
t.Errorf("got response body %v, want %v", got, want)
}
})
t.Run("authenticated request to sub page", func(t *testing.T) {
resp := doRequest("GET", "http://example.com/page", "", loggedInCookies, true, nil)
respBody, _ := ioutil.ReadAll(resp.Body)
if want := http.StatusOK; resp.StatusCode != want {
t.Errorf("got status code %v, want %v", resp.StatusCode, want)
}
if got, want := string(respBody), "This is a page"; got != want {
t.Errorf("got response body %v, want %v", got, want)
}
})
t.Run("verify actor gets set in request context", func(t *testing.T) {
resp := doRequest("GET", "http://example.com/require-authn", "", loggedInCookies, true, nil)
if want := http.StatusOK; resp.StatusCode != want {
t.Errorf("got status code %v, want %v", resp.StatusCode, want)
}
})
t.Run("verify actor gets set in API request context", func(t *testing.T) {
resp := doRequest("GET", "http://example.com/.api/foo", "", loggedInCookies, true, nil)
if got, want := resp.StatusCode, http.StatusOK; got != want {
t.Errorf("wrong status code: got %v, want %v", got, want)
}
})
}
// unexpiredCookies returns the list of unexpired cookies set by the response
func unexpiredCookies(resp *http.Response) (cookies []*http.Cookie) {
for _, cookie := range resp.Cookies() {
if cookie.RawExpires == "" || cookie.Expires.After(time.Now()) {
cookies = append(cookies, cookie)
}
}
return
}

View File

@ -0,0 +1,281 @@
package saml
import (
"context"
"crypto/tls"
"crypto/x509"
"encoding/base64"
"encoding/xml"
"fmt"
"io/ioutil"
"net/http"
"net/url"
"path"
"sync"
"time"
"github.com/beevik/etree"
"github.com/pkg/errors"
saml2 "github.com/russellhaering/gosaml2"
"github.com/russellhaering/gosaml2/types"
dsig "github.com/russellhaering/goxmldsig"
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/auth"
"github.com/sourcegraph/sourcegraph/pkg/conf"
"github.com/sourcegraph/sourcegraph/schema"
"golang.org/x/net/context/ctxhttp"
)
const providerType = "saml"
type provider struct {
config schema.SAMLAuthProvider
multiple bool // whether there are multiple SAML auth providers
mu sync.Mutex
samlSP *saml2.SAMLServiceProvider
refreshErr error
}
// ConfigID implements auth.Provider.
func (p *provider) ConfigID() auth.ProviderConfigID {
return auth.ProviderConfigID{
Type: providerType,
ID: providerConfigID(&p.config, p.multiple),
}
}
// Config implements auth.Provider.
func (p *provider) Config() schema.AuthProviders {
return schema.AuthProviders{Saml: &p.config}
}
// Refresh implements auth.Provider.
func (p *provider) Refresh(ctx context.Context) error {
p.mu.Lock()
defer p.mu.Unlock()
p.samlSP, p.refreshErr = getServiceProvider(ctx, &p.config)
return p.refreshErr
}
func providerIDQuery(pc *schema.SAMLAuthProvider, multiple bool) url.Values {
if multiple {
return url.Values{"pc": []string{providerConfigID(pc, multiple)}}
}
return url.Values{}
}
func (p *provider) getCachedInfoAndError() (*auth.ProviderInfo, error) {
info := auth.ProviderInfo{
DisplayName: p.config.DisplayName,
AuthenticationURL: (&url.URL{
Path: path.Join(auth.AuthURLPrefix, "saml", "login"),
RawQuery: providerIDQuery(&p.config, p.multiple).Encode(),
}).String(),
}
if info.DisplayName == "" {
info.DisplayName = "SAML"
}
p.mu.Lock()
defer p.mu.Unlock()
err := p.refreshErr
if err != nil {
err = errors.WithMessage(err, "failed to initialize SAML Service Provider")
} else if p.samlSP == nil {
err = errors.New("SAML Service Provider is not yet initialized")
}
if p.samlSP != nil {
info.ServiceID = p.samlSP.IdentityProviderIssuer
info.ClientID = p.samlSP.ServiceProviderIssuer
}
return &info, err
}
// CachedInfo implements auth.Provider.
func (p *provider) CachedInfo() *auth.ProviderInfo {
info, _ := p.getCachedInfoAndError()
return info
}
func getServiceProvider(ctx context.Context, pc *schema.SAMLAuthProvider) (*saml2.SAMLServiceProvider, error) {
c, err := readProviderConfig(pc)
if err != nil {
return nil, err
}
idpMetadata, err := readIdentityProviderMetadata(ctx, c)
if err != nil {
return nil, err
}
{
if c.identityProviderMetadataURL != nil {
traceLog(fmt.Sprintf("Identity Provider metadata: %s", c.identityProviderMetadataURL), string(idpMetadata))
}
}
metadata, err := unmarshalEntityDescriptor(idpMetadata)
if err != nil {
return nil, errors.WithMessage(err, "parsing SAML Identity Provider metadata")
}
sp := saml2.SAMLServiceProvider{
IdentityProviderSSOURL: metadata.IDPSSODescriptor.SingleSignOnServices[0].Location,
IdentityProviderIssuer: metadata.EntityID,
NameIdFormat: getNameIDFormat(pc),
SkipSignatureValidation: pc.InsecureSkipAssertionSignatureValidation,
ValidateEncryptionCert: true,
AllowMissingAttributes: true,
}
idpCertStore := &dsig.MemoryX509CertificateStore{Roots: []*x509.Certificate{}}
for _, kd := range metadata.IDPSSODescriptor.KeyDescriptors {
for i, xcert := range kd.KeyInfo.X509Data.X509Certificates {
if xcert.Data == "" {
return nil, fmt.Errorf("SAML Identity Provider metadata certificate %d is empty", i)
}
certData, err := base64.StdEncoding.DecodeString(xcert.Data)
if err != nil {
return nil, errors.WithMessage(err, fmt.Sprintf("decoding SAML Identity Provider metadata certificate %d", i))
}
idpCert, err := x509.ParseCertificate(certData)
if err != nil {
return nil, errors.WithMessage(err, fmt.Sprintf("parsing SAML Identity Provider metadata certificate %d X.509 data", i))
}
idpCertStore.Roots = append(idpCertStore.Roots, idpCert)
}
}
sp.IDPCertificateStore = idpCertStore
// The SP's signing and encryption keys.
if c.keyPair != nil {
sp.SPKeyStore = dsig.TLSCertKeyStore(*c.keyPair)
sp.SignAuthnRequests = pc.SignRequests == nil || *pc.SignRequests
} else {
// If the SP private key isn't specified, then the IdP must not care to validate.
if pc.SignRequests != nil && *pc.SignRequests {
return nil, errors.New("signRequests is true for SAML Service Provider but no private key and cert are given")
}
}
// pc.Issuer's default of ${appURL}/.auth/saml/metadata already applied (in withConfigDefaults).
sp.ServiceProviderIssuer = pc.ServiceProviderIssuer
if pc.ServiceProviderIssuer == "" {
return nil, errors.New("invalid SAML Service Provider configuration: issuer is empty (and default issuer could not be derived from empty appURL)")
}
appURL, err := url.Parse(conf.Get().AppURL)
if err != nil {
return nil, errors.WithMessage(err, "parsing app URL for SAML Service Provider")
}
sp.AssertionConsumerServiceURL = appURL.ResolveReference(&url.URL{Path: path.Join(authPrefix, "acs")}).String()
sp.AudienceURI = sp.ServiceProviderIssuer
return &sp, nil
}
// entitiesDescriptor represents the SAML EntitiesDescriptor object.
type entitiesDescriptor struct {
XMLName xml.Name `xml:"urn:oasis:names:tc:SAML:2.0:metadata EntitiesDescriptor"`
ID *string `xml:",attr,omitempty"`
ValidUntil *time.Time `xml:"validUntil,attr,omitempty"`
CacheDuration *time.Duration `xml:"cacheDuration,attr,omitempty"`
Name *string `xml:",attr,omitempty"`
Signature *etree.Element
EntitiesDescriptors []entitiesDescriptor `xml:"urn:oasis:names:tc:SAML:2.0:metadata EntitiesDescriptor"`
EntityDescriptors []types.EntityDescriptor `xml:"urn:oasis:names:tc:SAML:2.0:metadata EntityDescriptor"`
}
// unmarshalEntityDescriptor unmarshals from an XML root <EntityDescriptor> or <EntitiesDescriptor>
// element. If the latter, it returns the first <EntityDescriptor> child that has an
// IDPSSODescriptor.
//
// Taken from github.com/crewjam/saml.
func unmarshalEntityDescriptor(data []byte) (*types.EntityDescriptor, error) {
var entity *types.EntityDescriptor
if err := xml.Unmarshal(data, &entity); err != nil {
// This comparison is ugly, but it is how the error is generated in encoding/xml.
if err.Error() != "expected element type <EntityDescriptor> but have <EntitiesDescriptor>" {
return nil, err
}
var entities *entitiesDescriptor
if err := xml.Unmarshal(data, &entities); err != nil {
return nil, err
}
for i, e := range entities.EntityDescriptors {
if e.IDPSSODescriptor != nil {
entity = &entities.EntityDescriptors[i]
break
}
}
if entity == nil {
return nil, errors.New("no entity found with IDPSSODescriptor")
}
}
return entity, nil
}
type providerConfig struct {
keyPair *tls.Certificate
// Exactly 1 of these is set:
identityProviderMetadataURL *url.URL
identityProviderMetadata []byte
}
func readProviderConfig(pc *schema.SAMLAuthProvider) (*providerConfig, error) {
var c providerConfig
if pc.ServiceProviderCertificate != "" && pc.ServiceProviderPrivateKey != "" {
keyPair, err := tls.X509KeyPair([]byte(pc.ServiceProviderCertificate), []byte(pc.ServiceProviderPrivateKey))
if err != nil {
return nil, err
}
keyPair.Leaf, err = x509.ParseCertificate(keyPair.Certificate[0])
if err != nil {
return nil, err
}
c.keyPair = &keyPair
}
// Allow specifying either URL to SAML Identity Provider metadata XML file, or the XML
// file contents directly.
switch {
case pc.IdentityProviderMetadataURL != "" && pc.IdentityProviderMetadata != "":
return nil, errors.New("invalid SAML configuration: set either identityProviderMetadataURL or identityProviderMetadata, not both")
case pc.IdentityProviderMetadataURL != "":
var err error
c.identityProviderMetadataURL, err = url.Parse(pc.IdentityProviderMetadataURL)
if err != nil {
return nil, errors.Wrap(err, "parsing SAML Identity Provider metadata URL")
}
case pc.IdentityProviderMetadata != "":
c.identityProviderMetadata = []byte(pc.IdentityProviderMetadata)
default:
return nil, errors.New("invalid SAML configuration: must provide the SAML metadata, using either identityProviderMetadataURL (URL where XML file is available) or identityProviderMetadata (XML file contents)")
}
return &c, nil
}
func readIdentityProviderMetadata(ctx context.Context, c *providerConfig) ([]byte, error) {
if c.identityProviderMetadata != nil {
return []byte(c.identityProviderMetadata), nil
}
resp, err := ctxhttp.Get(ctx, nil, c.identityProviderMetadataURL.String())
if err != nil {
return nil, errors.WithMessage(err, "fetching SAML Identity Provider metadata")
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
return nil, fmt.Errorf("non-200 HTTP response for SAML Identity Provider metadata URL: %s", c.identityProviderMetadataURL)
}
data, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, errors.WithMessage(err, "reading SAML Identity Provider metadata")
}
return data, nil
}

View File

@ -0,0 +1 @@
package saml

View File

@ -0,0 +1,83 @@
package saml
import (
"fmt"
"net/http"
"github.com/beevik/etree"
"github.com/pkg/errors"
"github.com/sourcegraph/sourcegraph/pkg/conf"
"github.com/sourcegraph/sourcegraph/schema"
)
// SignOut returns the URL where the user can initiate a logout from the SAML IdentityProvider, if
// it has a SingleLogoutService.
func SignOut(w http.ResponseWriter, r *http.Request) (logoutURL string, err error) {
// TODO(sqs): Only supports a single SAML auth provider.
pc, multiple := getFirstProviderConfig()
if pc == nil {
return "", nil
}
p := getProvider(providerConfigID(pc, multiple))
if p == nil {
return "", nil
}
doc, err := newLogoutRequest(p)
if err != nil {
return "", errors.WithMessage(err, "creating SAML LogoutRequest")
}
{
if data, err := doc.WriteToString(); err == nil {
traceLog(fmt.Sprintf("LogoutRequest: %s", p.ConfigID().ID), string(data))
}
}
return p.samlSP.BuildAuthURLRedirect("/", doc)
}
// getFirstProviderConfig returns the SAML auth provider config. At most 1 can be specified in site
// config; if there is more than 1, it returns multiple == true (which the caller should handle by
// returning an error and refusing to proceed with auth).
func getFirstProviderConfig() (pc *schema.SAMLAuthProvider, multiple bool) {
for _, p := range conf.Get().AuthProviders {
if p.Saml != nil {
if pc != nil {
return pc, true // multiple SAML auth providers
}
pc = withConfigDefaults(p.Saml)
}
}
return pc, false
}
func newLogoutRequest(p *provider) (*etree.Document, error) {
p.mu.Lock()
defer p.mu.Unlock()
if p.samlSP == nil {
return nil, errors.New("unable to create SAML LogoutRequest because provider is not yet initialized")
}
// Start with the doc for AuthnRequest and change a few things to make it into a LogoutRequest
// doc. This saves us from needing to duplicate a bunch of code.
doc, err := p.samlSP.BuildAuthRequestDocumentNoSig()
if err != nil {
return nil, err
}
root := doc.Root()
root.Tag = "LogoutRequest"
// TODO(sqs): This assumes SSO URL == SLO URL (i.e., the same endpoint is used for signin and
// logout). To fix this, use `root.SelectAttr("Destination").Value = "..."`.
if t := root.FindElement("//samlp:NameIDPolicy"); t != nil {
root.RemoveChild(t)
}
if p.samlSP.SignAuthnRequests {
signed, err := p.samlSP.SignAuthnRequest(root)
if err != nil {
return nil, err
}
doc.SetRoot(signed)
}
return doc, nil
}

View File

@ -0,0 +1,124 @@
package saml
import (
"context"
"encoding/base64"
"fmt"
"strings"
"github.com/pkg/errors"
saml2 "github.com/russellhaering/gosaml2"
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/auth"
"github.com/sourcegraph/sourcegraph/pkg/actor"
)
type authnResponseInfo struct {
spec db.ExternalAccountSpec
email, displayName string
unnormalizedUsername string
accountData interface{}
}
func readAuthnResponse(p *provider, encodedResp string) (*authnResponseInfo, error) {
{
if raw, err := base64.StdEncoding.DecodeString(encodedResp); err == nil {
traceLog(fmt.Sprintf("AuthnResponse: %s", p.ConfigID().ID), string(raw))
}
}
assertions, err := p.samlSP.RetrieveAssertionInfo(encodedResp)
if err != nil {
return nil, errors.WithMessage(err, "reading AuthnResponse assertions")
}
if wi := assertions.WarningInfo; wi.InvalidTime || wi.NotInAudience {
return nil, fmt.Errorf("invalid SAML AuthnResponse: %+v", wi)
}
pi, err := p.getCachedInfoAndError()
if err != nil {
return nil, err
}
firstNonempty := func(ss ...string) string {
for _, s := range ss {
if s := strings.TrimSpace(s); s != "" {
return s
}
}
return ""
}
attr := samlAssertionValues(assertions.Values)
email := firstNonempty(attr.Get("email"), attr.Get("emailaddress"))
if email == "" && mightBeEmail(assertions.NameID) {
email = assertions.NameID
}
if pn := attr.Get("eduPersonPrincipalName"); email == "" && mightBeEmail(pn) {
email = pn
}
info := authnResponseInfo{
spec: db.ExternalAccountSpec{
ServiceType: providerType,
ServiceID: pi.ServiceID,
ClientID: pi.ClientID,
AccountID: assertions.NameID,
},
email: email,
unnormalizedUsername: firstNonempty(attr.Get("login"), attr.Get("uid"), email),
displayName: firstNonempty(attr.Get("displayName"), attr.Get("givenName")+" "+attr.Get("surname")),
accountData: assertions,
}
if assertions.NameID == "" {
return nil, errors.New("the SAML response did not contain a valid NameID")
}
if info.email == "" {
return nil, errors.New("the SAML response did not contain an email attribute")
}
if info.unnormalizedUsername == "" {
return nil, errors.New("the SAML response did not contain a username attribute")
}
return &info, nil
}
// getOrCreateUser gets or creates a user account based on the SAML claims. It returns the
// authenticated actor if successful; otherwise it returns an friendly error message (safeErrMsg)
// that is safe to display to users, and a non-nil err with lower-level error details.
func getOrCreateUser(ctx context.Context, info *authnResponseInfo) (_ *actor.Actor, safeErrMsg string, err error) {
var data db.ExternalAccountData
auth.SetExternalAccountData(&data.AccountData, info.accountData)
username, err := auth.NormalizeUsername(info.unnormalizedUsername)
if err != nil {
return nil, fmt.Sprintf("Error normalizing the username %q. See https://about.sourcegraph.com/docs/config/authentication#username-normalization.", info.unnormalizedUsername), err
}
userID, safeErrMsg, err := auth.CreateOrUpdateUser(ctx, db.NewUser{
Username: username,
Email: info.email,
EmailIsVerified: info.email != "", // TODO(sqs): https://github.com/sourcegraph/sourcegraph/issues/10118
DisplayName: info.displayName,
// SAML has no standard way of providing an avatar URL.
},
info.spec,
data,
)
if err != nil {
return nil, safeErrMsg, err
}
return actor.FromUser(userID), "", nil
}
func mightBeEmail(s string) bool {
return strings.Count(s, "@") == 1
}
type samlAssertionValues saml2.Values
func (v samlAssertionValues) Get(key string) string {
for _, a := range v {
if a.Name == key || a.FriendlyName == key {
return a.Values[0].Value
}
}
return ""
}

File diff suppressed because one or more lines are too long

View File

@ -0,0 +1,7 @@
package db
import dbtesting "github.com/sourcegraph/sourcegraph/cmd/frontend/db/testing"
func init() {
dbtesting.DBNameSuffix = "enterprisedb"
}

View File

@ -0,0 +1,527 @@
package db
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"github.com/lib/pq"
opentracing "github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
otlog "github.com/opentracing/opentracing-go/log"
"github.com/pkg/errors"
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
"github.com/sourcegraph/sourcegraph/cmd/frontend/db/dbconn"
"github.com/sourcegraph/sourcegraph/cmd/frontend/types"
"github.com/sourcegraph/sourcegraph/pkg/api"
"github.com/sourcegraph/sourcegraph/pkg/inventory"
"github.com/sourcegraph/sourcegraph/xlang"
"github.com/sourcegraph/sourcegraph/xlang/lspext"
)
// globalDeps provides access to the `global_dep` table. Each row in
// the table represents a dependency relationship from a repository to
// a package-manager-level package.
//
// * The language column is the programming language in which the
// dependency occurs (the language of the repository and the package
// manager package)
// * The dep_data column contains JSON describing the package manager package.
// Typically, this includes a name and version field.
// * The repo_id column identifies the repository.
// * The hints column contains JSON that contains additional hints that can
// be used to optimized requests related to the dependency (e.g., which
// directory in a repository contains the dependency).
//
// For a detailed overview of the schema, see schema.txt.
type globalDeps struct{}
func (g *globalDeps) TotalRefs(ctx context.Context, repo *types.Repo, langs []*inventory.Lang) (int, error) {
var sum int
for _, lang := range langs {
switch lang.Name {
case inventory.LangGo:
for _, expandedSources := range repoURIToGoPathPrefixes(repo.URI) {
refs, err := g.doTotalRefsGo(ctx, expandedSources)
if err != nil {
return 0, errors.Wrap(err, "doTotalRefsGo")
}
sum += refs
}
case inventory.LangJava:
refs, err := g.doTotalRefs(ctx, repo.ID, "java")
if err != nil {
return 0, errors.Wrap(err, "doTotalRefs")
}
sum += refs
}
}
return sum, nil
}
// ListTotalRefs is like TotalRefs, except it returns a list of repo IDs
// instead of just the length of that list. Obviously, this is less efficient
// if you just need the count, however.
func (g *globalDeps) ListTotalRefs(ctx context.Context, repo *types.Repo, langs []*inventory.Lang) ([]api.RepoID, error) {
var repos []api.RepoID
for _, lang := range langs {
switch lang.Name {
case inventory.LangGo:
for _, expandedSources := range repoURIToGoPathPrefixes(repo.URI) {
refs, err := g.doListTotalRefsGo(ctx, expandedSources)
if err != nil {
return nil, errors.Wrap(err, "doListTotalRefsGo")
}
repos = append(repos, refs...)
}
case inventory.LangJava:
refs, err := g.doListTotalRefs(ctx, repo.ID, "java")
if err != nil {
return nil, errors.Wrap(err, "doListTotalRefs")
}
repos = append(repos, refs...)
}
}
return repos, nil
}
// repoURIToGoPathPrefixes translates a repository URI like
// github.com/kubernetes/kubernetes into its _prefix_ matching Go import paths
// (e.g. k8s.io/kubernetes). In the case of the standard library,
// github.com/golang/go returns all of the Go stdlib package paths. If the
// repository URI is not special cased, []string{repoURI} is simply returned.
//
// TODO(slimsag): In the future, when the pkgs index includes Go repositories,
// use that instead of this manual mapping hack.
func repoURIToGoPathPrefixes(repoURI api.RepoURI) []string {
manualMapping := map[api.RepoURI][]string{
// stdlib hack: by returning an empty string (NOT no strings) we end up
// with an SQL query like `AND dep_data->>'package' LIKE '%';` which
// matches all Go repositories effectively. We do this for the stdlib
// because all Go repositories will import the stdlib anyway.
"github.com/golang/go": {""},
// google.golang.org
"github.com/grpc/grpc-go": {"google.golang.org/grpc"},
"github.com/google/google-api-go-client": {"google.golang.org/api"},
"github.com/golang/appengine": {"google.golang.org/appengine"},
// go4.org
"github.com/camlistore/go4": {"go4.org"},
// At special request of a user, since we don't support custom import
// paths generically here yet. See https://github.com/sourcegraph/sourcegraph/issues/12488
"github.com/goadesign/goa": {"github.com/goadesign/goa", "goa.design/goa"},
}
if v, ok := manualMapping[repoURI]; ok {
return v
}
switch {
case strings.HasPrefix(string(repoURI), "github.com/azul3d"): // azul3d.org
split := strings.Split(string(repoURI), "/")
if len(split) >= 3 {
return []string{"azul3d.org/" + split[2]}
}
case strings.HasPrefix(string(repoURI), "github.com/dskinner"): // dasa.cc
split := strings.Split(string(repoURI), "/")
if len(split) >= 3 {
return []string{"dasa.cc/" + split[2]}
}
case strings.HasPrefix(string(repoURI), "github.com/kubernetes"): // k8s.io
split := strings.Split(string(repoURI), "/")
if len(split) >= 3 {
return []string{"k8s.io/" + split[2]}
}
case strings.HasPrefix(string(repoURI), "github.com/uber-go"): // go.uber.org
split := strings.Split(string(repoURI), "/")
if len(split) >= 3 {
// Some repos use non-canonical import paths.
return []string{
string(repoURI),
"go.uber.org/" + split[2],
}
}
case strings.HasPrefix(string(repoURI), "github.com/dominikh"): // honnef.co
split := strings.Split(string(repoURI), "/")
if len(split) >= 3 {
return []string{"honnef.co/" + strings.Replace(split[2], "-", "/", -1)}
}
case strings.HasPrefix(string(repoURI), "github.com/golang") && repoURI != "github.com/golang/go": // golang.org/x
split := strings.Split(string(repoURI), "/")
if len(split) >= 3 {
return []string{"golang.org/x/" + split[2]}
}
case strings.HasPrefix(string(repoURI), "github.com"): // gopkg.in
split := strings.Split(string(repoURI), "/")
if len(split) >= 3 && strings.HasPrefix(split[1], "go-") {
// Four possibilities
return []string{
string(repoURI), // github.com/go-foo/foo
"gopkg.in/" + strings.TrimPrefix(split[1], "go-"), // gopkg.in/foo
"labix.org/v1/" + strings.TrimPrefix(split[1], "go-"), // labix.org/v1/foo
"labix.org/v2/" + strings.TrimPrefix(split[1], "go-"), // labix.org/v2/foo
}
} else if len(split) >= 3 {
// Two possibilities
return []string{
string(repoURI), // github.com/foo/bar
"gopkg.in/" + split[1] + "/" + split[2], // gopkg.in/foo/bar
}
}
}
return []string{string(repoURI)}
}
// doTotalRefs is the generic implementation of total references, using the `pkgs` table.
func (g *globalDeps) doTotalRefs(ctx context.Context, repo api.RepoID, lang string) (sum int, err error) {
// Get packages contained in the repo
packages, err := (&pkgs{}).ListPackages(ctx, &api.ListPackagesOp{Lang: lang, Limit: 500, RepoID: repo})
if err != nil {
return 0, errors.Wrap(err, "ListPackages")
}
if len(packages) == 0 {
return 0, nil
}
// Find number of repos that depend on that set of packages
var args []interface{}
arg := func(a interface{}) string {
args = append(args, a)
return fmt.Sprintf("$%d", len(args))
}
var pkgClauses []string
for _, pkg := range packages {
pkgID, ok := xlang.PackageIdentifier(pkg.Pkg, lang)
if !ok {
return 0, errors.Wrap(err, "PackageIdentifier")
}
containmentQuery, err := json.Marshal(pkgID)
if err != nil {
return 0, errors.Wrap(err, "Marshal")
}
pkgClauses = append(pkgClauses, `dep_data @> `+arg(string(containmentQuery)))
}
whereSQL := `(language=` + arg(lang) + `) AND ((` + strings.Join(pkgClauses, ") OR (") + `))`
sql := `SELECT count(distinct(repo_id))
FROM global_dep
WHERE ` + whereSQL
rows, err := dbconn.Global.QueryContext(ctx, sql, args...)
if err != nil {
return 0, errors.Wrap(err, "Query")
}
defer rows.Close()
var count int
for rows.Next() {
if err := rows.Scan(&count); err != nil {
return 0, errors.Wrap(err, "Scan")
}
}
return count, nil
}
// doListTotalRefs is the generic implementation of list total references,
// using the `pkgs` table.
func (g *globalDeps) doListTotalRefs(ctx context.Context, repo api.RepoID, lang string) ([]api.RepoID, error) {
// Get packages contained in the repo
packages, err := (&pkgs{}).ListPackages(ctx, &api.ListPackagesOp{Lang: lang, Limit: 500, RepoID: repo})
if err != nil {
return nil, errors.Wrap(err, "ListPackages")
}
if len(packages) == 0 {
return nil, nil
}
// Find all repos that depend on that set of packages
var args []interface{}
arg := func(a interface{}) string {
args = append(args, a)
return fmt.Sprintf("$%d", len(args))
}
var pkgClauses []string
for _, pkg := range packages {
pkgID, ok := xlang.PackageIdentifier(pkg.Pkg, lang)
if !ok {
return nil, errors.Wrap(err, "PackageIdentifier")
}
containmentQuery, err := json.Marshal(pkgID)
if err != nil {
return nil, errors.Wrap(err, "Marshal")
}
pkgClauses = append(pkgClauses, `dep_data @> `+arg(string(containmentQuery)))
}
whereSQL := `(language=` + arg(lang) + `) AND ((` + strings.Join(pkgClauses, ") OR (") + `))`
sql := `SELECT distinct(repo_id)
FROM global_dep
WHERE ` + whereSQL
rows, err := dbconn.Global.QueryContext(ctx, sql, args...)
if err != nil {
return nil, errors.Wrap(err, "Query")
}
defer rows.Close()
var repos []api.RepoID
for rows.Next() {
var repo api.RepoID
if err := rows.Scan(&repo); err != nil {
return nil, errors.Wrap(err, "Scan")
}
repos = append(repos, repo)
}
return repos, nil
}
// doTotalRefsGo is the Go-specific implementation of total references, since we can extract package metadata directly
// from Go repository URLs, without going through the `pkgs` table.
func (g *globalDeps) doTotalRefsGo(ctx context.Context, source string) (int, error) {
// Because global_dep only stores Go package paths, not repository URIs, we
// use a simple heuristic here by using `LIKE <repo>%`. This will work for
// GitHub package paths (e.g. `github.com/a/b%` matches `github.com/a/b/c`)
// but not custom import paths etc.
rows, err := dbconn.Global.QueryContext(ctx, `SELECT COUNT(DISTINCT repo_id)
FROM global_dep
WHERE language='go'
AND dep_data->>'depth' = '0'
AND ( -- in C locale, this is equivalent to matching "$1/*", but matches much faster
(dep_data->>'package' COLLATE "C" < $1 || '0' COLLATE "C" AND dep_data->>'package' COLLATE "C" > $1 || '/' COLLATE "C")
OR (dep_data->>'package' COLLATE "C" = $1)
);
`, source)
if err != nil {
return 0, errors.Wrap(err, "Query")
}
defer rows.Close()
var count int
for rows.Next() {
err := rows.Scan(&count)
if err != nil {
return 0, errors.Wrap(err, "Scan")
}
}
return count, nil
}
// doListTotalRefsGo is the Go-specific implementation of list total
// references, since we can extract package metadata directly from Go
// repository URLs, without going through the `pkgs` table.
func (g *globalDeps) doListTotalRefsGo(ctx context.Context, source string) ([]api.RepoID, error) {
// Because global_dep only stores Go package paths, not repository URIs, we
// use a simple heuristic here by using `LIKE <repo>%`. This will work for
// GitHub package paths (e.g. `github.com/a/b%` matches `github.com/a/b/c`)
// but not custom import paths etc.
rows, err := dbconn.Global.QueryContext(ctx, `SELECT DISTINCT repo_id
FROM global_dep
WHERE language='go'
AND dep_data->>'depth' = '0'
AND dep_data->>'package' LIKE $1;
`, source+"%")
if err != nil {
return nil, errors.Wrap(err, "Query")
}
defer rows.Close()
var repos []api.RepoID
for rows.Next() {
var repo api.RepoID
err := rows.Scan(&repo)
if err != nil {
return nil, errors.Wrap(err, "Scan")
}
repos = append(repos, repo)
}
return repos, nil
}
func (g *globalDeps) UpdateIndexForLanguage(ctx context.Context, language string, repo api.RepoID, deps []lspext.DependencyReference) (err error) {
err = db.Transaction(ctx, dbconn.Global, func(tx *sql.Tx) error {
// Update the table.
err = g.update(ctx, tx, language, deps, repo)
if err != nil {
return errors.Wrap(err, "update global_dep")
}
return nil
})
if err != nil {
return errors.Wrap(err, "executing transaction")
}
return nil
}
func (g *globalDeps) Dependencies(ctx context.Context, op db.DependenciesOptions) (refs []*api.DependencyReference, err error) {
if db.Mocks.GlobalDeps.Dependencies != nil {
return db.Mocks.GlobalDeps.Dependencies(ctx, op)
}
span, ctx := opentracing.StartSpanFromContext(ctx, "db.Dependencies")
defer func() {
if err != nil {
ext.Error.Set(span, true)
span.SetTag("err", err.Error())
}
span.Finish()
}()
span.SetTag("Language", op.Language)
span.SetTag("DepData", op.DepData)
var args []interface{}
arg := func(a interface{}) string {
args = append(args, a)
return fmt.Sprintf("$%d", len(args))
}
var whereConds []string
if op.Language != "" {
whereConds = append(whereConds, `gd.language=`+arg(op.Language))
}
if op.DepData != nil {
containmentQuery, err := json.Marshal(op.DepData)
if err != nil {
return nil, errors.New("marshaling op.DepData query")
}
whereConds = append(whereConds, `dep_data @> `+arg(string(containmentQuery)))
}
if op.Repo != 0 {
whereConds = append(whereConds, `repo_id = `+arg(op.Repo))
}
selectSQL := `SELECT gd.language, dep_data, repo_id, hints`
fromSQL := `FROM global_dep AS gd INNER JOIN repo AS r ON gd.repo_id=r.id`
whereSQL := ""
if len(whereConds) > 0 {
whereSQL = `WHERE ` + strings.Join(whereConds, " AND ")
}
limitSQL := ""
if op.Limit != 0 {
limitSQL = `LIMIT ` + arg(op.Limit)
}
sql := fmt.Sprintf("%s %s %s %s", selectSQL, fromSQL, whereSQL, limitSQL)
rows, err := dbconn.Global.QueryContext(ctx, sql, args...)
if err != nil {
return nil, errors.Wrap(err, "query")
}
defer rows.Close()
for rows.Next() {
var (
language, depData, hints string
repo api.RepoID
)
if err := rows.Scan(&language, &depData, &repo, &hints); err != nil {
return nil, errors.Wrap(err, "Scan")
}
r := &api.DependencyReference{
RepoID: repo,
Language: language,
}
if err := json.Unmarshal([]byte(depData), &r.DepData); err != nil {
return nil, errors.Wrap(err, "unmarshaling xdependencies metadata from sql scan")
}
if err := json.Unmarshal([]byte(hints), &r.Hints); err != nil {
return nil, errors.Wrap(err, "unmarshaling xdependencies hints from sql scan")
}
refs = append(refs, r)
}
if err := rows.Err(); err != nil {
return nil, errors.Wrap(err, "rows error")
}
return refs, nil
}
// updateGlobalDep updates the global_dep table.
func (g *globalDeps) update(ctx context.Context, tx *sql.Tx, language string, deps []lspext.DependencyReference, indexRepo api.RepoID) (err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "updateGlobalDep "+language)
defer func() {
if err != nil {
ext.Error.Set(span, true)
span.SetTag("err", err.Error())
}
span.Finish()
}()
span.SetTag("deps", len(deps))
// First, create a temporary table.
_, err = tx.ExecContext(ctx, `CREATE TEMPORARY TABLE new_global_dep (
language text NOT NULL,
dep_data jsonb NOT NULL,
repo_id integer NOT NULL,
hints jsonb
) ON COMMIT DROP;`)
if err != nil {
return errors.Wrap(err, "create temp table")
}
span.LogFields(otlog.String("event", "created temp table"))
// Copy the new deps into the temporary table.
copy, err := tx.Prepare(pq.CopyIn("new_global_dep",
"language",
"dep_data",
"repo_id",
"hints",
))
if err != nil {
return errors.Wrap(err, "prepare copy in")
}
defer copy.Close()
span.LogFields(otlog.String("event", "prepared copy in"))
for _, r := range deps {
data, err := json.Marshal(r.Attributes)
if err != nil {
return errors.Wrap(err, "marshaling xdependency metadata to JSON")
}
hintsData, err := json.Marshal(r.Hints)
if err != nil {
return errors.Wrap(err, "marshaling xdependency hints to JSON")
}
if _, err := copy.Exec(
language, // language
string(data), // dep_data
indexRepo, // repo_id
string(hintsData), // hints
); err != nil {
return errors.Wrap(err, "executing ref copy")
}
}
span.LogFields(otlog.String("event", "executed all dep copy"))
if _, err := copy.Exec(); err != nil {
return errors.Wrap(err, "executing copy")
}
span.LogFields(otlog.String("event", "executed copy"))
if _, err := tx.ExecContext(ctx, `DELETE FROM global_dep WHERE language=$1 AND repo_id=$2`, language, indexRepo); err != nil {
return errors.Wrap(err, "executing table deletion")
}
span.LogFields(otlog.String("event", "executed table deletion"))
// Insert from temporary table into the real table.
_, err = tx.ExecContext(ctx, `INSERT INTO global_dep(
language,
dep_data,
repo_id,
hints
) SELECT d.language,
d.dep_data,
d.repo_id,
d.hints
FROM new_global_dep d;`)
if err != nil {
return errors.Wrap(err, "executing final insertion from temp table")
}
span.LogFields(otlog.String("event", "executed final insertion from temp table"))
return nil
}
func (g *globalDeps) Delete(ctx context.Context, repo api.RepoID) error {
_, err := dbconn.Global.ExecContext(ctx, `DELETE FROM global_dep WHERE repo_id=$1`, repo)
return err
}

View File

@ -0,0 +1,345 @@
package db
import (
"bytes"
"encoding/json"
"fmt"
"reflect"
"sort"
"testing"
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
dbtesting "github.com/sourcegraph/sourcegraph/cmd/frontend/db/testing"
"github.com/sourcegraph/sourcegraph/pkg/api"
"github.com/sourcegraph/sourcegraph/xlang/lspext"
)
func TestGlobalDeps_TotalRefsExpansion(t *testing.T) {
tests := map[api.RepoURI][]string{
// azul3d.org
"github.com/azul3d/engine": {"azul3d.org/engine"},
// dasa.cc
"github.com/dskinner/ztext": {"dasa.cc/ztext"},
// k8s.io
"github.com/kubernetes/kubernetes": {"k8s.io/kubernetes"},
"github.com/kubernetes/apimachinery": {"k8s.io/apimachinery"},
"github.com/kubernetes/client-go": {"k8s.io/client-go"},
"github.com/kubernetes/heapster": {"k8s.io/heapster"},
// golang.org/x
"github.com/golang/net": {"golang.org/x/net"},
"github.com/golang/tools": {"golang.org/x/tools"},
"github.com/golang/oauth2": {"golang.org/x/oauth2"},
"github.com/golang/crypto": {"golang.org/x/crypto"},
"github.com/golang/sys": {"golang.org/x/sys"},
"github.com/golang/text": {"golang.org/x/text"},
"github.com/golang/image": {"golang.org/x/image"},
"github.com/golang/mobile": {"golang.org/x/mobile"},
// google.golang.org
"github.com/grpc/grpc-go": {"google.golang.org/grpc"},
"github.com/google/google-api-go-client": {"google.golang.org/api"},
"github.com/golang/appengine": {"google.golang.org/appengine"},
// go.uber.org
"github.com/uber-go/yarpc": {"github.com/uber-go/yarpc", "go.uber.org/yarpc"},
"github.com/uber-go/thriftrw": {"github.com/uber-go/thriftrw", "go.uber.org/thriftrw"},
"github.com/uber-go/zap": {"github.com/uber-go/zap", "go.uber.org/zap"},
"github.com/uber-go/atomic": {"github.com/uber-go/atomic", "go.uber.org/atomic"},
"github.com/uber-go/fx": {"github.com/uber-go/fx", "go.uber.org/fx"},
// go4.org
"github.com/camlistore/go4": {"go4.org"},
// honnef.co
"github.com/dominikh/go-staticcheck": {"honnef.co/go/staticcheck"},
"github.com/dominikh/go-js-dom": {"honnef.co/go/js/dom"},
"github.com/dominikh/go-ssa": {"honnef.co/go/ssa"},
// gopkg.in
"github.com/go-mgo/mgo": {"github.com/go-mgo/mgo", "gopkg.in/mgo", "labix.org/v1/mgo", "labix.org/v2/mgo"},
"github.com/go-yaml/yaml": {"github.com/go-yaml/yaml", "gopkg.in/yaml", "labix.org/v1/yaml", "labix.org/v2/yaml"},
"github.com/fatih/set": {"github.com/fatih/set", "gopkg.in/fatih/set"},
"github.com/juju/environschema": {"github.com/juju/environschema", "gopkg.in/juju/environschema"},
}
for input, want := range tests {
got := repoURIToGoPathPrefixes(input)
if !reflect.DeepEqual(got, want) {
t.Errorf("got %q want %q", got, want)
}
}
}
func TestGlobalDeps_update_delete(t *testing.T) {
if testing.Short() {
t.Skip()
}
ctx := dbtesting.TestContext(t)
if err := db.Repos.Upsert(ctx, api.InsertRepoOp{URI: "myrepo", Description: "", Fork: false, Enabled: true}); err != nil {
t.Fatal(err)
}
rp, err := db.Repos.GetByURI(ctx, "myrepo")
if err != nil {
t.Fatal(err)
}
repo := rp.ID
inputRefs := []lspext.DependencyReference{{
Attributes: map[string]interface{}{"name": "dep1", "vendor": true},
}}
if err := GlobalDeps.UpdateIndexForLanguage(ctx, "go", repo, inputRefs); err != nil {
t.Fatal(err)
}
t.Log("update")
wantRefs := []*api.DependencyReference{{
Language: "go",
DepData: map[string]interface{}{"name": "dep1", "vendor": true},
RepoID: repo,
}}
gotRefs, err := GlobalDeps.Dependencies(ctx, db.DependenciesOptions{
Language: "go",
DepData: map[string]interface{}{"name": "dep1"},
Limit: 20,
})
if err != nil {
t.Fatal(err)
}
sort.Sort(sortDepRefs(wantRefs))
sort.Sort(sortDepRefs(gotRefs))
if !reflect.DeepEqual(gotRefs, wantRefs) {
t.Errorf("got %+v, expected %+v", gotRefs, wantRefs)
}
t.Log("delete other")
if err := GlobalDeps.Delete(ctx, 345345345); err != nil {
t.Fatal(err)
}
gotRefs, err = GlobalDeps.Dependencies(ctx, db.DependenciesOptions{
Language: "go",
DepData: map[string]interface{}{"name": "dep1"},
Limit: 20,
})
if err != nil {
t.Fatal(err)
}
sort.Sort(sortDepRefs(wantRefs))
sort.Sort(sortDepRefs(gotRefs))
if !reflect.DeepEqual(gotRefs, wantRefs) {
t.Errorf("got %+v, expected %+v", gotRefs, wantRefs)
}
t.Log("delete")
if err := GlobalDeps.Delete(ctx, repo); err != nil {
t.Fatal(err)
}
gotRefs, err = GlobalDeps.Dependencies(ctx, db.DependenciesOptions{
Language: "go",
DepData: map[string]interface{}{"name": "dep1"},
Limit: 20,
})
if err != nil {
t.Fatal(err)
}
if len(gotRefs) > 0 {
t.Errorf("expected no matching refs, got %+v", gotRefs)
}
}
func TestGlobalDeps_RefreshIndex(t *testing.T) {
if testing.Short() {
t.Skip()
}
ctx := dbtesting.TestContext(t)
if err := db.Repos.Upsert(ctx, api.InsertRepoOp{URI: "myrepo", Description: "", Fork: false, Enabled: true}); err != nil {
t.Fatal(err)
}
repo, err := db.Repos.GetByURI(ctx, "myrepo")
if err != nil {
t.Fatal(err)
}
if err := GlobalDeps.UpdateIndexForLanguage(ctx, "go", repo.ID, []lspext.DependencyReference{{
Attributes: map[string]interface{}{
"name": "github.com/gorilla/dep",
"vendor": true,
},
}}); err != nil {
t.Fatal(err)
}
wantRefs := []*api.DependencyReference{{
Language: "go",
DepData: map[string]interface{}{"name": "github.com/gorilla/dep", "vendor": true},
RepoID: repo.ID,
}}
gotRefs, err := GlobalDeps.Dependencies(ctx, db.DependenciesOptions{
Language: "go",
DepData: map[string]interface{}{"name": "github.com/gorilla/dep"},
Limit: 20,
})
if err != nil {
t.Fatal(err)
}
sort.Sort(sortDepRefs(wantRefs))
sort.Sort(sortDepRefs(gotRefs))
if !reflect.DeepEqual(gotRefs, wantRefs) {
t.Errorf("got %+v, expected %+v", gotRefs, wantRefs)
}
}
func TestGlobalDeps_Dependencies(t *testing.T) {
if testing.Short() {
t.Skip()
}
ctx := dbtesting.TestContext(t)
repos := make([]api.RepoID, 5)
for i := 0; i < 5; i++ {
uri := api.RepoURI(fmt.Sprintf("myrepo-%d", i))
if err := db.Repos.Upsert(ctx, api.InsertRepoOp{URI: uri, Description: "", Fork: false, Enabled: true}); err != nil {
t.Fatal(err)
}
rp, err := db.Repos.GetByURI(ctx, uri)
if err != nil {
t.Fatal(err)
}
repos[i] = rp.ID
}
inputRefs := map[api.RepoID][]lspext.DependencyReference{
repos[0]: {{Attributes: map[string]interface{}{"name": "github.com/gorilla/dep2", "vendor": true}}},
repos[1]: {{Attributes: map[string]interface{}{"name": "github.com/gorilla/dep3", "vendor": true}}},
repos[2]: {{Attributes: map[string]interface{}{"name": "github.com/gorilla/dep4", "vendor": true}}},
repos[3]: {{Attributes: map[string]interface{}{"name": "github.com/gorilla/dep4", "vendor": true}}},
repos[4]: {{Attributes: map[string]interface{}{"name": "github.com/gorilla/dep4", "vendor": true}}},
}
for rp, deps := range inputRefs {
err := GlobalDeps.UpdateIndexForLanguage(ctx, "go", rp, deps)
if err != nil {
t.Fatal(err)
}
}
{ // Test case 1
wantRefs := []*api.DependencyReference{{
Language: "go",
DepData: map[string]interface{}{"name": "github.com/gorilla/dep2", "vendor": true},
RepoID: repos[0],
}}
gotRefs, err := GlobalDeps.Dependencies(ctx, db.DependenciesOptions{
Language: "go",
DepData: map[string]interface{}{"name": "github.com/gorilla/dep2"},
Limit: 20,
})
if err != nil {
t.Fatal(err)
}
sort.Sort(sortDepRefs(wantRefs))
sort.Sort(sortDepRefs(gotRefs))
if !reflect.DeepEqual(gotRefs, wantRefs) {
t.Errorf("got %+v, expected %+v", gotRefs, wantRefs)
}
}
{ // Test case 2
wantRefs := []*api.DependencyReference{{
Language: "go",
DepData: map[string]interface{}{"name": "github.com/gorilla/dep3", "vendor": true},
RepoID: repos[1],
}}
gotRefs, err := GlobalDeps.Dependencies(ctx, db.DependenciesOptions{
Language: "go",
DepData: map[string]interface{}{"name": "github.com/gorilla/dep3"},
Limit: 20,
})
if err != nil {
t.Fatal(err)
}
sort.Sort(sortDepRefs(wantRefs))
sort.Sort(sortDepRefs(gotRefs))
if !reflect.DeepEqual(gotRefs, wantRefs) {
t.Errorf("got %+v, expected %+v", gotRefs, wantRefs)
}
}
{ // Test case 3
wantRefs := []*api.DependencyReference{{
Language: "go",
DepData: map[string]interface{}{"name": "github.com/gorilla/dep4", "vendor": true},
RepoID: repos[2],
}, {
Language: "go",
DepData: map[string]interface{}{"name": "github.com/gorilla/dep4", "vendor": true},
RepoID: repos[3],
},
{
Language: "go",
DepData: map[string]interface{}{"name": "github.com/gorilla/dep4", "vendor": true},
RepoID: repos[4],
},
}
gotRefs, err := GlobalDeps.Dependencies(ctx, db.DependenciesOptions{
Language: "go",
DepData: map[string]interface{}{"name": "github.com/gorilla/dep4"},
Limit: 20,
})
if err != nil {
t.Fatal(err)
}
sort.Sort(sortDepRefs(wantRefs))
sort.Sort(sortDepRefs(gotRefs))
if !reflect.DeepEqual(gotRefs, wantRefs) {
t.Errorf("got %+v, expected %+v", gotRefs, wantRefs)
}
}
}
type sortDepRefs []*api.DependencyReference
func (s sortDepRefs) Len() int { return len(s) }
func (s sortDepRefs) Swap(a, b int) { s[a], s[b] = s[b], s[a] }
func (s sortDepRefs) Less(a, b int) bool {
if s[a].RepoID != s[b].RepoID {
return s[a].RepoID < s[b].RepoID
}
if !reflect.DeepEqual(s[a].DepData, s[b].DepData) {
return stringMapLess(s[a].DepData, s[b].DepData)
}
return stringMapLess(s[a].Hints, s[b].Hints)
}
func stringMapLess(a, b map[string]interface{}) bool {
if len(a) != len(b) {
return len(a) < len(b)
}
ak := make([]string, 0, len(a))
for k := range a {
ak = append(ak, k)
}
bk := make([]string, 0, len(b))
for k := range b {
bk = append(bk, k)
}
sort.Strings(ak)
sort.Strings(bk)
for i := range ak {
if ak[i] != bk[i] {
return ak[i] < bk[i]
}
// This does not consistentlbk order the output, but in the
// cases we use this it will since it is just a simple value
// like bool or string
av, _ := json.Marshal(a[ak[i]])
bv, _ := json.Marshal(b[bk[i]])
if bytes.Equal(av, bv) {
return string(av) < string(bv)
}
}
return false
}

View File

@ -0,0 +1,212 @@
package db
import (
"context"
"database/sql"
"encoding/json"
"fmt"
"strings"
"github.com/pkg/errors"
"github.com/lib/pq"
opentracing "github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
otlog "github.com/opentracing/opentracing-go/log"
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
"github.com/sourcegraph/sourcegraph/cmd/frontend/db/dbconn"
"github.com/sourcegraph/sourcegraph/pkg/api"
"github.com/sourcegraph/sourcegraph/xlang/lspext"
)
// pkgs provides access to the `pkgs` table.
//
// For a detailed overview of the schema, see schema.txt.
type pkgs struct{}
func (p *pkgs) UpdateIndexForLanguage(ctx context.Context, language string, repo api.RepoID, pks []lspext.PackageInformation) (err error) {
err = db.Transaction(ctx, dbconn.Global, func(tx *sql.Tx) error {
// Update the pkgs table.
err = p.update(ctx, tx, repo, language, pks)
if err != nil {
return errors.Wrap(err, "pkgs.update")
}
return nil
})
if err != nil {
return errors.Wrap(err, "executing transaction")
}
return nil
}
type dbQueryer interface {
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
}
func (p *pkgs) update(ctx context.Context, tx *sql.Tx, indexRepo api.RepoID, language string, pks []lspext.PackageInformation) (err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "pkgs.update "+language)
defer func() {
if err != nil {
ext.Error.Set(span, true)
span.SetTag("err", err.Error())
}
span.Finish()
}()
span.SetTag("pkgs", len(pks))
// First, create a temporary table.
_, err = tx.ExecContext(ctx, `CREATE TEMPORARY TABLE new_pkgs (
pkg jsonb NOT NULL,
language text NOT NULL,
repo_id integer NOT NULL
) ON COMMIT DROP;`)
if err != nil {
return errors.Wrap(err, "create temp table")
}
span.LogFields(otlog.String("event", "created temp table"))
// Copy the new deps into the temporary table.
copy, err := tx.Prepare(pq.CopyIn("new_pkgs",
"repo_id",
"language",
"pkg",
))
if err != nil {
return errors.Wrap(err, "prepare copy in")
}
defer copy.Close()
span.LogFields(otlog.String("event", "prepared copy in"))
for _, r := range pks {
pkgData, err := json.Marshal(r.Package)
if err != nil {
return errors.Wrap(err, "marshaling package metadata to JSON")
}
if _, err := copy.Exec(
indexRepo, // repo_id
language, // language
string(pkgData), // pkg
); err != nil {
return errors.Wrap(err, "executing pkg copy")
}
}
span.LogFields(otlog.String("event", "executed all pkg copy"))
if _, err := copy.Exec(); err != nil {
return errors.Wrap(err, "executing copy")
}
span.LogFields(otlog.String("event", "executed copy"))
if _, err := tx.ExecContext(ctx, `DELETE FROM pkgs WHERE language=$1 AND repo_id=$2`, language, indexRepo); err != nil {
return errors.Wrap(err, "executing pkgs deletion")
}
span.LogFields(otlog.String("event", "executed pkgs deletion"))
// Insert from temporary table into the real table.
_, err = tx.ExecContext(ctx, `INSERT INTO pkgs(
repo_id,
language,
pkg
)
SELECT p.repo_id,
p.language,
p.pkg
FROM new_pkgs p;
`)
if err != nil {
return errors.Wrap(err, "executing final insertion from temp table")
}
span.LogFields(otlog.String("event", "executed final insertion from temp table"))
return nil
}
func (p *pkgs) ListPackages(ctx context.Context, op *api.ListPackagesOp) (pks []*api.PackageInfo, err error) {
if db.Mocks.Pkgs.ListPackages != nil {
return db.Mocks.Pkgs.ListPackages(ctx, op)
}
span, ctx := opentracing.StartSpanFromContext(ctx, "pkgs.ListPackages")
defer func() {
if err != nil {
ext.Error.Set(span, true)
span.SetTag("err", err.Error())
}
span.Finish()
}()
span.SetTag("Lang", op.Lang)
span.SetTag("PkgQuery", op.PkgQuery)
var args []interface{}
arg := func(a interface{}) string {
args = append(args, a)
return fmt.Sprintf("$%d", len(args))
}
var whereClauses []string
if op.PkgQuery != nil {
containmentQuery, err := json.Marshal(op.PkgQuery)
if err != nil {
return nil, errors.New("marshaling op.PkgQuery")
}
whereClauses = append(whereClauses, `pkgs.pkg @> `+arg(string(containmentQuery)))
}
if op.RepoID != 0 {
whereClauses = append(whereClauses, `repo_id=`+arg(op.RepoID))
}
if op.Lang != "" {
whereClauses = append(whereClauses, `pkgs.language=`+arg(op.Lang))
}
if len(whereClauses) == 0 {
return nil, fmt.Errorf("no filtering options specified, must specify at least one")
}
whereSQL := "(" + strings.Join(whereClauses, ") AND (") + ")"
sql := `
SELECT pkgs.*
FROM pkgs INNER JOIN repo ON pkgs.repo_id=repo.id
WHERE ` + whereSQL + `
ORDER BY repo.created_at ASC NULLS LAST, pkgs.repo_id ASC
LIMIT ` + arg(op.Limit)
rows, err := dbconn.Global.QueryContext(ctx, sql, args...)
if err != nil {
return nil, errors.Wrap(err, "query")
}
defer rows.Close()
var rawPkgs []*api.PackageInfo
for rows.Next() {
var (
pkg, lang string
repo api.RepoID
)
if err := rows.Scan(&repo, &lang, &pkg); err != nil {
return nil, errors.Wrap(err, "Scan")
}
r := api.PackageInfo{
RepoID: repo,
Lang: lang,
// NOTE: Dependency info (in api.PackageInfo's Dependencies field) is not set
// here because it is stored separately in the global_dep table in a way that
// is slow and difficult to get in this code path. Currently callers that use
// DB-persisted package info do not need the dependency info, so this is
// acceptable.
}
if err := json.Unmarshal([]byte(pkg), &r.Pkg); err != nil {
return nil, errors.Wrap(err, "unmarshaling xdependencies metadata from sql scan")
}
rawPkgs = append(rawPkgs, &r)
}
if err := rows.Err(); err != nil {
return nil, errors.Wrap(err, "rows error")
}
return rawPkgs, nil
}
func (p *pkgs) Delete(ctx context.Context, repo api.RepoID) error {
if db.Mocks.Pkgs.Delete != nil {
return db.Mocks.Pkgs.Delete(ctx, repo)
}
_, err := dbconn.Global.ExecContext(ctx, `DELETE FROM pkgs WHERE repo_id=$1`, repo)
return err
}

View File

@ -0,0 +1,303 @@
package db
import (
"context"
"database/sql"
"encoding/json"
"reflect"
"strconv"
"testing"
"time"
"github.com/pkg/errors"
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
"github.com/sourcegraph/sourcegraph/cmd/frontend/db/dbconn"
dbtesting "github.com/sourcegraph/sourcegraph/cmd/frontend/db/testing"
"github.com/sourcegraph/sourcegraph/pkg/api"
"github.com/sourcegraph/sourcegraph/xlang/lspext"
)
func TestPkgs_update_delete(t *testing.T) {
if testing.Short() {
t.Skip()
}
ctx := dbtesting.TestContext(t)
if err := db.Repos.Upsert(ctx, api.InsertRepoOp{URI: "myrepo", Description: "", Fork: false, Enabled: true}); err != nil {
t.Fatal(err)
}
rp, err := db.Repos.GetByURI(ctx, "myrepo")
if err != nil {
t.Fatal(err)
}
pks := []lspext.PackageInformation{{
Package: map[string]interface{}{"name": "pkg"},
Dependencies: []lspext.DependencyReference{{
Attributes: map[string]interface{}{"name": "dep1"},
}},
}}
t.Log("update")
if err := db.Transaction(ctx, dbconn.Global, func(tx *sql.Tx) error {
if err := Pkgs.update(ctx, tx, rp.ID, "go", pks); err != nil {
return err
}
return nil
}); err != nil {
t.Fatal(err)
}
expPkgs := []*api.PackageInfo{{
RepoID: rp.ID,
Lang: "go",
Pkg: map[string]interface{}{"name": "pkg"},
}}
gotPkgs, err := Pkgs.getAll(ctx, dbconn.Global)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(gotPkgs, expPkgs) {
t.Errorf("got %+v, expected %+v", gotPkgs, expPkgs)
}
t.Log("delete nothing")
if err := Pkgs.Delete(ctx, 0); err != nil {
t.Fatal(err)
}
gotPkgs, err = Pkgs.getAll(ctx, dbconn.Global)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(gotPkgs, expPkgs) {
t.Errorf("got %+v, expected %+v", gotPkgs, expPkgs)
}
t.Log("delete")
if err := Pkgs.Delete(ctx, 1); err != nil {
t.Fatal(err)
}
gotPkgs, err = Pkgs.getAll(ctx, dbconn.Global)
if err != nil {
t.Fatal(err)
}
if len(gotPkgs) > 0 {
t.Errorf("expected all pkgs corresponding to repo %d deleted, but got %+v", rp.ID, gotPkgs)
}
}
func TestPkgs_RefreshIndex(t *testing.T) {
if testing.Short() {
t.Skip()
}
ctx := dbtesting.TestContext(t)
if err := db.Repos.Upsert(ctx, api.InsertRepoOp{URI: "myrepo", Description: "", Fork: false, Enabled: true}); err != nil {
t.Fatal(err)
}
rp, err := db.Repos.GetByURI(ctx, "myrepo")
if err != nil {
t.Fatal(err)
}
if err := Pkgs.UpdateIndexForLanguage(ctx, "typescript", rp.ID, []lspext.PackageInformation{
{
Package: map[string]interface{}{
"name": "tspkg",
"version": "2.2.2",
},
Dependencies: []lspext.DependencyReference{},
},
}); err != nil {
t.Fatal(err)
}
expPkgs := []*api.PackageInfo{{
RepoID: rp.ID,
Lang: "typescript",
Pkg: map[string]interface{}{
"name": "tspkg",
"version": "2.2.2",
},
}}
gotPkgs, err := Pkgs.getAll(ctx, dbconn.Global)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(gotPkgs, expPkgs) {
t.Errorf("got %+v, expected %+v", gotPkgs, expPkgs)
}
}
func TestPkgs_ListPackages(t *testing.T) {
if testing.Short() {
t.Skip()
}
ctx := dbtesting.TestContext(t)
repoToPkgs := map[api.RepoID][]lspext.PackageInformation{
1: {{
Package: map[string]interface{}{"name": "pkg1", "version": "1.1.1"},
Dependencies: []lspext.DependencyReference{{
Attributes: map[string]interface{}{"name": "pkg1-dep", "version": "1.1.2"},
}},
}},
2: {{
Package: map[string]interface{}{"name": "pkg2", "version": "2.2.1"},
Dependencies: []lspext.DependencyReference{{
Attributes: map[string]interface{}{"name": "pkg2-dep", "version": "2.2.2"},
}},
}},
3: {{Package: map[string]interface{}{"name": "pkg3", "version": "3.3.1"}}},
4: {{Package: map[string]interface{}{"name": "pkg3", "version": "3.3.1"}}},
5: {{Package: map[string]interface{}{"name": "pkg3", "version": "3.3.1"}}},
}
createdAt := time.Now()
for repo, pks := range repoToPkgs {
if err := db.Transaction(ctx, dbconn.Global, func(tx *sql.Tx) error {
if _, err := tx.ExecContext(ctx, `INSERT INTO repo(id, uri, created_at) VALUES ($1, $2, $3)`, repo, strconv.Itoa(int(repo)), createdAt); err != nil {
return err
}
if err := Pkgs.update(ctx, tx, repo, "go", pks); err != nil {
return err
}
return nil
}); err != nil {
t.Fatal(err)
}
}
{ // Test case 1
expPkgInfo := []*api.PackageInfo{{
RepoID: 1,
Lang: "go",
Pkg: map[string]interface{}{"name": "pkg1", "version": "1.1.1"},
}}
op := &api.ListPackagesOp{
Lang: "go",
PkgQuery: map[string]interface{}{"name": "pkg1"},
Limit: 10,
}
gotPkgInfo, err := Pkgs.ListPackages(ctx, op)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(gotPkgInfo, expPkgInfo) {
t.Errorf("got %+v, expected %+v", gotPkgInfo, expPkgInfo)
}
}
{ // Test case 2
expPkgInfo := []*api.PackageInfo{{
RepoID: 1,
Lang: "go",
Pkg: map[string]interface{}{"name": "pkg1", "version": "1.1.1"},
}}
op := &api.ListPackagesOp{
Lang: "go",
PkgQuery: map[string]interface{}{"name": "pkg1", "version": "1.1.1"},
Limit: 10,
}
gotPkgInfo, err := Pkgs.ListPackages(ctx, op)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(gotPkgInfo, expPkgInfo) {
t.Errorf("got %+v, expected %+v", gotPkgInfo, expPkgInfo)
}
}
{ // Test case 3
var expPkgInfo []*api.PackageInfo
op := &api.ListPackagesOp{
Lang: "go",
PkgQuery: map[string]interface{}{"name": "pkg1", "version": "2"},
Limit: 10,
}
gotPkgInfo, err := Pkgs.ListPackages(ctx, op)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(gotPkgInfo, expPkgInfo) {
t.Errorf("got %+v, expected %+v", gotPkgInfo, expPkgInfo)
}
}
{ // Test case 4
expPkgInfo := []*api.PackageInfo{{
RepoID: 3,
Lang: "go",
Pkg: map[string]interface{}{"name": "pkg3", "version": "3.3.1"},
}, {
RepoID: 4,
Lang: "go",
Pkg: map[string]interface{}{"name": "pkg3", "version": "3.3.1"},
},
{
RepoID: 5,
Lang: "go",
Pkg: map[string]interface{}{"name": "pkg3", "version": "3.3.1"},
},
}
op := &api.ListPackagesOp{
Lang: "go",
PkgQuery: map[string]interface{}{"name": "pkg3"},
Limit: 10,
}
gotPkgInfo, err := Pkgs.ListPackages(ctx, op)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(gotPkgInfo, expPkgInfo) {
t.Errorf("got %+v, expected %+v", gotPkgInfo, expPkgInfo)
}
}
{ // Test case 5, filter by repo ID
expPkgInfo := []*api.PackageInfo{{
RepoID: 3,
Lang: "go",
Pkg: map[string]interface{}{"name": "pkg3", "version": "3.3.1"},
}}
op := &api.ListPackagesOp{
Lang: "go",
RepoID: 3,
Limit: 10,
}
gotPkgInfo, err := Pkgs.ListPackages(ctx, op)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(gotPkgInfo, expPkgInfo) {
t.Errorf("got %+v, expected %+v", gotPkgInfo, expPkgInfo)
}
}
}
func (p *pkgs) getAll(ctx context.Context, db dbQueryer) (packages []*api.PackageInfo, err error) {
rows, err := db.QueryContext(ctx, "SELECT * FROM pkgs ORDER BY language ASC")
if err != nil {
return nil, errors.Wrap(err, "query")
}
defer rows.Close()
for rows.Next() {
var (
repo api.RepoID
language string
pkg string
)
if err := rows.Scan(&repo, &language, &pkg); err != nil {
return nil, errors.Wrap(err, "Scan")
}
p := api.PackageInfo{
RepoID: repo,
Lang: language,
}
if err := json.Unmarshal([]byte(pkg), &p.Pkg); err != nil {
return nil, errors.Wrap(err, "unmarshaling package metadata from SQL scan")
}
packages = append(packages, &p)
}
if err := rows.Err(); err != nil {
return nil, errors.Wrap(err, "rows error")
}
return packages, nil
}

View File

@ -0,0 +1,15 @@
package db
import (
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
)
var (
Pkgs = &pkgs{}
GlobalDeps = &globalDeps{}
)
func init() {
db.Pkgs = Pkgs
db.GlobalDeps = GlobalDeps
}

View File

@ -0,0 +1,13 @@
// +build !dist
package assets
import (
"net/http"
"github.com/sourcegraph/sourcegraph/cmd/frontend/assets"
)
func init() {
assets.Assets = http.Dir("./ui/assets")
}

View File

@ -0,0 +1,9 @@
// +build dist
package assets
import "github.com/sourcegraph/sourcegraph/cmd/frontend/assets"
func init() {
assets.Assets = DistAssets
}

View File

@ -0,0 +1,22 @@
// +build generate
package main
import (
"log"
"net/http"
"github.com/shurcooL/vfsgen"
)
func main() {
dir := "../../../../ui/assets/"
err := vfsgen.Generate(http.Dir(dir), vfsgen.Options{
PackageName: "assets",
BuildTags: "dist",
VariableName: "DistAssets",
})
if err != nil {
log.Fatalln(err)
}
}

View File

@ -0,0 +1,3 @@
// Package assets contains static assets for the enterprise front-end Web app. It should be imported
// for side-effects.
package assets

View File

@ -0,0 +1,3 @@
//go:generate go run assets_generate.go
package assets

View File

@ -0,0 +1,165 @@
package billing
import (
"context"
"fmt"
"sync"
"time"
multierror "github.com/hashicorp/go-multierror"
"github.com/pkg/errors"
"github.com/sourcegraph/sourcegraph/cmd/frontend/db/dbconn"
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
stripe "github.com/stripe/stripe-go"
"github.com/stripe/stripe-go/customer"
log15 "gopkg.in/inconshreveable/log15.v2"
)
// GetOrAssignUserCustomerID returns the billing customer ID associated with the user. If no billing
// customer ID exists for the user, a new one is created and saved on the user's DB record.
func GetOrAssignUserCustomerID(ctx context.Context, userID int32) (_ string, err error) {
// Wrap this operation in a transaction so we never have stored 2 auto-created billing customer
// IDs for the same user.
tx, err := dbconn.Global.BeginTx(ctx, nil)
if err != nil {
return "", err
}
defer func() {
if err != nil {
rollErr := tx.Rollback()
if rollErr != nil {
err = multierror.Append(err, rollErr)
}
return
}
err = tx.Commit()
}()
custID, err := dbBilling{}.getUserBillingCustomerID(ctx, tx, userID)
if err != nil {
return "", err
}
if custID == nil {
// There is no billing customer ID for this user yet, so we must make one. This is not racy
// w.r.t. the DB because we are still in a DB transaction. It is still possible for a race
// condition to result in 2 billing customers being created, but only one of them would ever
// be stored in our DB.
newCustID, err := createCustomerID(ctx, userID)
if err != nil {
return "", errors.WithMessage(err, fmt.Sprintf("auto-creating customer ID for user ID %d", userID))
}
// If we fail after here, then try to clean up the customer ID.
defer func() {
if err != nil {
ctx, cancel := context.WithTimeout(ctx, 3*time.Second) // don't wait too long
defer cancel()
if err := deleteCustomerID(ctx, newCustID); err != nil {
log15.Error("During cleanup of failed auto-creation of billing customer ID for user, failed to delete billing customer ID.", "userID", userID, "newCustomerID", newCustID, "err", err)
}
}
}()
if err := (dbBilling{}).setUserBillingCustomerID(ctx, tx, userID, &newCustID); err != nil {
return "", err
}
custID = &newCustID
}
return *custID, nil
}
var (
dummyCustomerMu sync.Mutex
dummyCustomerID string
)
// GetDummyCustomerID returns the customer ID for a dummy customer that must be used only for
// pricing out invoices not associated with any particular customer. There is one dummy customer in
// the billing system that is used for all such operations (because the billing system requires
// providing a customer ID but we don't want to use any actual customer's ID).
//
// The first time this func is called, it looks up the ID of the existing dummy customer in the
// billing system and returns that if one exists (to avoid creating many dummy customer records). If
// the dummy customer doesn't exist yet, it is automatically created.
func GetDummyCustomerID(ctx context.Context) (string, error) {
dummyCustomerMu.Lock()
defer dummyCustomerMu.Unlock()
if dummyCustomerID == "" {
// Look up dummy customer.
const dummyCustomerEmail = "dummy-customer@example.com"
listParams := &stripe.CustomerListParams{
ListParams: stripe.ListParams{Context: ctx},
}
listParams.Filters.AddFilter("email", "", dummyCustomerEmail)
listParams.Limit = stripe.Int64(1)
customers := customer.List(listParams)
if err := customers.Err(); err != nil {
return "", err
}
if customers.Next() {
dummyCustomerID = customers.Customer().ID
} else {
// No dummy customer exists yet, so create it. Future calls to GetDummyCustomerID will reuse the dummy customer.
params := &stripe.CustomerParams{
Params: stripe.Params{Context: ctx},
Email: stripe.String(dummyCustomerEmail),
Description: stripe.String("DUMMY (only used for generating quotes for unauthenticated viewers)"),
}
cust, err := customer.New(params)
if err != nil {
return "", err
}
dummyCustomerID = cust.ID
}
}
return dummyCustomerID, nil
}
var mockCreateCustomerID func(userID int32) (string, error)
// createCustomerID creates a customer record on the billing system and returns the customer ID of
// the new record.
func createCustomerID(ctx context.Context, userID int32) (string, error) {
if mockCreateCustomerID != nil {
return mockCreateCustomerID(userID)
}
user, err := graphqlbackend.UserByIDInt32(ctx, userID)
if err != nil {
return "", err
}
custParams := &stripe.CustomerParams{
Params: stripe.Params{Context: ctx},
Description: stripe.String(fmt.Sprintf("%s (%d)", user.Username(), user.SourcegraphID())),
}
// Use the user's first verified email (if any).
emails, err := user.Emails(ctx)
if err != nil {
return "", err
}
for _, email := range emails {
if email.Verified() {
custParams.Email = stripe.String(email.Email())
break
}
}
// Create the billing customer.
cust, err := customer.New(custParams)
if err != nil {
return "", err
}
return cust.ID, nil
}
// deleteCustomerID deletes the customer record on the billing system.
func deleteCustomerID(ctx context.Context, customerID string) error {
// For simplicity of tests, just noop if the mockCreateCustomerID is set.
if mockCreateCustomerID != nil {
return nil
}
_, err := customer.Del(customerID, &stripe.CustomerParams{Params: stripe.Params{Context: ctx}})
return err
}

View File

@ -0,0 +1,53 @@
package billing
import (
"context"
"github.com/sourcegraph/sourcegraph/cmd/frontend/backend"
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
stripe "github.com/stripe/stripe-go"
"github.com/stripe/stripe-go/customer"
)
func init() {
graphqlbackend.UserURLForSiteAdminBilling = func(ctx context.Context, userID int32) (*string, error) {
// 🚨 SECURITY: Only site admins may view the billing URL, because it may contain sensitive
// data or identifiers.
if err := backend.CheckCurrentUserIsSiteAdmin(ctx); err != nil {
return nil, err
}
custID, err := dbBilling{}.getUserBillingCustomerID(ctx, nil, userID)
if err != nil {
return nil, err
}
if custID != nil {
u := CustomerURL(*custID)
return &u, nil
}
return nil, nil
}
}
func (BillingResolver) SetUserBilling(ctx context.Context, args *graphqlbackend.SetUserBillingArgs) (*graphqlbackend.EmptyResponse, error) {
// 🚨 SECURITY: Only site admins may set a user's billing info.
if err := backend.CheckCurrentUserIsSiteAdmin(ctx); err != nil {
return nil, err
}
userID, err := graphqlbackend.UnmarshalUserID(args.User)
if err != nil {
return nil, err
}
// Ensure the billing customer ID refers to a valid customer in the billing system.
if args.BillingCustomerID != nil {
if _, err := customer.Get(*args.BillingCustomerID, &stripe.CustomerParams{Params: stripe.Params{Context: ctx}}); err != nil {
return nil, err
}
}
if err := (dbBilling{}).setUserBillingCustomerID(ctx, nil, userID, args.BillingCustomerID); err != nil {
return nil, err
}
return &graphqlbackend.EmptyResponse{}, nil
}

View File

@ -0,0 +1,45 @@
package billing
import (
"fmt"
"testing"
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
dbtesting "github.com/sourcegraph/sourcegraph/cmd/frontend/db/testing"
)
func TestGetOrAssignUserCustomerID(t *testing.T) {
ctx := dbtesting.TestContext(t)
c := 0
mockCreateCustomerID = func(userID int32) (string, error) {
c++
return fmt.Sprintf("cust%d", c), nil
}
defer func() { mockCreateCustomerID = nil }()
u, err := db.Users.Create(ctx, db.NewUser{Username: "u"})
if err != nil {
t.Fatal(err)
}
t.Run("assigns and retrieves", func(t *testing.T) {
custID1, err := GetOrAssignUserCustomerID(ctx, u.ID)
if err != nil {
t.Fatal(err)
}
custID2, err := GetOrAssignUserCustomerID(ctx, u.ID)
if err != nil {
t.Fatal(err)
}
if custID2 != custID1 {
t.Errorf("got custID %q, want %q", custID2, custID2)
}
})
t.Run("fails on nonexistent users", func(t *testing.T) {
if _, err := GetOrAssignUserCustomerID(ctx, 123 /* no such user */); err == nil {
t.Fatal("err == nil")
}
})
}

View File

@ -0,0 +1,65 @@
package billing
import (
"context"
"database/sql"
"github.com/keegancsmith/sqlf"
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
"github.com/sourcegraph/sourcegraph/cmd/frontend/db/dbconn"
)
// dbBilling provides billing-related database operations.
type dbBilling struct{}
// getUserBillingCustomerID gets the billing customer ID (if any) for a user.
//
// If a transaction tx is provided, the query is executed using the transaction. If tx is nil, the
// global DB handle is used.
func (dbBilling) getUserBillingCustomerID(ctx context.Context, tx *sql.Tx, userID int32) (billingCustomerID *string, err error) {
var dbh dbh
if tx != nil {
dbh = tx
} else {
dbh = dbconn.Global
}
query := sqlf.Sprintf("SELECT billing_customer_id FROM users WHERE id=%d AND deleted_at IS NULL", userID)
err = dbh.QueryRowContext(ctx, query.Query(sqlf.PostgresBindVar), query.Args()...).Scan(&billingCustomerID)
if err == sql.ErrNoRows {
return nil, db.NewUserNotFoundError(userID)
}
return billingCustomerID, err
}
// setUserBillingCustomerID sets or unsets the billing customer ID for a user.
//
// If a transaction tx is provided, the query is executed using the transaction. If tx is nil, the
// global DB handle is used.
func (dbBilling) setUserBillingCustomerID(ctx context.Context, tx *sql.Tx, userID int32, billingCustomerID *string) error {
var dbh dbh
if tx != nil {
dbh = tx
} else {
dbh = dbconn.Global
}
query := sqlf.Sprintf("UPDATE users SET billing_customer_id=%s WHERE id=%d AND deleted_at IS NULL", billingCustomerID, userID)
res, err := dbh.ExecContext(ctx, query.Query(sqlf.PostgresBindVar), query.Args()...)
if err != nil {
return err
}
nrows, err := res.RowsAffected()
if err != nil {
return err
}
if nrows == 0 {
return db.NewUserNotFoundError(userID)
}
return nil
}
type dbh interface {
QueryRowContext(context.Context, string, ...interface{}) *sql.Row
ExecContext(context.Context, string, ...interface{}) (sql.Result, error)
}

View File

@ -0,0 +1,66 @@
package billing
import (
"testing"
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
dbtesting "github.com/sourcegraph/sourcegraph/cmd/frontend/db/testing"
"github.com/sourcegraph/sourcegraph/pkg/errcode"
)
func init() {
dbtesting.DBNameSuffix = "billing"
}
func TestDBUsersBillingCustomerID(t *testing.T) {
ctx := dbtesting.TestContext(t)
t.Run("existing user", func(t *testing.T) {
u, err := db.Users.Create(ctx, db.NewUser{Username: "u"})
if err != nil {
t.Fatal(err)
}
if custID, err := (dbBilling{}).getUserBillingCustomerID(ctx, nil, u.ID); err != nil {
t.Fatal(err)
} else if custID != nil {
t.Errorf("got %q, want nil", *custID)
}
t.Run("set to non-nil", func(t *testing.T) {
if err := (dbBilling{}).setUserBillingCustomerID(ctx, nil, u.ID, strptr("x")); err != nil {
t.Fatal(err)
}
if custID, err := (dbBilling{}).getUserBillingCustomerID(ctx, nil, u.ID); err != nil {
t.Fatal(err)
} else if want := "x"; custID == nil || *custID != want {
t.Errorf("got %v, want %q", custID, want)
}
})
t.Run("set to nil", func(t *testing.T) {
if err := (dbBilling{}).setUserBillingCustomerID(ctx, nil, u.ID, nil); err != nil {
t.Fatal(err)
}
if custID, err := (dbBilling{}).getUserBillingCustomerID(ctx, nil, u.ID); err != nil {
t.Fatal(err)
} else if custID != nil {
t.Errorf("got %q, want nil", *custID)
}
})
})
t.Run("nonexistent user", func(t *testing.T) {
if _, err := (dbBilling{}).getUserBillingCustomerID(ctx, nil, 123 /* doesn't exist */); !errcode.IsNotFound(err) {
t.Errorf("got %v, want errcode.IsNotFound(err) == true", err)
}
if err := (dbBilling{}).setUserBillingCustomerID(ctx, nil, 123 /* doesn't exist */, strptr("x")); !errcode.IsNotFound(err) {
t.Errorf("got %v, want errcode.IsNotFound(err) == true", err)
}
})
}
func strptr(s string) *string {
return &s
}

View File

@ -0,0 +1,2 @@
// Package billing handles subscription billing on Sourcegraph.com (via Stripe).
package billing

View File

@ -0,0 +1,97 @@
package billing
import (
"fmt"
"strconv"
"strings"
"time"
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
stripe "github.com/stripe/stripe-go"
)
// productSubscriptionEvent implements the GraphQL type ProductSubscriptionEvent.
type productSubscriptionEvent struct {
v *stripe.Event
}
// ToProductSubscriptionEvent returns a resolver for the GraphQL type ProductSubscriptionEvent from
// the given billing event.
//
// The okToShowUser return value reports whether the event should be shown to the user. It is false
// for internal events (e.g., an invoice being marked uncollectible).
func ToProductSubscriptionEvent(event *stripe.Event) (gqlEvent graphqlbackend.ProductSubscriptionEvent, okToShowUser bool) {
_, _, okToShowUser = getProductSubscriptionEventInfo(event)
return &productSubscriptionEvent{v: event}, okToShowUser
}
// getProductSubscriptionEventInfo returns a nice title and description for the event. See
// ToProductSubscriptionEvent for information about the okToShowUser return value.
func getProductSubscriptionEventInfo(v *stripe.Event) (title, description string, okToShowUser bool) {
switch v.Type {
case "charge.succeeded":
title = "Charge succeeded"
okToShowUser = true
case "invoice.created":
title = "Invoice created"
okToShowUser = true
case "invoice.payment_succeeded":
title = "Invoice payment succeeded"
description = fmt.Sprintf("An invoice payment of %s succeeded.", usdCentsToString(v.GetObjectValue("amount_paid")))
okToShowUser = true
case "invoice.payment_failed":
title = "Invoice payment failed"
description = fmt.Sprintf("An invoice payment of %s failed.", usdCentsToString(v.GetObjectValue("amount_paid")))
okToShowUser = true
case "invoice.sent":
title = "Invoice email sent"
okToShowUser = true
case "invoice.updated":
title = "Invoice updated"
okToShowUser = true
default:
title = v.Type
}
return title, description, okToShowUser
}
func usdCentsToString(s string) string {
// TODO(sqs): use a real currency lib
usdCents, err := strconv.ParseFloat(s, 64)
if err != nil {
return "unknown amount"
}
return fmt.Sprintf("$%.2f", usdCents/100)
}
func (r *productSubscriptionEvent) ID() string { return r.v.ID }
func (r *productSubscriptionEvent) Date() string {
return time.Unix(r.v.Created, 0).Format(time.RFC3339)
}
func (r *productSubscriptionEvent) Title() string {
title, _, _ := getProductSubscriptionEventInfo(r.v)
return title
}
func (r *productSubscriptionEvent) Description() *string {
_, description, _ := getProductSubscriptionEventInfo(r.v)
if description == "" {
return nil
}
return &description
}
func (r *productSubscriptionEvent) URL() *string {
var u string
if strings.HasPrefix(r.v.Type, "invoice.") {
u = r.v.GetObjectValue("hosted_invoice_url")
}
if u == "" {
return nil
}
return &u
}

View File

@ -0,0 +1,4 @@
package billing
// BillingResolver implements the GraphQL Query and Mutation fields related to billing.
type BillingResolver struct{}

View File

@ -0,0 +1,38 @@
package billing
import (
"context"
"fmt"
"github.com/sourcegraph/enterprise/cmd/frontend/internal/licensing"
"github.com/sourcegraph/enterprise/pkg/license"
stripe "github.com/stripe/stripe-go"
"github.com/stripe/stripe-go/plan"
)
// InfoForProductPlan returns the license key tags and min quantity that should be used for the
// given product plan.
//
// License key tags indicate which product variant (e.g., Enterprise vs. Enterprise Starter), so
// they are stored on the billing system in the metadata of the product plans.
func InfoForProductPlan(ctx context.Context, planID string) (licenseTags []string, minQuantity *int32, err error) {
params := &stripe.PlanParams{Params: stripe.Params{Context: ctx}}
params.AddExpand("product")
plan, err := plan.Get(planID, params)
if err != nil {
return nil, nil, err
}
var tags []string
switch {
case plan.Product.Metadata["licenseTags"] != "":
tags = license.ParseTagsInput(plan.Product.Metadata["licenseTags"])
case plan.Product.Name == "Enterprise Starter":
tags = licensing.EnterpriseStarterTags
case plan.Product.Name == "Enterprise":
tags = licensing.EnterpriseTags
default:
return nil, nil, fmt.Errorf("unable to determine license tags for plan %q (nickname %q)", planID, plan.Nickname)
}
return tags, ProductPlanMinQuantity(plan), nil
}

View File

@ -0,0 +1,119 @@
package billing
import (
"context"
"fmt"
"sort"
"strconv"
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
stripe "github.com/stripe/stripe-go"
"github.com/stripe/stripe-go/plan"
)
// productPlan implements the GraphQL type ProductPlan.
type productPlan struct {
billingPlanID string
productPlanID string
name string
pricePerUserPerYear int32
minQuantity *int32
tiersMode string
planTiers []graphqlbackend.PlanTier
}
// planTier implements the GraphQL type PlanTier.
type planTier struct {
unitAmount int64
upTo int64
}
func (r *productPlan) ProductPlanID() string { return r.productPlanID }
func (r *productPlan) BillingPlanID() string { return r.billingPlanID }
func (r *productPlan) Name() string { return r.name }
func (r *productPlan) NameWithBrand() string { return "Sourcegraph " + r.name }
func (r *productPlan) PricePerUserPerYear() int32 { return r.pricePerUserPerYear }
func (r *productPlan) MinQuantity() *int32 { return r.minQuantity }
func (r *productPlan) TiersMode() string { return r.tiersMode }
func (r *productPlan) PlanTiers() []graphqlbackend.PlanTier {
if r.planTiers == nil {
return nil
}
return r.planTiers
}
func (r *planTier) UnitAmount() int32 { return int32(r.unitAmount) }
func (r *planTier) UpTo() int32 { return int32(r.upTo) }
// ToProductPlan returns a resolver for the GraphQL type ProductPlan from the given billing plan.
func ToProductPlan(plan *stripe.Plan) (graphqlbackend.ProductPlan, error) {
// Sanity check.
if plan.Product.Name == "" {
return nil, fmt.Errorf("unexpected empty product name for plan %q", plan.ID)
}
if plan.Currency != stripe.CurrencyUSD {
return nil, fmt.Errorf("unexpected currency %q for plan %q", plan.Currency, plan.ID)
}
if plan.Interval != stripe.PlanIntervalYear {
return nil, fmt.Errorf("unexpected plan interval %q for plan %q", plan.Interval, plan.ID)
}
if plan.IntervalCount != 1 {
return nil, fmt.Errorf("unexpected plan interval count %d for plan %q", plan.IntervalCount, plan.ID)
}
var tiers []graphqlbackend.PlanTier
for _, tier := range plan.Tiers {
tiers = append(tiers, &planTier{
unitAmount: tier.UnitAmount,
upTo: tier.UpTo,
})
}
return &productPlan{
productPlanID: plan.Product.ID,
billingPlanID: plan.ID,
name: plan.Product.Name,
pricePerUserPerYear: int32(plan.Amount),
minQuantity: ProductPlanMinQuantity(plan),
planTiers: tiers,
tiersMode: plan.TiersMode,
}, nil
}
// ProductPlanMinQuantity returns the plan's product's minQuantity metadata value, or nil if there
// is none.
func ProductPlanMinQuantity(plan *stripe.Plan) *int32 {
if v, err := strconv.Atoi(plan.Product.Metadata["minQuantity"]); err == nil {
tmp := int32(v)
return &tmp
}
return nil
}
// ProductPlans implements the GraphQL field Query.dotcom.productPlans.
func (BillingResolver) ProductPlans(ctx context.Context) ([]graphqlbackend.ProductPlan, error) {
params := &stripe.PlanListParams{
ListParams: stripe.ListParams{Context: ctx},
Active: stripe.Bool(true),
}
params.AddExpand("data.product")
plans := plan.List(params)
var gqlPlans []graphqlbackend.ProductPlan
for plans.Next() {
gqlPlan, err := ToProductPlan(plans.Plan())
if err != nil {
return nil, err
}
gqlPlans = append(gqlPlans, gqlPlan)
}
if err := plans.Err(); err != nil {
return nil, err
}
// Sort cheapest first (a reasonable assumption).
sort.Slice(gqlPlans, func(i, j int) bool {
return gqlPlans[i].PricePerUserPerYear() < gqlPlans[j].PricePerUserPerYear()
})
return gqlPlans, nil
}

View File

@ -0,0 +1,55 @@
package billing
import (
"log"
"strings"
"github.com/sourcegraph/sourcegraph/cmd/frontend/external/app"
"github.com/sourcegraph/sourcegraph/pkg/env"
stripe "github.com/stripe/stripe-go"
)
var (
stripeSecretKey = env.Get("STRIPE_SECRET_KEY", "", "billing: Stripe API secret key")
stripePublishableKey = env.Get("STRIPE_PUBLISHABLE_KEY", "", "billing: Stripe API publishable key")
stripeWebhookSecret = env.Get("STRIPE_WEBHOOK_SECRET", "", "billing: Stripe webhook secret")
)
func init() {
// Sanity-check the Stripe keys (to help prevent mistakes where they got switched and the secret
// key is published).
if stripeSecretKey != "" && !strings.HasPrefix(stripeSecretKey, "sk_") {
log.Fatal(`Invalid STRIPE_SECRET_KEY (must begin with "sk_").`)
}
if stripePublishableKey != "" && !strings.HasPrefix(stripePublishableKey, "pk_") {
log.Fatal(`Invalid STRIPE_PUBLISHABLE_KEY (must begin with "pk_").`)
}
if (stripeSecretKey != "") != (stripePublishableKey != "") {
log.Fatalf("Either zero or both of STRIPE_SECRET_KEY (set=%v) and STRIPE_PUBLISHABLE_KEY (set=%v) must be set.", stripeSecretKey != "", stripePublishableKey != "")
}
stripe.Key = stripeSecretKey
app.SetBillingPublishableKey(stripePublishableKey)
}
func isTest() bool {
return strings.Contains(stripe.Key, "_test_")
}
func baseURL() string {
u := "https://dashboard.stripe.com"
if isTest() {
u += "/test"
}
return u
}
// CustomerURL returns the URL to the customer with the given ID on the billing system.
func CustomerURL(id string) string {
return baseURL() + "/customers/" + id
}
// SubscriptionURL returns the URL to the subscription with the given ID on the billing system.
func SubscriptionURL(id string) string {
return baseURL() + "/subscriptions/" + id
}

View File

@ -0,0 +1,31 @@
package billing
import (
"errors"
"fmt"
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
stripe "github.com/stripe/stripe-go"
)
// ToSubscriptionItemsParams converts a value of GraphQL type ProductSubscriptionInput into a
// subscription item parameter for the billing system.
func ToSubscriptionItemsParams(input graphqlbackend.ProductSubscriptionInput) *stripe.SubscriptionItemsParams {
return &stripe.SubscriptionItemsParams{
Plan: stripe.String(input.BillingPlanID),
Quantity: stripe.Int64(int64(input.UserCount)),
}
}
// GetSubscriptionItemIDToReplace returns the ID of the billing subscription item (used when
// updating the subscription or previewing an invoice to do so). It also performs a good set of
// sanity checks on the subscription that should be performed whenever the subscription is updated.
func GetSubscriptionItemIDToReplace(billingSub *stripe.Subscription, billingCustomerID string) (string, error) {
if billingSub.Customer.ID != billingCustomerID {
return "", errors.New("product subscription's billing customer does not match the provided account parameter")
}
if len(billingSub.Items.Data) != 1 {
return "", fmt.Errorf("product subscription has unexpected number of invoice items (got %d, want 1)", len(billingSub.Items.Data))
}
return billingSub.Items.Data[0].ID, nil
}

View File

@ -0,0 +1,56 @@
package billing
import (
"context"
"encoding/json"
"io/ioutil"
"net/http"
stripe "github.com/stripe/stripe-go"
"github.com/stripe/stripe-go/webhook"
log15 "gopkg.in/inconshreveable/log15.v2"
)
// handleWebhook handles HTTP requests containing webhook payloads about billing-related events from
// the billing system.
func handleWebhook(w http.ResponseWriter, r *http.Request) {
body, err := ioutil.ReadAll(r.Body)
if err != nil {
http.Error(w, err.Error(), http.StatusServiceUnavailable)
return
}
// Check the signature to verify the HTTP request came from the billing system.
event, err := webhook.ConstructEvent(body, r.Header.Get("Stripe-Signature"), stripeWebhookSecret)
if err != nil {
// Parse out some of the event for logging.
var event struct {
ID string `json:"id"`
Type string `json:"type"`
}
_ = json.Unmarshal(body, &event)
log15.Error("Billing webhook received request with invalid signature.", "idUnverified", event.ID, "typeUnverified", event.Type, "err", err)
http.Error(w, "billing event signature is invalid", http.StatusBadRequest)
return
}
log15.Info("Billing webhook received event.", "id", event.ID, "type", event.Type)
if err := handleEvent(r.Context(), event); err != nil {
log15.Error("Billing webhook failed to handle event.", "id", event.ID, "type", event.Type, "err", err)
http.Error(w, "billing event handler error", http.StatusInternalServerError)
return
}
w.WriteHeader(http.StatusOK)
}
// handleEvent handles a billing event (received via webhook).
//
// TODO(sqs): implement this so we can create invoices instead of only being able to accept
// immediate payment.
func handleEvent(ctx context.Context, event stripe.Event) error {
switch event.Type {
case "invoice.payment_succeeded":
// noop
}
return nil
}

View File

@ -0,0 +1,7 @@
package productsubscription
import dbtesting "github.com/sourcegraph/sourcegraph/cmd/frontend/db/testing"
func init() {
dbtesting.DBNameSuffix = "productsubscription"
}

View File

@ -0,0 +1,2 @@
// Package productsubscription handles product subscriptions and licensing.
package productsubscription

View File

@ -0,0 +1,5 @@
package productsubscription
// ProductSubscriptionLicensingResolver implements the GraphQL Query and Mutation fields related to product
// subscriptions and licensing.
type ProductSubscriptionLicensingResolver struct{}

View File

@ -0,0 +1,190 @@
package productsubscription
import (
"context"
"time"
"github.com/pkg/errors"
"github.com/sourcegraph/enterprise/cmd/frontend/internal/dotcom/billing"
"github.com/sourcegraph/sourcegraph/cmd/frontend/backend"
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
stripe "github.com/stripe/stripe-go"
"github.com/stripe/stripe-go/invoice"
"github.com/stripe/stripe-go/plan"
"github.com/stripe/stripe-go/sub"
)
type productSubscriptionPreviewInvoice struct {
price int32
amountDue int32
prorationDate *int64
before, after *productSubscriptionInvoiceItem
}
func (r *productSubscriptionPreviewInvoice) Price() int32 { return r.price }
func (r *productSubscriptionPreviewInvoice) AmountDue() int32 { return r.amountDue }
func (r *productSubscriptionPreviewInvoice) ProrationDate() *string {
if v := r.prorationDate; v != nil {
s := time.Unix(*v, 0).Format(time.RFC3339)
return &s
}
return nil
}
func (r *productSubscriptionPreviewInvoice) IsDowngradeRequiringManualIntervention() bool {
return r.before != nil && isDowngradeRequiringManualIntervention(r.before.userCount, r.before.plan.Amount, r.after.userCount, r.after.plan.Amount)
}
func isDowngradeRequiringManualIntervention(beforeUserCount int32, beforePlanPrice int64, afterUserCount int32, afterPlanPrice int64) bool {
return afterUserCount < beforeUserCount || afterPlanPrice < beforePlanPrice
}
func (r *productSubscriptionPreviewInvoice) BeforeInvoiceItem() graphqlbackend.ProductSubscriptionInvoiceItem {
if r.before == nil {
return nil // untyped nil is necessary for graphql-go
}
return r.before
}
func (r *productSubscriptionPreviewInvoice) AfterInvoiceItem() graphqlbackend.ProductSubscriptionInvoiceItem {
return r.after
}
func (ProductSubscriptionLicensingResolver) PreviewProductSubscriptionInvoice(ctx context.Context, args *graphqlbackend.PreviewProductSubscriptionInvoiceArgs) (graphqlbackend.ProductSubscriptionPreviewInvoice, error) {
// Support previewing an invoice with or without a customer ID.
var custID string
var accountUserID *int32
if args.Account != nil {
// There is a customer ID given.
accountUser, err := graphqlbackend.UserByID(ctx, *args.Account)
if err != nil {
return nil, err
}
tmp := accountUser.SourcegraphID()
accountUserID = &tmp
custID, err = billing.GetOrAssignUserCustomerID(ctx, *accountUserID)
if err != nil {
return nil, err
}
// 🚨 SECURITY: Users may only preview invoices for their own product subscriptions. Site admins
// may preview invoices for all product subscriptions.
if err := backend.CheckSiteAdminOrSameUser(ctx, *accountUserID); err != nil {
return nil, err
}
} else {
// Support previewing an invoice without a customer ID, for unauthenticated viewers who just want
// to see the price.
if args.SubscriptionToUpdate != nil {
return nil, errors.New("missing account ID argument (must be the owner of the subscriptionToUpdate)")
}
var err error
custID, err = billing.GetDummyCustomerID(ctx)
if err != nil {
return nil, err
}
}
// Get the "before" subscription invoice item.
planParams := &stripe.PlanParams{Params: stripe.Params{Context: ctx}}
planParams.AddExpand("product")
plan, err := plan.Get(args.ProductSubscription.BillingPlanID, planParams)
if err != nil {
return nil, err
}
if minQuantity := billing.ProductPlanMinQuantity(plan); minQuantity != nil && args.ProductSubscription.UserCount < *minQuantity {
args.ProductSubscription.UserCount = *minQuantity
}
result := productSubscriptionPreviewInvoice{
after: &productSubscriptionInvoiceItem{
plan: plan,
userCount: args.ProductSubscription.UserCount,
// The expiresAt field will be set below, not here, because its value depends on whether
// this is a new vs. updated subscription.
},
}
params := &stripe.InvoiceParams{
Params: stripe.Params{Context: ctx},
Customer: stripe.String(custID),
SubscriptionItems: []*stripe.SubscriptionItemsParams{billing.ToSubscriptionItemsParams(args.ProductSubscription)},
}
if args.SubscriptionToUpdate != nil {
// Update a subscription.
//
// When updating an existing subscription, craft the params to replace the existing subscription
// item (otherwise the invoice would include both the existing and updated subscription items).
subToUpdate, err := productSubscriptionByID(ctx, *args.SubscriptionToUpdate)
if err != nil {
return nil, err
}
// 🚨 SECURITY: Only site admins and the subscription's account owner may preview invoices
// for product subscriptions.
if err := backend.CheckSiteAdminOrSameUser(ctx, subToUpdate.v.UserID); err != nil {
return nil, err
}
// 🚨 SECURITY: Ensure that the subscription is owned by the account (i.e., that the
// parameters are internally consistent). These checks are redundant for site admins, but
// it's good to be robust against bugs.
if subToUpdate.v.UserID != *accountUserID {
return nil, errors.New("product subscription's account owner does not match the provided account parameter")
}
if subToUpdate.v.BillingSubscriptionID == nil {
return nil, errors.New("unable to get preview invoice for product subscription that has no associated billing information")
}
subParams := &stripe.SubscriptionParams{Params: stripe.Params{Context: ctx}}
subParams.AddExpand("plan.product")
billingSubToUpdate, err := sub.Get(*subToUpdate.v.BillingSubscriptionID, subParams)
if err != nil {
return nil, err
}
params.SubscriptionProrationDate = stripe.Int64(time.Now().Unix())
params.Subscription = stripe.String(*subToUpdate.v.BillingSubscriptionID)
params.SubscriptionProrate = stripe.Bool(true)
idToReplace, err := billing.GetSubscriptionItemIDToReplace(billingSubToUpdate, custID)
if err != nil {
return nil, err
}
params.SubscriptionItems[0].ID = stripe.String(idToReplace)
result.prorationDate = params.SubscriptionProrationDate
result.before = &productSubscriptionInvoiceItem{
plan: billingSubToUpdate.Plan,
userCount: int32(billingSubToUpdate.Quantity),
expiresAt: time.Unix(billingSubToUpdate.CurrentPeriodEnd, 0),
}
}
// Get the preview invoice.
invoice, err := invoice.GetNext(params)
if err != nil {
return nil, err
}
// Calculate the price and expiration.
for _, invoiceItem := range invoice.Lines.Data {
// When updating an existing subscription, only include invoice items that are affected by
// the update (== whose proration date is the same as the one we set on the update params).
if result.prorationDate != nil && invoiceItem.Period.Start != *result.prorationDate {
continue
}
result.price += int32(invoiceItem.Amount)
// Set the period end to the farthest ahead future invoice item's end date.
periodEnd := time.Unix(invoiceItem.Period.End, 0)
if periodEnd.After(result.after.expiresAt) {
result.after.expiresAt = periodEnd
}
}
// When there is no change (no new invoice lines), set the "after" state to expire at the same
// as the before state to indicate there is no change in the expiration either.
if result.after.expiresAt.IsZero() && result.before != nil {
result.after.expiresAt = result.before.expiresAt
}
return &result, nil
}

View File

@ -0,0 +1,145 @@
package productsubscription
import (
"context"
"time"
"github.com/google/uuid"
"github.com/keegancsmith/sqlf"
"github.com/pkg/errors"
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
"github.com/sourcegraph/sourcegraph/cmd/frontend/db/dbconn"
)
// dbLicense describes an product license row in the product_licenses DB table.
type dbLicense struct {
ID string // UUID
ProductSubscriptionID string // UUID
LicenseKey string
CreatedAt time.Time
}
// errLicenseNotFound occurs when a database operation expects a specific Sourcegraph
// license to exist but it does not exist.
var errLicenseNotFound = errors.New("product license not found")
// dbLicenses exposes product licenses in the product_licenses DB table.
type dbLicenses struct{}
// Create creates a new product license entry given a license key.
func (dbLicenses) Create(ctx context.Context, subscriptionID, licenseKey string) (id string, err error) {
if mocks.licenses.Create != nil {
return mocks.licenses.Create(subscriptionID, licenseKey)
}
uuid, err := uuid.NewRandom()
if err != nil {
return "", err
}
if err := dbconn.Global.QueryRowContext(ctx, `
INSERT INTO product_licenses(id, product_subscription_id, license_key) VALUES($1, $2, $3) RETURNING id
`,
uuid, subscriptionID, licenseKey,
).Scan(&id); err != nil {
return "", err
}
return id, nil
}
// GetByID retrieves the product license (if any) given its ID.
//
// 🚨 SECURITY: The caller must ensure that the actor is permitted to view this product license.
func (s dbLicenses) GetByID(ctx context.Context, id string) (*dbLicense, error) {
if mocks.licenses.GetByID != nil {
return mocks.licenses.GetByID(id)
}
results, err := s.list(ctx, []*sqlf.Query{sqlf.Sprintf("id=%s", id)}, nil)
if err != nil {
return nil, err
}
if len(results) == 0 {
return nil, errLicenseNotFound
}
return results[0], nil
}
// GetByID retrieves the product license (if any) given its license key.
func (s dbLicenses) GetByLicenseKey(ctx context.Context, licenseKey string) (*dbLicense, error) {
if mocks.licenses.GetByLicenseKey != nil {
return mocks.licenses.GetByLicenseKey(licenseKey)
}
results, err := s.list(ctx, []*sqlf.Query{sqlf.Sprintf("license_key=%s", licenseKey)}, nil)
if err != nil {
return nil, err
}
if len(results) == 0 {
return nil, errLicenseNotFound
}
return results[0], nil
}
// dbLicensesListOptions contains options for listing product licenses.
type dbLicensesListOptions struct {
LicenseKeySubstring string
ProductSubscriptionID string // only list product licenses for this subscription (by UUID)
*db.LimitOffset
}
func (o dbLicensesListOptions) sqlConditions() []*sqlf.Query {
conds := []*sqlf.Query{sqlf.Sprintf("TRUE")}
if o.LicenseKeySubstring != "" {
conds = append(conds, sqlf.Sprintf("license_key LIKE %s", "%"+o.LicenseKeySubstring+"%"))
}
if o.ProductSubscriptionID != "" {
conds = append(conds, sqlf.Sprintf("product_subscription_id=%s", o.ProductSubscriptionID))
}
return conds
}
// List lists all product licenses that satisfy the options.
func (s dbLicenses) List(ctx context.Context, opt dbLicensesListOptions) ([]*dbLicense, error) {
return s.list(ctx, opt.sqlConditions(), opt.LimitOffset)
}
func (dbLicenses) list(ctx context.Context, conds []*sqlf.Query, limitOffset *db.LimitOffset) ([]*dbLicense, error) {
q := sqlf.Sprintf(`
SELECT id, product_subscription_id, license_key, created_at FROM product_licenses
WHERE (%s)
ORDER BY created_at DESC
%s`,
sqlf.Join(conds, ") AND ("),
limitOffset.SQL(),
)
rows, err := dbconn.Global.QueryContext(ctx, q.Query(sqlf.PostgresBindVar), q.Args()...)
if err != nil {
return nil, err
}
defer rows.Close()
var results []*dbLicense
for rows.Next() {
var v dbLicense
if err := rows.Scan(&v.ID, &v.ProductSubscriptionID, &v.LicenseKey, &v.CreatedAt); err != nil {
return nil, err
}
results = append(results, &v)
}
return results, nil
}
// Count counts all product licenses that satisfy the options (ignoring limit and offset).
func (dbLicenses) Count(ctx context.Context, opt dbLicensesListOptions) (int, error) {
q := sqlf.Sprintf("SELECT COUNT(*) FROM product_licenses WHERE (%s)", sqlf.Join(opt.sqlConditions(), ") AND ("))
var count int
if err := dbconn.Global.QueryRowContext(ctx, q.Query(sqlf.PostgresBindVar), q.Args()...).Scan(&count); err != nil {
return 0, err
}
return count, nil
}
type mockLicenses struct {
Create func(subscriptionID, licenseKey string) (id string, err error)
GetByID func(id string) (*dbLicense, error)
GetByLicenseKey func(licenseKey string) (*dbLicense, error)
}

View File

@ -0,0 +1,127 @@
package productsubscription
import (
"testing"
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
dbtesting "github.com/sourcegraph/sourcegraph/cmd/frontend/db/testing"
)
func TestProductLicenses_Create(t *testing.T) {
ctx := dbtesting.TestContext(t)
u, err := db.Users.Create(ctx, db.NewUser{Username: "u"})
if err != nil {
t.Fatal(err)
}
ps0, err := dbSubscriptions{}.Create(ctx, u.ID)
if err != nil {
t.Fatal(err)
}
pl0, err := dbLicenses{}.Create(ctx, ps0, "k")
if err != nil {
t.Fatal(err)
}
got, err := dbLicenses{}.GetByID(ctx, pl0)
if err != nil {
t.Fatal(err)
}
if want := pl0; got.ID != want {
t.Errorf("got %v, want %v", got.ID, want)
}
if want := ps0; got.ProductSubscriptionID != want {
t.Errorf("got %v, want %v", got.ProductSubscriptionID, want)
}
if want := "k"; got.LicenseKey != want {
t.Errorf("got %q, want %q", got.LicenseKey, want)
}
ts, err := dbLicenses{}.List(ctx, dbLicensesListOptions{ProductSubscriptionID: ps0})
if err != nil {
t.Fatal(err)
}
if want := 1; len(ts) != want {
t.Errorf("got %d product licenses, want %d", len(ts), want)
}
ts, err = dbLicenses{}.List(ctx, dbLicensesListOptions{ProductSubscriptionID: "69da12d5-323c-4e42-9d44-cc7951639bca" /* invalid */})
if err != nil {
t.Fatal(err)
}
if want := 0; len(ts) != want {
t.Errorf("got %d product licenses, want %d", len(ts), want)
}
}
func TestProductLicenses_List(t *testing.T) {
if testing.Short() {
t.Skip()
}
ctx := dbtesting.TestContext(t)
u1, err := db.Users.Create(ctx, db.NewUser{Username: "u1"})
if err != nil {
t.Fatal(err)
}
ps0, err := dbSubscriptions{}.Create(ctx, u1.ID)
if err != nil {
t.Fatal(err)
}
ps1, err := dbSubscriptions{}.Create(ctx, u1.ID)
if err != nil {
t.Fatal(err)
}
_, err = dbLicenses{}.Create(ctx, ps0, "k")
if err != nil {
t.Fatal(err)
}
_, err = dbLicenses{}.Create(ctx, ps0, "n1")
if err != nil {
t.Fatal(err)
}
{
// List all product licenses.
ts, err := dbLicenses{}.List(ctx, dbLicensesListOptions{})
if err != nil {
t.Fatal(err)
}
if want := 2; len(ts) != want {
t.Errorf("got %d product licenses, want %d", len(ts), want)
}
count, err := dbLicenses{}.Count(ctx, dbLicensesListOptions{})
if err != nil {
t.Fatal(err)
}
if want := 2; count != want {
t.Errorf("got %d, want %d", count, want)
}
}
{
// List ps0's product licenses.
ts, err := dbLicenses{}.List(ctx, dbLicensesListOptions{ProductSubscriptionID: ps0})
if err != nil {
t.Fatal(err)
}
if want := 2; len(ts) != want {
t.Errorf("got %d product licenses, want %d", len(ts), want)
}
}
{
// List ps1's product licenses.
ts, err := dbLicenses{}.List(ctx, dbLicensesListOptions{ProductSubscriptionID: ps1})
if err != nil {
t.Fatal(err)
}
if want := 0; len(ts) != want {
t.Errorf("got %d product licenses, want %d", len(ts), want)
}
}
}

View File

@ -0,0 +1,201 @@
package productsubscription
import (
"context"
"sync"
"time"
graphql "github.com/graph-gophers/graphql-go"
"github.com/graph-gophers/graphql-go/relay"
"github.com/sourcegraph/enterprise/cmd/frontend/internal/licensing"
"github.com/sourcegraph/enterprise/pkg/license"
"github.com/sourcegraph/sourcegraph/cmd/frontend/backend"
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend/graphqlutil"
)
func init() {
graphqlbackend.ProductLicenseByID = func(ctx context.Context, id graphql.ID) (graphqlbackend.ProductLicense, error) {
return productLicenseByID(ctx, id)
}
}
// productLicense implements the GraphQL type ProductLicense.
type productLicense struct {
v *dbLicense
}
// productLicenseByID looks up and returns the ProductLicense with the given GraphQL ID. If no such
// ProductLicense exists, it returns a non-nil error.
func productLicenseByID(ctx context.Context, id graphql.ID) (*productLicense, error) {
idInt32, err := unmarshalProductLicenseID(id)
if err != nil {
return nil, err
}
return productLicenseByDBID(ctx, idInt32)
}
// productLicenseByDBID looks up and returns the ProductLicense with the given database ID. If no
// such ProductLicense exists, it returns a non-nil error.
func productLicenseByDBID(ctx context.Context, id string) (*productLicense, error) {
v, err := dbLicenses{}.GetByID(ctx, id)
if err != nil {
return nil, err
}
// 🚨 SECURITY: Only site admins and the license's subscription's account's user may view a
// product license.
sub, err := productSubscriptionByDBID(ctx, v.ProductSubscriptionID)
if err != nil {
return nil, err
}
if err := backend.CheckSiteAdminOrSameUser(ctx, sub.v.UserID); err != nil {
return nil, err
}
return &productLicense{v: v}, nil
}
func (r *productLicense) ID() graphql.ID {
return marshalProductLicenseID(r.v.ID)
}
func marshalProductLicenseID(id string) graphql.ID {
return relay.MarshalID("ProductLicense", id)
}
func unmarshalProductLicenseID(id graphql.ID) (productLicenseID string, err error) {
err = relay.UnmarshalSpec(id, &productLicenseID)
return
}
func (r *productLicense) Subscription(ctx context.Context) (graphqlbackend.ProductSubscription, error) {
return productSubscriptionByDBID(ctx, r.v.ProductSubscriptionID)
}
func (r *productLicense) Info() (*graphqlbackend.ProductLicenseInfo, error) {
// Call this instead of licensing.ParseProductLicenseKey so that license info can be read from
// license keys generated using the test license generation private key.
info, err := licensing.ParseProductLicenseKeyWithBuiltinOrGenerationKey(r.v.LicenseKey)
if err != nil {
return nil, err
}
return &graphqlbackend.ProductLicenseInfo{
TagsValue: info.Tags,
UserCountValue: info.UserCount,
ExpiresAtValue: info.ExpiresAt,
}, nil
}
func (r *productLicense) LicenseKey() string { return r.v.LicenseKey }
func (r *productLicense) CreatedAt() string {
return r.v.CreatedAt.Format(time.RFC3339)
}
func generateProductLicenseForSubscription(ctx context.Context, subscriptionID string, input *graphqlbackend.ProductLicenseInput) (id string, err error) {
licenseKey, err := licensing.GenerateProductLicenseKey(license.Info{
Tags: input.Tags,
UserCount: uint(input.UserCount),
ExpiresAt: time.Unix(int64(input.ExpiresAt), 0),
})
if err != nil {
return "", err
}
return dbLicenses{}.Create(ctx, subscriptionID, licenseKey)
}
func (ProductSubscriptionLicensingResolver) GenerateProductLicenseForSubscription(ctx context.Context, args *graphqlbackend.GenerateProductLicenseForSubscriptionArgs) (graphqlbackend.ProductLicense, error) {
// 🚨 SECURITY: Only site admins may generate product licenses.
if err := backend.CheckCurrentUserIsSiteAdmin(ctx); err != nil {
return nil, err
}
sub, err := productSubscriptionByID(ctx, args.ProductSubscriptionID)
if err != nil {
return nil, err
}
id, err := generateProductLicenseForSubscription(ctx, sub.v.ID, args.License)
if err != nil {
return nil, err
}
return productLicenseByDBID(ctx, id)
}
func (ProductSubscriptionLicensingResolver) ProductLicenses(ctx context.Context, args *graphqlbackend.ProductLicensesArgs) (graphqlbackend.ProductLicenseConnection, error) {
// 🚨 SECURITY: Only site admins may list product licenses.
if err := backend.CheckCurrentUserIsSiteAdmin(ctx); err != nil {
return nil, err
}
var sub *productSubscription
if args.ProductSubscriptionID != nil {
var err error
sub, err = productSubscriptionByID(ctx, *args.ProductSubscriptionID)
if err != nil {
return nil, err
}
}
var opt dbLicensesListOptions
if sub != nil {
opt.ProductSubscriptionID = sub.v.ID
}
if args.LicenseKeySubstring != nil {
opt.LicenseKeySubstring = *args.LicenseKeySubstring
}
args.ConnectionArgs.Set(&opt.LimitOffset)
return &productLicenseConnection{opt: opt}, nil
}
// productLicenseConnection implements the GraphQL type ProductLicenseConnection.
//
// 🚨 SECURITY: When instantiating a productLicenseConnection value, the caller MUST
// check permissions.
type productLicenseConnection struct {
opt dbLicensesListOptions
// cache results because they are used by multiple fields
once sync.Once
results []*dbLicense
err error
}
func (r *productLicenseConnection) compute(ctx context.Context) ([]*dbLicense, error) {
r.once.Do(func() {
opt2 := r.opt
if opt2.LimitOffset != nil {
tmp := *opt2.LimitOffset
opt2.LimitOffset = &tmp
opt2.Limit++ // so we can detect if there is a next page
}
r.results, r.err = dbLicenses{}.List(ctx, opt2)
})
return r.results, r.err
}
func (r *productLicenseConnection) Nodes(ctx context.Context) ([]graphqlbackend.ProductLicense, error) {
results, err := r.compute(ctx)
if err != nil {
return nil, err
}
var l []graphqlbackend.ProductLicense
for _, result := range results {
l = append(l, &productLicense{v: result})
}
return l, nil
}
func (r *productLicenseConnection) TotalCount(ctx context.Context) (int32, error) {
count, err := dbLicenses{}.Count(ctx, r.opt)
return int32(count), err
}
func (r *productLicenseConnection) PageInfo(ctx context.Context) (*graphqlutil.PageInfo, error) {
results, err := r.compute(ctx)
if err != nil {
return nil, err
}
return graphqlutil.HasNextPage(r.opt.LimitOffset != nil && len(results) > r.opt.Limit), nil
}

View File

@ -0,0 +1,12 @@
package productsubscription
func resetMocks() {
mocks = dbMocks{}
}
type dbMocks struct {
subscriptions mockSubscriptions
licenses mockLicenses
}
var mocks dbMocks

View File

@ -0,0 +1,49 @@
package productsubscription
import (
"context"
"time"
"github.com/sourcegraph/enterprise/cmd/frontend/internal/dotcom/billing"
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
stripe "github.com/stripe/stripe-go"
"github.com/stripe/stripe-go/sub"
)
func (r *productSubscription) InvoiceItem(ctx context.Context) (graphqlbackend.ProductSubscriptionInvoiceItem, error) {
if r.v.BillingSubscriptionID == nil {
return nil, nil
}
params := &stripe.SubscriptionParams{Params: stripe.Params{Context: ctx}}
params.AddExpand("plan.product")
billingSub, err := sub.Get(*r.v.BillingSubscriptionID, params)
if err != nil {
return nil, err
}
return &productSubscriptionInvoiceItem{
plan: billingSub.Plan,
userCount: int32(billingSub.Quantity),
expiresAt: time.Unix(billingSub.CurrentPeriodEnd, 0),
}, nil
}
type productSubscriptionInvoiceItem struct {
plan *stripe.Plan
userCount int32
expiresAt time.Time
}
var _ graphqlbackend.ProductSubscriptionInvoiceItem = &productSubscriptionInvoiceItem{}
func (r *productSubscriptionInvoiceItem) Plan() (graphqlbackend.ProductPlan, error) {
return billing.ToProductPlan(r.plan)
}
func (r *productSubscriptionInvoiceItem) UserCount() int32 {
return r.userCount
}
func (r *productSubscriptionInvoiceItem) ExpiresAt() string {
return r.expiresAt.Format(time.RFC3339)
}

View File

@ -0,0 +1,186 @@
package productsubscription
import (
"context"
"database/sql"
"time"
"github.com/google/uuid"
"github.com/keegancsmith/sqlf"
"github.com/pkg/errors"
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
"github.com/sourcegraph/sourcegraph/cmd/frontend/db/dbconn"
)
// dbSubscription describes an product subscription row in the product_subscriptions DB
// table.
type dbSubscription struct {
ID string // UUID
UserID int32
BillingSubscriptionID *string // this subscription's ID in the billing system
CreatedAt time.Time
ArchivedAt *time.Time
}
// errSubscriptionNotFound occurs when a database operation expects a specific Sourcegraph
// license to exist but it does not exist.
var errSubscriptionNotFound = errors.New("product subscription not found")
// dbSubscriptions exposes product subscriptions in the product_subscriptions DB table.
type dbSubscriptions struct{}
// Create creates a new product subscription entry given a license key.
func (dbSubscriptions) Create(ctx context.Context, userID int32) (id string, err error) {
if mocks.subscriptions.Create != nil {
return mocks.subscriptions.Create(userID)
}
uuid, err := uuid.NewRandom()
if err != nil {
return "", err
}
if err := dbconn.Global.QueryRowContext(ctx, `
INSERT INTO product_subscriptions(id, user_id) VALUES($1, $2) RETURNING id
`,
uuid, userID,
).Scan(&id); err != nil {
return "", err
}
return id, nil
}
// GetByID retrieves the product subscription (if any) given its ID.
//
// 🚨 SECURITY: The caller must ensure that the actor is permitted to view this product subscription.
func (s dbSubscriptions) GetByID(ctx context.Context, id string) (*dbSubscription, error) {
if mocks.subscriptions.GetByID != nil {
return mocks.subscriptions.GetByID(id)
}
results, err := s.list(ctx, []*sqlf.Query{sqlf.Sprintf("id=%s", id)}, nil)
if err != nil {
return nil, err
}
if len(results) == 0 {
return nil, errSubscriptionNotFound
}
return results[0], nil
}
// dbSubscriptionsListOptions contains options for listing product subscriptions.
type dbSubscriptionsListOptions struct {
UserID int32 // only list product subscriptions for this user
IncludeArchived bool
*db.LimitOffset
}
func (o dbSubscriptionsListOptions) sqlConditions() []*sqlf.Query {
conds := []*sqlf.Query{sqlf.Sprintf("TRUE")}
if o.UserID != 0 {
conds = append(conds, sqlf.Sprintf("user_id=%d", o.UserID))
}
if !o.IncludeArchived {
conds = append(conds, sqlf.Sprintf("archived_at IS NULL"))
}
return conds
}
// List lists all product subscriptions that satisfy the options.
func (s dbSubscriptions) List(ctx context.Context, opt dbSubscriptionsListOptions) ([]*dbSubscription, error) {
return s.list(ctx, opt.sqlConditions(), opt.LimitOffset)
}
func (dbSubscriptions) list(ctx context.Context, conds []*sqlf.Query, limitOffset *db.LimitOffset) ([]*dbSubscription, error) {
q := sqlf.Sprintf(`
SELECT id, user_id, billing_subscription_id, created_at, archived_at FROM product_subscriptions
WHERE (%s)
ORDER BY archived_at DESC NULLS FIRST, created_at DESC
%s`,
sqlf.Join(conds, ") AND ("),
limitOffset.SQL(),
)
rows, err := dbconn.Global.QueryContext(ctx, q.Query(sqlf.PostgresBindVar), q.Args()...)
if err != nil {
return nil, err
}
defer rows.Close()
var results []*dbSubscription
for rows.Next() {
var v dbSubscription
if err := rows.Scan(&v.ID, &v.UserID, &v.BillingSubscriptionID, &v.CreatedAt, &v.ArchivedAt); err != nil {
return nil, err
}
results = append(results, &v)
}
return results, nil
}
// Count counts all product subscriptions that satisfy the options (ignoring limit and offset).
func (dbSubscriptions) Count(ctx context.Context, opt dbSubscriptionsListOptions) (int, error) {
q := sqlf.Sprintf("SELECT COUNT(*) FROM product_subscriptions WHERE (%s)", sqlf.Join(opt.sqlConditions(), ") AND ("))
var count int
if err := dbconn.Global.QueryRowContext(ctx, q.Query(sqlf.PostgresBindVar), q.Args()...).Scan(&count); err != nil {
return 0, err
}
return count, nil
}
// dbSubscriptionsUpdate represents an update to a product subscription in the database. Each field
// represents an update to the corresponding database field if the Go value is non-nil. If the Go
// value is nil, the field remains unchanged in the database.
type dbSubscriptionUpdate struct {
billingSubscriptionID *sql.NullString
}
// Update updates a product subscription.
func (dbSubscriptions) Update(ctx context.Context, id string, update dbSubscriptionUpdate) error {
fieldUpdates := []*sqlf.Query{
sqlf.Sprintf("updated_at=now()"), // always update updated_at timestamp
}
if v := update.billingSubscriptionID; v != nil {
fieldUpdates = append(fieldUpdates, sqlf.Sprintf("billing_subscription_id=%s", *v))
}
query := sqlf.Sprintf("UPDATE product_subscriptions SET %s WHERE id=%s", sqlf.Join(fieldUpdates, ", "), id)
res, err := dbconn.Global.ExecContext(ctx, query.Query(sqlf.PostgresBindVar), query.Args()...)
if err != nil {
return err
}
nrows, err := res.RowsAffected()
if err != nil {
return err
}
if nrows == 0 {
return errSubscriptionNotFound
}
return nil
}
// Archive marks a product subscription as archived given its ID.
//
// 🚨 SECURITY: The caller must ensure that the actor is permitted to archive the token.
func (dbSubscriptions) Archive(ctx context.Context, id string) error {
if mocks.subscriptions.Archive != nil {
return mocks.subscriptions.Archive(id)
}
q := sqlf.Sprintf("UPDATE product_subscriptions SET archived_at=now(), updated_at=now() WHERE id=%s AND archived_at IS NULL", id)
res, err := dbconn.Global.ExecContext(ctx, q.Query(sqlf.PostgresBindVar), q.Args()...)
if err != nil {
return err
}
nrows, err := res.RowsAffected()
if err != nil {
return err
}
if nrows == 0 {
return errSubscriptionNotFound
}
return nil
}
type mockSubscriptions struct {
Create func(userID int32) (id string, err error)
GetByID func(id string) (*dbSubscription, error)
Archive func(id string) error
}

View File

@ -0,0 +1,176 @@
package productsubscription
import (
"database/sql"
"testing"
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
dbtesting "github.com/sourcegraph/sourcegraph/cmd/frontend/db/testing"
)
func TestProductSubscriptions_Create(t *testing.T) {
ctx := dbtesting.TestContext(t)
u, err := db.Users.Create(ctx, db.NewUser{Username: "u"})
if err != nil {
t.Fatal(err)
}
sub0, err := dbSubscriptions{}.Create(ctx, u.ID)
if err != nil {
t.Fatal(err)
}
got, err := dbSubscriptions{}.GetByID(ctx, sub0)
if err != nil {
t.Fatal(err)
}
if want := sub0; got.ID != want {
t.Errorf("got %v, want %v", got.ID, want)
}
if want := u.ID; got.UserID != want {
t.Errorf("got %v, want %v", got.UserID, want)
}
if got.BillingSubscriptionID != nil {
t.Errorf("got %v, want nil", got.BillingSubscriptionID)
}
ts, err := dbSubscriptions{}.List(ctx, dbSubscriptionsListOptions{UserID: u.ID})
if err != nil {
t.Fatal(err)
}
if want := 1; len(ts) != want {
t.Errorf("got %d product subscriptions, want %d", len(ts), want)
}
ts, err = dbSubscriptions{}.List(ctx, dbSubscriptionsListOptions{UserID: 123 /* invalid */})
if err != nil {
t.Fatal(err)
}
if want := 0; len(ts) != want {
t.Errorf("got %d product subscriptions, want %d", len(ts), want)
}
}
func TestProductSubscriptions_List(t *testing.T) {
if testing.Short() {
t.Skip()
}
ctx := dbtesting.TestContext(t)
u1, err := db.Users.Create(ctx, db.NewUser{Username: "u1"})
if err != nil {
t.Fatal(err)
}
u2, err := db.Users.Create(ctx, db.NewUser{Username: "u2"})
if err != nil {
t.Fatal(err)
}
_, err = dbSubscriptions{}.Create(ctx, u1.ID)
if err != nil {
t.Fatal(err)
}
_, err = dbSubscriptions{}.Create(ctx, u1.ID)
if err != nil {
t.Fatal(err)
}
{
// List all product subscriptions.
ts, err := dbSubscriptions{}.List(ctx, dbSubscriptionsListOptions{})
if err != nil {
t.Fatal(err)
}
if want := 2; len(ts) != want {
t.Errorf("got %d product subscriptions, want %d", len(ts), want)
}
count, err := dbSubscriptions{}.Count(ctx, dbSubscriptionsListOptions{})
if err != nil {
t.Fatal(err)
}
if want := 2; count != want {
t.Errorf("got %d, want %d", count, want)
}
}
{
// List u1's product subscriptions.
ts, err := dbSubscriptions{}.List(ctx, dbSubscriptionsListOptions{UserID: u1.ID})
if err != nil {
t.Fatal(err)
}
if want := 2; len(ts) != want {
t.Errorf("got %d product subscriptions, want %d", len(ts), want)
}
}
{
// List u2's product subscriptions.
ts, err := dbSubscriptions{}.List(ctx, dbSubscriptionsListOptions{UserID: u2.ID})
if err != nil {
t.Fatal(err)
}
if want := 0; len(ts) != want {
t.Errorf("got %d product subscriptions, want %d", len(ts), want)
}
}
}
func TestProductSubscriptions_Update(t *testing.T) {
ctx := dbtesting.TestContext(t)
u, err := db.Users.Create(ctx, db.NewUser{Username: "u"})
if err != nil {
t.Fatal(err)
}
sub0, err := dbSubscriptions{}.Create(ctx, u.ID)
if err != nil {
t.Fatal(err)
}
if got, err := (dbSubscriptions{}).GetByID(ctx, sub0); err != nil {
t.Fatal(err)
} else if got.BillingSubscriptionID != nil {
t.Errorf("got %q, want nil", *got.BillingSubscriptionID)
}
// Set non-null value.
if err := (dbSubscriptions{}).Update(ctx, sub0, dbSubscriptionUpdate{
billingSubscriptionID: &sql.NullString{
String: "x",
Valid: true,
},
}); err != nil {
t.Fatal(err)
}
if got, err := (dbSubscriptions{}).GetByID(ctx, sub0); err != nil {
t.Fatal(err)
} else if want := "x"; got.BillingSubscriptionID == nil || *got.BillingSubscriptionID != want {
t.Errorf("got %v, want %q", got.BillingSubscriptionID, want)
}
// Update no fields.
if err := (dbSubscriptions{}).Update(ctx, sub0, dbSubscriptionUpdate{billingSubscriptionID: nil}); err != nil {
t.Fatal(err)
}
if got, err := (dbSubscriptions{}).GetByID(ctx, sub0); err != nil {
t.Fatal(err)
} else if want := "x"; got.BillingSubscriptionID == nil || *got.BillingSubscriptionID != want {
t.Errorf("got %v, want %q", got.BillingSubscriptionID, want)
}
// Set null value.
if err := (dbSubscriptions{}).Update(ctx, sub0, dbSubscriptionUpdate{
billingSubscriptionID: &sql.NullString{Valid: false},
}); err != nil {
t.Fatal(err)
}
if got, err := (dbSubscriptions{}).GetByID(ctx, sub0); err != nil {
t.Fatal(err)
} else if got.BillingSubscriptionID != nil {
t.Errorf("got %q, want nil", *got.BillingSubscriptionID)
}
}
func strptr(s string) *string { return &s }

View File

@ -0,0 +1,520 @@
package productsubscription
import (
"context"
"database/sql"
"fmt"
"strings"
"sync"
"time"
graphql "github.com/graph-gophers/graphql-go"
"github.com/graph-gophers/graphql-go/relay"
"github.com/pkg/errors"
"github.com/sourcegraph/enterprise/cmd/frontend/internal/dotcom/billing"
"github.com/sourcegraph/sourcegraph/cmd/frontend/backend"
db_ "github.com/sourcegraph/sourcegraph/cmd/frontend/db"
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend/graphqlutil"
stripe "github.com/stripe/stripe-go"
"github.com/stripe/stripe-go/customer"
"github.com/stripe/stripe-go/event"
"github.com/stripe/stripe-go/invoice"
"github.com/stripe/stripe-go/plan"
"github.com/stripe/stripe-go/sub"
)
func init() {
graphqlbackend.ProductSubscriptionByID = func(ctx context.Context, id graphql.ID) (graphqlbackend.ProductSubscription, error) {
return productSubscriptionByID(ctx, id)
}
}
// productSubscription implements the GraphQL type ProductSubscription.
type productSubscription struct {
v *dbSubscription
}
// productSubscriptionByID looks up and returns the ProductSubscription with the given GraphQL
// ID. If no such ProductSubscription exists, it returns a non-nil error.
func productSubscriptionByID(ctx context.Context, id graphql.ID) (*productSubscription, error) {
idString, err := unmarshalProductSubscriptionID(id)
if err != nil {
return nil, err
}
return productSubscriptionByDBID(ctx, idString)
}
// productSubscriptionByDBID looks up and returns the ProductSubscription with the given database
// ID. If no such ProductSubscription exists, it returns a non-nil error.
func productSubscriptionByDBID(ctx context.Context, id string) (*productSubscription, error) {
v, err := dbSubscriptions{}.GetByID(ctx, id)
if err != nil {
return nil, err
}
// 🚨 SECURITY: Only site admins and the subscription account's user may view a product subscription.
if err := backend.CheckSiteAdminOrSameUser(ctx, v.UserID); err != nil {
return nil, err
}
return &productSubscription{v: v}, nil
}
func (r *productSubscription) ID() graphql.ID {
return marshalProductSubscriptionID(r.v.ID)
}
func marshalProductSubscriptionID(id string) graphql.ID {
return relay.MarshalID("ProductSubscription", id)
}
func unmarshalProductSubscriptionID(id graphql.ID) (productSubscriptionID string, err error) {
err = relay.UnmarshalSpec(id, &productSubscriptionID)
return
}
func (r *productSubscription) UUID() string {
return r.v.ID
}
func (r *productSubscription) Name() string {
return fmt.Sprintf("L-%s", strings.ToUpper(strings.Replace(r.v.ID, "-", "", -1)[:10]))
}
func (r *productSubscription) Account(ctx context.Context) (*graphqlbackend.UserResolver, error) {
return graphqlbackend.UserByIDInt32(ctx, r.v.UserID)
}
func (r *productSubscription) Events(ctx context.Context) ([]graphqlbackend.ProductSubscriptionEvent, error) {
if r.v.BillingSubscriptionID == nil {
return []graphqlbackend.ProductSubscriptionEvent{}, nil
}
// List all events related to this subscription. The related_object parameter is an undocumented
// Stripe API.
params := &stripe.EventListParams{
ListParams: stripe.ListParams{Context: ctx},
}
params.Filters.AddFilter("related_object", "", *r.v.BillingSubscriptionID)
events := event.List(params)
var gqlEvents []graphqlbackend.ProductSubscriptionEvent
for events.Next() {
gqlEvent, okToShowUser := billing.ToProductSubscriptionEvent(events.Event())
if okToShowUser {
gqlEvents = append(gqlEvents, gqlEvent)
}
}
if err := events.Err(); err != nil {
return nil, err
}
return gqlEvents, nil
}
func (r *productSubscription) ActiveLicense(ctx context.Context) (graphqlbackend.ProductLicense, error) {
// Return newest license.
licenses, err := dbLicenses{}.List(ctx, dbLicensesListOptions{
ProductSubscriptionID: r.v.ID,
LimitOffset: &db_.LimitOffset{Limit: 1},
})
if err != nil {
return nil, err
}
if len(licenses) == 0 {
return nil, nil
}
return &productLicense{v: licenses[0]}, nil
}
func (r *productSubscription) ProductLicenses(ctx context.Context, args *graphqlutil.ConnectionArgs) (graphqlbackend.ProductLicenseConnection, error) {
// 🚨 SECURITY: Only site admins may list historical product licenses (to reduce confusion
// around old license reuse). Other viewers should use ProductSubscription.activeLicense.
if err := backend.CheckCurrentUserIsSiteAdmin(ctx); err != nil {
return nil, err
}
opt := dbLicensesListOptions{ProductSubscriptionID: r.v.ID}
args.Set(&opt.LimitOffset)
return &productLicenseConnection{opt: opt}, nil
}
func (r *productSubscription) CreatedAt() string {
return r.v.CreatedAt.Format(time.RFC3339)
}
func (r *productSubscription) IsArchived() bool { return r.v.ArchivedAt != nil }
func (r *productSubscription) URL(ctx context.Context) (string, error) {
accountUser, err := r.Account(ctx)
if err != nil {
return "", err
}
return accountUser.URL() + "/subscriptions/" + string(r.v.ID), nil
}
func (r *productSubscription) URLForSiteAdmin(ctx context.Context) *string {
// 🚨 SECURITY: Only site admins may see this URL. Currently it does not contain any sensitive
// info, but there is no need to show it to non-site admins.
if err := backend.CheckCurrentUserIsSiteAdmin(ctx); err != nil {
return nil
}
u := fmt.Sprintf("/site-admin/dotcom/product/subscriptions/%s", r.v.ID)
return &u
}
func (r *productSubscription) URLForSiteAdminBilling(ctx context.Context) (*string, error) {
// 🚨 SECURITY: Only site admins may see this URL, which might contain the subscription's billing ID.
if err := backend.CheckCurrentUserIsSiteAdmin(ctx); err != nil {
return nil, err
}
if id := r.v.BillingSubscriptionID; id != nil {
u := billing.SubscriptionURL(*id)
return &u, nil
}
return nil, nil
}
func (ProductSubscriptionLicensingResolver) CreateProductSubscription(ctx context.Context, args *graphqlbackend.CreateProductSubscriptionArgs) (graphqlbackend.ProductSubscription, error) {
// 🚨 SECURITY: Only site admins may create product subscriptions.
if err := backend.CheckCurrentUserIsSiteAdmin(ctx); err != nil {
return nil, err
}
user, err := graphqlbackend.UserByID(ctx, args.AccountID)
if err != nil {
return nil, err
}
id, err := dbSubscriptions{}.Create(ctx, user.SourcegraphID())
if err != nil {
return nil, err
}
return productSubscriptionByDBID(ctx, id)
}
func (ProductSubscriptionLicensingResolver) SetProductSubscriptionBilling(ctx context.Context, args *graphqlbackend.SetProductSubscriptionBillingArgs) (*graphqlbackend.EmptyResponse, error) {
// 🚨 SECURITY: Only site admins may update product subscriptions.
if err := backend.CheckCurrentUserIsSiteAdmin(ctx); err != nil {
return nil, err
}
// Ensure the args refer to valid subscriptions in the database and in the billing system.
dbSub, err := productSubscriptionByID(ctx, args.ID)
if err != nil {
return nil, err
}
if args.BillingSubscriptionID != nil {
if _, err := sub.Get(*args.BillingSubscriptionID, &stripe.SubscriptionParams{Params: stripe.Params{Context: ctx}}); err != nil {
return nil, err
}
}
stringValue := func(s *string) string {
if s == nil {
return ""
}
return *s
}
if err := (dbSubscriptions{}).Update(ctx, dbSub.v.ID, dbSubscriptionUpdate{
billingSubscriptionID: &sql.NullString{
String: stringValue(args.BillingSubscriptionID),
Valid: args.BillingSubscriptionID != nil,
},
}); err != nil {
return nil, err
}
return &graphqlbackend.EmptyResponse{}, nil
}
func (ProductSubscriptionLicensingResolver) CreatePaidProductSubscription(ctx context.Context, args *graphqlbackend.CreatePaidProductSubscriptionArgs) (*graphqlbackend.CreatePaidProductSubscriptionResult, error) {
user, err := graphqlbackend.UserByID(ctx, args.AccountID)
if err != nil {
return nil, err
}
// 🚨 SECURITY: Users may only create paid product subscriptions for themselves. Site admins may
// create them for any user.
if err := backend.CheckSiteAdminOrSameUser(ctx, user.SourcegraphID()); err != nil {
return nil, err
}
// Determine which license tags and min quantity to use for the purchased plan. Do this early on
// because it's the most likely place for a stupid mistake to cause a bug, and doing it early
// means the user hasn't been charged if there is an error.
licenseTags, minQuantity, err := billing.InfoForProductPlan(ctx, args.ProductSubscription.BillingPlanID)
if err != nil {
return nil, err
}
if minQuantity != nil && args.ProductSubscription.UserCount < *minQuantity {
args.ProductSubscription.UserCount = *minQuantity
}
// Create the subscription in our database first, before processing payment. If payment fails,
// users can retry payment on the already created subscription.
subID, err := dbSubscriptions{}.Create(ctx, user.SourcegraphID())
if err != nil {
return nil, err
}
// Get the billing customer for the current user, and update it to use the payment source
// provided to us.
custID, err := billing.GetOrAssignUserCustomerID(ctx, user.SourcegraphID())
if err != nil {
return nil, err
}
custUpdateParams := &stripe.CustomerParams{
Params: stripe.Params{Context: ctx},
}
custUpdateParams.SetSource(args.PaymentToken)
if _, err := customer.Update(custID, custUpdateParams); err != nil {
return nil, err
}
// Create the billing subscription.
billingSub, err := sub.New(&stripe.SubscriptionParams{
Params: stripe.Params{Context: ctx},
Customer: stripe.String(custID),
Items: []*stripe.SubscriptionItemsParams{billing.ToSubscriptionItemsParams(args.ProductSubscription)},
})
if err != nil {
return nil, err
}
// Link the billing subscription with the subscription in our database.
if err := (dbSubscriptions{}).Update(ctx, subID, dbSubscriptionUpdate{
billingSubscriptionID: &sql.NullString{
String: billingSub.ID,
Valid: true,
},
}); err != nil {
return nil, err
}
// Generate a new license key for the subscription.
if _, err := generateProductLicenseForSubscription(ctx, subID, &graphqlbackend.ProductLicenseInput{
Tags: licenseTags,
UserCount: args.ProductSubscription.UserCount,
ExpiresAt: int32(billingSub.CurrentPeriodEnd),
}); err != nil {
return nil, err
}
sub, err := productSubscriptionByDBID(ctx, subID)
if err != nil {
return nil, err
}
return &graphqlbackend.CreatePaidProductSubscriptionResult{ProductSubscriptionValue: sub}, nil
}
func (ProductSubscriptionLicensingResolver) UpdatePaidProductSubscription(ctx context.Context, args *graphqlbackend.UpdatePaidProductSubscriptionArgs) (*graphqlbackend.UpdatePaidProductSubscriptionResult, error) {
subToUpdate, err := productSubscriptionByID(ctx, args.SubscriptionID)
if err != nil {
return nil, err
}
// 🚨 SECURITY: Only site admins and the subscription's account owner may update product
// subscriptions.
if err := backend.CheckSiteAdminOrSameUser(ctx, subToUpdate.v.UserID); err != nil {
return nil, err
}
// Determine which license tags and min quantity to use for the purchased plan. Do this early on
// because it's the most likely place for a stupid mistake to cause a bug, and doing it early
// means the user hasn't been charged if there is an error.
licenseTags, minQuantity, err := billing.InfoForProductPlan(ctx, args.Update.BillingPlanID)
if err != nil {
return nil, err
}
if minQuantity != nil && args.Update.UserCount < *minQuantity {
args.Update.UserCount = *minQuantity
}
params := &stripe.SubscriptionParams{
Params: stripe.Params{Context: ctx},
Items: []*stripe.SubscriptionItemsParams{billing.ToSubscriptionItemsParams(args.Update)},
Prorate: stripe.Bool(true),
}
// Get the billing customer for the current user, and update it to use the payment source
// provided to us.
custID, err := billing.GetOrAssignUserCustomerID(ctx, subToUpdate.v.UserID)
if err != nil {
return nil, err
}
custUpdateParams := &stripe.CustomerParams{
Params: stripe.Params{Context: ctx},
}
custUpdateParams.SetSource(args.PaymentToken)
if _, err := customer.Update(custID, custUpdateParams); err != nil {
return nil, err
}
if subToUpdate.v.BillingSubscriptionID == nil {
return nil, errors.New("unable to update product subscription that has no associated billing information")
}
subParams := &stripe.SubscriptionParams{Params: stripe.Params{Context: ctx}}
subParams.AddExpand("plan")
billingSubToUpdate, err := sub.Get(*subToUpdate.v.BillingSubscriptionID, subParams)
if err != nil {
return nil, err
}
idToReplace, err := billing.GetSubscriptionItemIDToReplace(billingSubToUpdate, custID)
if err != nil {
return nil, err
}
params.Items[0].ID = stripe.String(idToReplace)
// Forbid self-service downgrades. (Reason: We can't revoke licenses, so we want to manually
// intervene to ensure that customers who downgrade are not using the previous license.)
{
planParams := &stripe.PlanParams{Params: stripe.Params{Context: ctx}}
afterPlan, err := plan.Get(args.Update.BillingPlanID, planParams)
if err != nil {
return nil, err
}
if isDowngradeRequiringManualIntervention(int32(billingSubToUpdate.Quantity), billingSubToUpdate.Plan.Amount, args.Update.UserCount, afterPlan.Amount) {
return nil, errors.New("self-service downgrades are not yet supported")
}
}
// Update the billing subscription.
billingSub, err := sub.Update(*subToUpdate.v.BillingSubscriptionID, params)
if err != nil {
return nil, err
}
// Generate an invoice and charge so that payment is performed immediately. See
// https://stripe.com/docs/billing/subscriptions/upgrading-downgrading.
//
// TODO(sqs): use webhooks to ensure the subscription is rolled back if the invoice payment
// fails.
{
inv, err := invoice.New(&stripe.InvoiceParams{
Params: stripe.Params{Context: ctx},
Customer: stripe.String(custID),
Subscription: stripe.String(*subToUpdate.v.BillingSubscriptionID),
})
if err != nil {
return nil, err
}
if _, err := invoice.Pay(inv.ID, &stripe.InvoicePayParams{
Params: stripe.Params{Context: ctx},
}); err != nil {
return nil, err
}
}
// Generate a new license key for the subscription with the updated parameters.
if _, err := generateProductLicenseForSubscription(ctx, subToUpdate.v.ID, &graphqlbackend.ProductLicenseInput{
Tags: licenseTags,
UserCount: args.Update.UserCount,
ExpiresAt: int32(billingSub.CurrentPeriodEnd),
}); err != nil {
return nil, err
}
return &graphqlbackend.UpdatePaidProductSubscriptionResult{ProductSubscriptionValue: subToUpdate}, nil
}
func (ProductSubscriptionLicensingResolver) ArchiveProductSubscription(ctx context.Context, args *graphqlbackend.ArchiveProductSubscriptionArgs) (*graphqlbackend.EmptyResponse, error) {
// 🚨 SECURITY: Only site admins may archive product subscriptions.
if err := backend.CheckCurrentUserIsSiteAdmin(ctx); err != nil {
return nil, err
}
sub, err := productSubscriptionByID(ctx, args.ID)
if err != nil {
return nil, err
}
if err := (dbSubscriptions{}).Archive(ctx, sub.v.ID); err != nil {
return nil, err
}
return &graphqlbackend.EmptyResponse{}, nil
}
func (ProductSubscriptionLicensingResolver) ProductSubscription(ctx context.Context, args *graphqlbackend.ProductSubscriptionArgs) (graphqlbackend.ProductSubscription, error) {
// 🚨 SECURITY: Only site admins and the subscription's account owner may get a product
// subscription. This check is performed in productSubscriptionByDBID.
return productSubscriptionByDBID(ctx, args.UUID)
}
func (ProductSubscriptionLicensingResolver) ProductSubscriptions(ctx context.Context, args *graphqlbackend.ProductSubscriptionsArgs) (graphqlbackend.ProductSubscriptionConnection, error) {
var accountUser *graphqlbackend.UserResolver
if args.Account != nil {
var err error
accountUser, err = graphqlbackend.UserByID(ctx, *args.Account)
if err != nil {
return nil, err
}
}
// 🚨 SECURITY: Users may only list their own product subscriptions. Site admins may list
// licenses for all users, or for any other user.
if accountUser == nil {
if err := backend.CheckCurrentUserIsSiteAdmin(ctx); err != nil {
return nil, err
}
} else {
if err := backend.CheckSiteAdminOrSameUser(ctx, accountUser.SourcegraphID()); err != nil {
return nil, err
}
}
var opt dbSubscriptionsListOptions
if accountUser != nil {
opt.UserID = accountUser.SourcegraphID()
}
args.ConnectionArgs.Set(&opt.LimitOffset)
return &productSubscriptionConnection{opt: opt}, nil
}
// productSubscriptionConnection implements the GraphQL type ProductSubscriptionConnection.
//
// 🚨 SECURITY: When instantiating a productSubscriptionConnection value, the caller MUST
// check permissions.
type productSubscriptionConnection struct {
opt dbSubscriptionsListOptions
// cache results because they are used by multiple fields
once sync.Once
results []*dbSubscription
err error
}
func (r *productSubscriptionConnection) compute(ctx context.Context) ([]*dbSubscription, error) {
r.once.Do(func() {
opt2 := r.opt
if opt2.LimitOffset != nil {
tmp := *opt2.LimitOffset
opt2.LimitOffset = &tmp
opt2.Limit++ // so we can detect if there is a next page
}
r.results, r.err = dbSubscriptions{}.List(ctx, opt2)
})
return r.results, r.err
}
func (r *productSubscriptionConnection) Nodes(ctx context.Context) ([]graphqlbackend.ProductSubscription, error) {
results, err := r.compute(ctx)
if err != nil {
return nil, err
}
var l []graphqlbackend.ProductSubscription
for _, result := range results {
l = append(l, &productSubscription{v: result})
}
return l, nil
}
func (r *productSubscriptionConnection) TotalCount(ctx context.Context) (int32, error) {
count, err := dbSubscriptions{}.Count(ctx, r.opt)
return int32(count), err
}
func (r *productSubscriptionConnection) PageInfo(ctx context.Context) (*graphqlutil.PageInfo, error) {
results, err := r.compute(ctx)
if err != nil {
return nil, err
}
return graphqlutil.HasNextPage(r.opt.LimitOffset != nil && len(results) > r.opt.Limit), nil
}

View File

@ -0,0 +1,18 @@
package graphqlbackend
import (
"github.com/sourcegraph/enterprise/cmd/frontend/internal/dotcom/billing"
"github.com/sourcegraph/enterprise/cmd/frontend/internal/dotcom/productsubscription"
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
)
func init() {
// Contribute the GraphQL types DotcomMutation and DotcomQuery.
graphqlbackend.Dotcom = dotcomResolver{}
}
// dotcomResolver implements the GraphQL types DotcomMutation and DotcomQuery.
type dotcomResolver struct {
productsubscription.ProductSubscriptionLicensingResolver
billing.BillingResolver
}

View File

@ -0,0 +1,3 @@
// Package httpapi should be imported for side-effects to register enterprise-specific hooks into
// corresponding httpapi package in the open-source repository.
package httpapi

View File

@ -0,0 +1,16 @@
package httpapi
import (
"github.com/sourcegraph/sourcegraph/cmd/frontend/httpapi"
"github.com/sourcegraph/sourcegraph/xlang"
)
func init() {
httpapi.XLangNewClient = func() (httpapi.XLangClient, error) {
c, err := xlang.UnsafeNewDefaultClient()
if err != nil {
return nil, err
}
return &xclient{Client: c}, nil
}
}

View File

@ -0,0 +1,394 @@
package httpapi
import (
"context"
"encoding/json"
"net/url"
opentracing "github.com/opentracing/opentracing-go"
"github.com/opentracing/opentracing-go/ext"
otlog "github.com/opentracing/opentracing-go/log"
"github.com/pkg/errors"
"github.com/sourcegraph/go-langserver/pkg/lsp"
"github.com/sourcegraph/go-langserver/pkg/lspext"
vcsurl "github.com/sourcegraph/go-vcsurl"
"github.com/sourcegraph/jsonrpc2"
"github.com/sourcegraph/sourcegraph/cmd/frontend/backend"
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
"github.com/sourcegraph/sourcegraph/pkg/api"
"github.com/sourcegraph/sourcegraph/pkg/conf"
"github.com/sourcegraph/sourcegraph/pkg/errcode"
"github.com/sourcegraph/sourcegraph/pkg/vcs"
"github.com/sourcegraph/sourcegraph/xlang"
xlang_lspext "github.com/sourcegraph/sourcegraph/xlang/lspext"
)
// xclient is an LSP client that transparently wraps xlang.Client,
// except that it translates textDocument/definition requests into a
// series of requests that computes the cross-repo jump-to-definition
// result.
type xclient struct {
*xlang.Client
hasXDefinitionAndXPackages bool
hasCrossRepoHover bool
mode string
}
// Call transparently wraps xlang.Client.Call *except* for `textDocument/definition` if the language
// server is a textDocument/xdefinition provider. In that case, this method invokes
// `textDocument/xdefinition` instead. If the result contains a non-zero `Location` field, then that
// is returned to the client as if it came from `textDocument/definition`. If the location is zero,
// then that means the definition did not exist locally. The method will locate the definition in an
// external repository and return that to the client as if it came from a single
// `textDocument/definition` call.
//
// SECURITY NOTE: Call also verifies permissions for cross-repo jumps. Any changes to this method
// should preserve this property.
func (c *xclient) Call(ctx context.Context, method string, params, result interface{}, opt ...jsonrpc2.CallOption) (err error) {
span, ctx := opentracing.StartSpanFromContext(ctx, "xclient.Call")
defer func() {
if err != nil {
ext.Error.Set(span, true)
span.SetTag("err", err.Error())
}
span.Finish()
}()
span.SetTag("Method", method)
// marshalResult takes an existing interface and marshals it into result
// via JSON.
marshalResult := func(v interface{}, err error) error {
if err != nil {
return err
}
b, err := json.Marshal(v)
if err != nil {
return errors.Wrap(err, "marshaling result")
}
return json.Unmarshal(b, result)
}
switch {
case method == "initialize":
var init xlang_lspext.ClientProxyInitializeParams
if err := json.Unmarshal(*params.(*json.RawMessage), &init); err != nil {
return err
}
c.mode = init.InitializationOptions.Mode
if c.mode == "" {
// DEPRECATED: Use old Mode field if the new one is not set.
c.mode = init.Mode
}
var resultRaw json.RawMessage
if err := c.Client.Call(ctx, method, params, &resultRaw, opt...); err != nil {
return err
}
// We only care about the XDefinitionProvider. Right now it implies
// the support of XPackages as well :'(
var initResultSubset struct {
Capabilities struct {
// XDefinitionProvider indicates the server provides support for
// textDocument/xdefinition. This is a Sourcegraph extension.
XDefinitionProvider bool `json:"xdefinitionProvider,omitempty"`
} `json:"capabilities,omitempty"`
}
if err := json.Unmarshal(resultRaw, &initResultSubset); err != nil {
return err
}
c.hasXDefinitionAndXPackages = initResultSubset.Capabilities.XDefinitionProvider
//_, c.hasXDefinitionAndXPackages = xlang.HasXDefinitionAndXPackages[c.mode]
_, c.hasCrossRepoHover = xlang.HasCrossRepoHover[c.mode]
return json.Unmarshal(resultRaw, result)
case !c.hasXDefinitionAndXPackages:
break
case method == "textDocument/definition":
span.SetTag("LocationAbsent", "true")
return marshalResult(c.jumpToDefCrossRepo(ctx, params, opt...))
case method == "textDocument/hover" && c.hasCrossRepoHover:
return marshalResult(c.hoverCrossRepo(ctx, params, opt...))
case method == "xsymbol/hover":
// Federation. This will only run on sourcegraph.com
var syms []lspext.SymbolLocationInformation
if err := json.Unmarshal(*params.(*json.RawMessage), &syms); err != nil {
return err
}
return marshalResult(c.symbolHover(ctx, syms))
case method == "xsymbol/definition":
// Federation. This will only run on sourcegraph.com
var syms []lspext.SymbolLocationInformation
if err := json.Unmarshal(*params.(*json.RawMessage), &syms); err != nil {
return err
}
return marshalResult(c.symbolDefinition(ctx, syms))
}
return c.Client.Call(ctx, method, params, result, opt...)
}
func (c *xclient) Notify(ctx context.Context, method string, params interface{}, opt ...jsonrpc2.CallOption) error {
return c.Client.Notify(ctx, method, params, opt...)
}
func (c *xclient) Close() error {
return c.Client.Close()
}
func (c *xclient) xdefQuery(ctx context.Context, syms []lspext.SymbolLocationInformation) (map[lsp.DocumentURI][]lsp.SymbolInformation, error) {
span := opentracing.SpanFromContext(ctx)
symInfos := make(map[lsp.DocumentURI][]lsp.SymbolInformation)
// For each symbol in the xdefinition-result-derived query, compute the symbol information for that symbol
for _, sym := range syms {
var rootURIs []lsp.DocumentURI
// If we can extract the repository URL from the symbol metadata, do so
if repoURL := xlang.SymbolRepoURL(sym.Symbol); repoURL != "" {
span.LogFields(otlog.String("event", "extracted repo directly from symbol metadata"),
otlog.String("repoURL", repoURL))
repoInfo, err := vcsurl.Parse(repoURL)
if err != nil {
return nil, errors.Wrap(err, "extract repo URL from symbol metadata")
}
repoURI := api.RepoURI(string(repoInfo.RepoHost) + "/" + repoInfo.FullName)
// We issue a workspace/symbols on the URL, so ensure we have the repo / it exists.
repo, err := backend.Repos.GetByURI(ctx, repoURI)
if err != nil {
span.LogFields(otlog.Error(err))
if _, isSeeOther := err.(backend.ErrRepoSeeOther); isSeeOther || errcode.IsNotFound(err) {
span.LogFields(otlog.String("event", "ignoring not found error"))
continue
}
return nil, errors.Wrap(err, "extract repo URL from symbol metadata")
}
rev, err := backend.Repos.ResolveRev(ctx, repo, "")
if err != nil {
span.LogFields(otlog.Error(err))
if vcs.IsRepoNotExist(err) {
span.LogFields(otlog.String("event", "ignoring not found error"))
continue
}
return nil, errors.Wrap(err, "extract repo URL from symbol metadata")
}
rootURIs = append(rootURIs, lsp.DocumentURI(string(repoInfo.VCS)+"://"+string(repoURI)+"?"+string(rev)))
} else { // if we can't extract the repository URL directly, we have to consult the pkgs database
pkgDescriptor, ok := xlang.SymbolPackageDescriptor(sym.Symbol, c.mode)
if !ok {
continue
}
span.LogFields(otlog.String("event", "cross-repo jump to def"))
pkgs, err := db.Pkgs.ListPackages(ctx, &api.ListPackagesOp{PkgQuery: pkgDescriptor, Lang: c.mode, Limit: 1})
if err != nil {
return nil, errors.Wrap(err, "getting repo by package db query")
}
span.LogFields(otlog.String("event", "listed repository packages"))
for _, pkg := range pkgs {
repo, err := backend.Repos.Get(ctx, pkg.RepoID)
if err != nil {
return nil, errors.Wrap(err, "fetch repo for package")
}
var commit api.CommitID
if repo.IndexedRevision != nil {
commit = *repo.IndexedRevision
} else {
var err error
commit, err = backend.Repos.ResolveRev(ctx, repo, "")
if err != nil {
return nil, errors.Wrap(err, "resolve revision for package repo")
}
}
rootURIs = append(rootURIs, lsp.DocumentURI("git://"+string(repo.URI)+"?"+string(commit)))
}
span.LogFields(otlog.String("event", "resolved rootURIs"))
}
// Issue a workspace/symbol for each repository that provides a definition for the symbol
for _, rootURI := range rootURIs {
params := &lspext.WorkspaceSymbolParams{Symbol: sym.Symbol, Limit: 10}
var repoSymInfos []lsp.SymbolInformation
if err := xlang.UnsafeOneShotClientRequest(ctx, c.mode, rootURI, "workspace/symbol", params, &repoSymInfos); err != nil {
return nil, errors.Wrap(err, "resolving symbol to location")
}
symInfos[rootURI] = repoSymInfos
}
span.LogFields(otlog.String("event", "done issuing workspace/symbol requests"))
}
return symInfos, nil
}
func (c *xclient) jumpToDefCrossRepo(ctx context.Context, params interface{}, opt ...jsonrpc2.CallOption) ([]lsp.Location, error) {
// Issue xdefinition request
var syms []lspext.SymbolLocationInformation
if err := c.Client.Call(ctx, "textDocument/xdefinition", params, &syms, opt...); err != nil {
return nil, err
}
locs := make([]lsp.Location, 0, len(syms))
var nolocSyms []lspext.SymbolLocationInformation
for _, sym := range syms {
// If a concrete location is already present, just use that
if sym.Location != (lsp.Location{}) {
locs = append(locs, sym.Location)
} else {
nolocSyms = append(nolocSyms, sym)
}
}
symLocs, err := c.symbolDefinition(ctx, nolocSyms)
if err != nil {
return nil, err
}
locs = append(locs, symLocs...)
// Failed to find the definition locally, try symbolDefinition on Sourcegraph.com
// which may have indexed the OSS repo used.
if len(locs) == 0 && len(nolocSyms) > 0 && conf.JumpToDefOSSIndexEnabled() {
// HACK we need a valid rootURI, even though we are doing symbol queries.
rootURI := lsp.DocumentURI("git://github.com/gorilla/mux?4dbd923b0c9e99ff63ad54b0e9705ff92d3cdb06")
err := xlang.RemoteOneShotClientRequest(ctx, sourcegraphDotComBaseURL, c.mode, rootURI, "xsymbol/definition", nolocSyms, &locs)
if err != nil {
return nil, err
}
return locs, nil
}
return locs, nil
}
func (c *xclient) symbolDefinition(ctx context.Context, syms []lspext.SymbolLocationInformation) ([]lsp.Location, error) {
symInfos, err := c.xdefQuery(ctx, syms)
if err != nil {
return nil, err
}
locs := make([]lsp.Location, 0, len(symInfos))
for _, repoSymInfos := range symInfos {
for _, symInfo := range repoSymInfos {
locs = append(locs, symInfo.Location)
}
}
return locs, nil
}
// hoverCrossRepo translates hover requests in the current repository to a
// hover request on the definition in the definition's repository.
//
// Algorithm:
//
// 1. If we are hovering over a symbol in the current repository, use the
// normal textDocument/hover.
// 2. Use textDocument/xdefinition (sg extension) to retrieve symbol information.
// 3. symbolHover: Using the symbols, use the first successful hover in the
// definition repos.
// 4. If we do not find a non-empty hover and federation is enabled, we send
// the package query to Sourcegraph.com's xlang API. The assumption is the
// dependency is an OSS package so we can consult our public index. If the
// response is non-empty we return it.
// 5. If we do not find a non-empty hover, fallback to the normal hover.
func (c *xclient) hoverCrossRepo(ctx context.Context, params interface{}, opt ...jsonrpc2.CallOption) (*lsp.Hover, error) {
// Note: we can't parallelize the hover and xdefinition requests
// without breaking the request cancellation logic used by LSP
// proxy
// xdefinition request
var syms []lspext.SymbolLocationInformation
if err := c.Client.Call(ctx, "textDocument/xdefinition", params, &syms, opt...); err != nil {
return nil, errors.Wrap(err, "hoverCrossRepo: textDocument/xdefinition error")
}
// hover request
var hover lsp.Hover
if err := c.Client.Call(ctx, "textDocument/hover", params, &hover, opt...); err != nil {
return nil, errors.Wrap(err, "hoverCrossRepo: textDocument/hover error")
}
// return local hover if local definition found
for _, sym := range syms {
if sym.Location != (lsp.Location{}) {
return &hover, nil
}
}
// Cross repo hover is done via the symbols only.
xhov, err := c.symbolHover(ctx, syms)
if err != nil {
return nil, err
}
if len(xhov.Contents) > 0 {
// Range is for the queried token, so we need to use the local
// hover range.
xhov.Range = hover.Range
return xhov, nil
}
// Failed to find the hover locally, try symbolHover on Sourcegraph.com
// which may have indexed the OSS repo used.
if conf.JumpToDefOSSIndexEnabled() {
// HACK we need a valid rootURI, even though we are doing symbol queries.
rootURI := lsp.DocumentURI("git://github.com/gorilla/mux?4dbd923b0c9e99ff63ad54b0e9705ff92d3cdb06")
var remoteHov lsp.Hover
err := xlang.RemoteOneShotClientRequest(ctx, sourcegraphDotComBaseURL, c.mode, rootURI, "xsymbol/hover", syms, &remoteHov)
if err != nil {
return nil, err
}
if len(remoteHov.Contents) > 0 {
// Range is for the queried token, so we need to use the local
// hover range.
remoteHov.Range = hover.Range
return &remoteHov, nil
}
}
// Fallback to local hover contents.
return &hover, nil
}
// symbolHover finds a hover contents for the given symbols.
//
// Algorithm:
//
// 1. xdefQuery: Consult our packages index to find potential repositories
// containing the symbols.
// 2. xdefQuery: For each potential repository use workspace/symbol with our
// symbol query (sg extension).
// 3. For each symbol do a textDocument/hover. The first non-empty hover
// content we return to the user.
func (c *xclient) symbolHover(ctx context.Context, syms []lspext.SymbolLocationInformation) (*lsp.Hover, error) {
symInfos, err := c.xdefQuery(ctx, syms)
if err != nil {
return nil, err
}
// return first hover found
for rootURI, repoSymInfos := range symInfos {
for _, symInfo := range repoSymInfos {
pos := symInfo.Location.Range.Start
pos.Character++
p := lsp.TextDocumentPositionParams{
TextDocument: lsp.TextDocumentIdentifier{URI: symInfo.Location.URI},
Position: pos,
}
var xhov lsp.Hover
if err := xlang.UnsafeOneShotClientRequest(ctx, c.mode, rootURI, "textDocument/hover", p, &xhov); err != nil {
return nil, errors.Wrap(err, "hoverCrossRepo: external textDocument/hover error")
}
if len(xhov.Contents) > 0 {
return &xhov, nil
}
}
}
// nothing found, so empty response
return &lsp.Hover{}, nil
}
var sourcegraphDotComBaseURL = &url.URL{Scheme: "https", Host: "sourcegraph.com"}

View File

@ -0,0 +1,3 @@
// Package licensing handles parsing, verifying, and enforcing the product subscription (specified in
// site configuration).
package licensing

View File

@ -0,0 +1,125 @@
package licensing
import (
"context"
"fmt"
"html"
"net/http"
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
"github.com/sourcegraph/sourcegraph/cmd/frontend/hooks"
"github.com/sourcegraph/sourcegraph/pkg/errcode"
log15 "gopkg.in/inconshreveable/log15.v2"
)
// Enforce the use of a valid license key by preventing all HTTP requests if the license is invalid
// (due to a error in parsing or verification, or because the license has expired).
func init() {
hooks.PreAuthMiddleware = func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
info, err := GetConfiguredProductLicenseInfo()
if err != nil {
log15.Error("Error reading license key for Sourcegraph subscription.", "err", err)
WriteSubscriptionErrorResponse(w, http.StatusInternalServerError, "Error reading Sourcegraph license key", "Site admins may check the logs for more information.")
return
}
if info != nil && info.IsExpiredWithGracePeriod() {
WriteSubscriptionErrorResponse(w, http.StatusForbidden, "Sourcegraph license expired", "To continue using Sourcegraph, a site admin must renew the Sourcegraph license (or downgrade to only using Sourcegraph Core features).")
return
}
next.ServeHTTP(w, r)
})
}
}
// WriteSubscriptionErrorResponseForFeature is a wrapper around WriteSubscriptionErrorResponse that
// generates the error title and message indicating that the current license does not active the
// given feature.
func WriteSubscriptionErrorResponseForFeature(w http.ResponseWriter, featureNameHumanReadable string) {
WriteSubscriptionErrorResponse(
w, http.StatusForbidden,
fmt.Sprintf("License is not valid for %s", featureNameHumanReadable),
fmt.Sprintf("To use the %s feature, a site admin must upgrade the Sourcegraph license. (The site admin may also remove the site configuration that enables this feature to dismiss this message.)", featureNameHumanReadable),
)
}
// WriteSubscriptionErrorResponse writes an HTTP response that displays a standalone error page to
// the user.
//
// The title and message should be full sentences that describe the problem and how to fix it. Use
// WriteSubscriptionErrorResponseForFeature to generate these for the common case of a failed
// license feature check.
func WriteSubscriptionErrorResponse(w http.ResponseWriter, statusCode int, title, message string) {
w.WriteHeader(statusCode)
w.Header().Set("Cache-Control", "no-cache")
w.Header().Set("Content-Type", "text/html; charset=utf-8")
// Inline all styles and resources because those requests will fail (our middleware
// intercepts all HTTP requests).
fmt.Fprintln(w, `
<title>`+html.EscapeString(title)+` - Sourcegraph</title>
<style>
.bg {
position: absolute;
user-select: none;
pointer-events: none;
z-index: -1;
top: 0;
bottom: 0;
left: 0;
right: 0;
/* The Sourcegraph logo in SVG. */
background-image: url('data:image/svg+xml,<svg xmlns="http://www.w3.org/2000/svg" viewBox="0 0 124 127"><g fill="none" fill-rule="evenodd"><path fill="%23F96316" d="M35.942 16.276L63.528 117.12c1.854 6.777 8.85 10.768 15.623 8.912 6.778-1.856 10.765-8.854 8.91-15.63L60.47 9.555C58.615 2.78 51.62-1.212 44.847.645c-6.772 1.853-10.76 8.853-8.905 15.63z"/><path fill="%23B200F8" d="M87.024 15.894L17.944 93.9c-4.66 5.26-4.173 13.303 1.082 17.964 5.255 4.66 13.29 4.174 17.95-1.084l69.08-78.005c4.66-5.26 4.173-13.3-1.082-17.962-5.257-4.664-13.294-4.177-17.95 1.08z"/><path fill="%2300B4F2" d="M8.75 59.12l98.516 32.595c6.667 2.205 13.86-1.414 16.065-8.087 2.21-6.672-1.41-13.868-8.08-16.076L16.738 34.96c-6.67-2.207-13.86 1.412-16.066 8.085-2.204 6.672 1.416 13.87 8.08 16.075z"/></g></svg>');
background-repeat: repeat;
background-size: 5rem;
opacity: 0.1;
}
.msg {
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, 'Helvetica Neue', Arial, sans-serif;
max-width: 30rem;
margin: 20vh auto 0;
border: solid 2px rgba(255,0,0,0.8);
background-color: rgba(255,255,255,0.8);
color: rgb(30, 0, 0);
padding: 1rem 2rem;
}
h1 {
font-size: 1.5rem;
}
</style>
<meta name="robots" content="noindex">
<body>
<div class=bg></div>
<div class=msg><h1>`+html.EscapeString(title)+`</h1><p>`+html.EscapeString(message)+`</p><p>See <a href="https://about.sourcegraph.com/pricing">about.sourcegraph.com</a> for more information.</p></div>`)
}
// Enforce the license's max user count by preventing the creation of new users when the max is
// reached.
func init() {
db.Users.PreCreateUser = func(ctx context.Context) error {
info, err := GetConfiguredProductLicenseInfo()
if info == nil || err != nil {
return err
}
// Block creation of a new user beyond the licensed user count (unless true-up is allowed).
userCount, err := db.Users.Count(ctx, nil)
if err != nil {
return err
}
// Be conservative and treat 0 as unlimited. We don't plan to intentionally generate
// licenses with UserCount == 0, but that might result from a bug in license decoding, and
// we don't want that to immediately disable Sourcegraph instances.
if info.UserCount > 0 && userCount >= int(info.UserCount) {
if info.HasTag(TrueUpUserCountTag) {
log15.Info("Licensed user count exceeded, but license supports true-up and will not block creation of new user. The new user will be retroactively charged for in the next billing period. Contact sales@sourcegraph.com for help.", "activeUserCount", userCount, "licensedUserCount", info.UserCount)
} else {
return errcode.NewPresentationError("Unable to create user account: the Sourcegraph license's maximum user count has been reached. A site admin must upgrade the Sourcegraph subscription to allow for more users.")
}
}
return nil
}
}

View File

@ -0,0 +1,86 @@
package licensing
import (
"context"
"fmt"
"testing"
"github.com/sourcegraph/enterprise/pkg/license"
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
)
func TestEnforcementPreCreateUser(t *testing.T) {
tests := []struct {
license *license.Info
activeUserCount uint
wantErr bool
}{
// See the impl for why we treat UserCount == 0 as unlimited.
{
license: &license.Info{UserCount: 0},
activeUserCount: 5,
wantErr: false,
},
// Non-true-up licenses.
{
license: &license.Info{UserCount: 10},
activeUserCount: 0,
wantErr: false,
},
{
license: &license.Info{UserCount: 10},
activeUserCount: 5,
wantErr: false,
},
{
license: &license.Info{UserCount: 10},
activeUserCount: 9,
wantErr: false,
},
{
license: &license.Info{UserCount: 10},
activeUserCount: 10,
wantErr: true,
},
{
license: &license.Info{UserCount: 10},
activeUserCount: 11,
wantErr: true,
},
{
license: &license.Info{UserCount: 10},
activeUserCount: 12,
wantErr: true,
},
// True-up licenses.
{
license: &license.Info{Tags: []string{TrueUpUserCountTag}, UserCount: 10},
activeUserCount: 5,
wantErr: false,
},
{
license: &license.Info{Tags: []string{TrueUpUserCountTag}, UserCount: 10},
activeUserCount: 15,
wantErr: false,
},
}
for _, test := range tests {
t.Run(fmt.Sprintf("license %s with %d active users", test.license, test.activeUserCount), func(t *testing.T) {
MockGetConfiguredProductLicenseInfo = func() (*license.Info, error) {
return test.license, nil
}
defer func() { MockGetConfiguredProductLicenseInfo = nil }()
db.Mocks.Users.Count = func(context.Context, *db.UsersListOptions) (int, error) {
return int(test.activeUserCount), nil
}
defer func() { db.Mocks = db.MockStores{} }()
err := db.Users.PreCreateUser(context.Background())
if gotErr := (err != nil); gotErr != test.wantErr {
t.Errorf("got error %v, want %v", gotErr, test.wantErr)
}
})
}
}

View File

@ -0,0 +1,102 @@
package licensing
import (
"fmt"
"github.com/pkg/errors"
"github.com/sourcegraph/enterprise/pkg/license"
"github.com/sourcegraph/sourcegraph/pkg/errcode"
)
// Feature is a product feature that is selectively activated based on the current license key.
type Feature string
// The list of features. For each feature, add a new const here and the checking logic in
// isFeatureEnabled below.
const (
// FeatureExternalAuthProvider is whether external user authentication providers (aka "SSO") may
// be used.
FeatureExternalAuthProvider Feature = "sso-external-user-auth-provider"
// FeatureExtensionRegistry is whether publishing extensions to this Sourcegraph instance is
// allowed. If not, then extensions must be published to Sourcegraph.com. All instances may use
// extensions published to Sourcegraph.com.
FeatureExtensionRegistry Feature = "private-extension-registry"
// FeatureRemoteExtensionsAllowDisallow is whether the site admin may explictly specify a list
// of allowed remote extensions and prevent any other remote extensions from being used. It does
// not apply to locally published extensions.
FeatureRemoteExtensionsAllowDisallow = "remote-extensions-allow-disallow"
)
func isFeatureEnabled(info license.Info, feature Feature) bool {
// Add feature-specific logic here.
switch feature {
case FeatureExternalAuthProvider:
// Enterprise Starter and Enterprise both allow SSO. Core doesn't, but this func is only
// called when there is a valid license.
return true
case FeatureExtensionRegistry:
// Enterprise Starter does not support a local extension registry.
return !info.HasTag(EnterpriseStarterTag)
case FeatureRemoteExtensionsAllowDisallow:
// Enterprise Starter does not support explictly allowing/disallowing remote extensions by
// extension ID.
return !info.HasTag(EnterpriseStarterTag)
}
return false
}
// CheckFeature checks whether the feature is activated based on the current license. If it is
// disabled, it returns a non-nil error.
//
// The returned error may implement errcode.PresentationError to indicate that it can be displayed
// directly to the user. Use IsFeatureNotActivated to distinguish between the error reasons.
func CheckFeature(feature Feature) error {
info, err := GetConfiguredProductLicenseInfo()
if err != nil {
return errors.WithMessage(err, fmt.Sprintf("checking feature %q activation", feature))
}
if info == nil {
return newFeatureNotActivatedError(fmt.Sprintf("The feature %q is not activated because it requires a valid Sourcegraph license. Purchase a Sourcegraph subscription to activate this feature.", feature))
}
if !isFeatureEnabled(*info, feature) {
return newFeatureNotActivatedError(fmt.Sprintf("The feature %q is not activated for Sourcegraph Enterprise Starter. Upgrade to Sourcegraph Enterprise to use this feature.", feature))
}
return nil // feature is activated for current license
}
func newFeatureNotActivatedError(message string) featureNotActivatedError {
e := errcode.NewPresentationError(message).(errcode.PresentationError)
return featureNotActivatedError{e}
}
type featureNotActivatedError struct{ errcode.PresentationError }
// IsFeatureNotActivated reports whether err indicates that the license is valid but does not
// activate the feature.
//
// It is used to distinguish between the multiple reasons for errors from CheckFeature: either
// failed license verification, or a valid license that does not activate a feature (e.g.,
// Enterprise Starter not including an Enterprise-only feature).
func IsFeatureNotActivated(err error) bool {
if err == nil {
return false
}
_, ok := err.(featureNotActivatedError)
if !ok {
// Also check for the pointer type to guard against stupid mistakes.
_, ok = err.(*featureNotActivatedError)
}
return ok
}
// IsFeatureEnabledLenient reports whether the current license enables the given feature. If there
// is an error reading the license, it is lenient and returns true.
//
// This is useful for callers who don't want to handle errors (usually because the user would be
// prevented from getting to this point if license verification had failed, so it's not necessary to
// handle license verification errors here).
func IsFeatureEnabledLenient(feature Feature) bool {
return !IsFeatureNotActivated(CheckFeature(feature))
}

View File

@ -0,0 +1,27 @@
package licensing
import (
"testing"
"github.com/sourcegraph/enterprise/pkg/license"
)
func TestIsFeatureEnabled(t *testing.T) {
check := func(t *testing.T, feature Feature, licenseTags []string, wantEnabled bool) {
t.Helper()
got := isFeatureEnabled(license.Info{Tags: licenseTags}, feature)
if got != wantEnabled {
t.Errorf("got %v, want %v", got, wantEnabled)
}
}
t.Run(string(FeatureExternalAuthProvider), func(t *testing.T) {
check(t, FeatureExternalAuthProvider, EnterpriseStarterTags, true)
check(t, FeatureExternalAuthProvider, EnterpriseTags, true)
})
t.Run(string(FeatureExtensionRegistry), func(t *testing.T) {
check(t, FeatureExtensionRegistry, EnterpriseStarterTags, false)
check(t, FeatureExtensionRegistry, EnterpriseTags, true)
})
}

View File

@ -0,0 +1,162 @@
package licensing
import (
"fmt"
"log"
"os"
"sync"
"github.com/pkg/errors"
"github.com/sourcegraph/enterprise/pkg/license"
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
"github.com/sourcegraph/sourcegraph/pkg/conf"
"github.com/sourcegraph/sourcegraph/pkg/env"
"golang.org/x/crypto/ssh"
)
// publicKey is the public key used to verify product license keys.
//
// It is hardcoded here intentionally (we only have one private signing key, and we don't yet
// support/need key rotation). The corresponding private key is at
// https://team-sourcegraph.1password.com/vaults/dnrhbauihkhjs5ag6vszsme45a/allitems/zkdx6gpw4uqejs3flzj7ef5j4i
// and set below in SOURCEGRAPH_LICENSE_GENERATION_KEY.
var publicKey = func() ssh.PublicKey {
// To convert PKCS#8 format (which `openssl rsa -in key.pem -pubout` produces) to the format
// that ssh.ParseAuthorizedKey reads here, use `ssh-keygen -i -mPKCS8 -f key.pub`.
const publicKeyData = `ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDUUd9r83fGmYVLzcqQp5InyAoJB5lLxlM7s41SUUtxfnG6JpmvjNd+WuEptJGk0C/Zpyp/cCjCV4DljDs8Z7xjRbvJYW+vklFFxXrMTBs/+HjpIBKlYTmG8SqTyXyu1s4485Kh1fEC5SK6z2IbFaHuSHUXgDi/IepSOg1QudW4n8J91gPtT2E30/bPCBRq8oz/RVwJSDMvYYjYVb//LhV0Mx3O6hg4xzUNuwiCtNjCJ9t4YU2sV87+eJwWtQNbSQ8TelQa8WjG++XSnXUHw12bPDe7wGL/7/EJb7knggKSAMnpYpCyV35dyi4DsVc46c+b6P0gbVSosh3Uc3BJHSWF`
var err error
publicKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(publicKeyData))
if err != nil {
panic("failed to parse public key for license verification: " + err.Error())
}
return publicKey
}()
// ParseProductLicenseKey parses and verifies the license key using the license verification public
// key (publicKey in this package).
func ParseProductLicenseKey(licenseKey string) (*license.Info, error) {
return license.ParseSignedKey(licenseKey, publicKey)
}
// ParseProductLicenseKeyWithBuiltinOrGenerationKey is like ParseProductLicenseKey, except it tries
// parsing and verifying the license key with the license generation key (if set), instead of always
// using the builtin license key.
//
// It is useful for local development when using a test license generation key (whose signatures
// aren't considered valid when verified using the builtin public key).
func ParseProductLicenseKeyWithBuiltinOrGenerationKey(licenseKey string) (*license.Info, error) {
var k ssh.PublicKey
if licenseGenerationPrivateKey != nil {
k = licenseGenerationPrivateKey.PublicKey()
} else {
k = publicKey
}
return license.ParseSignedKey(licenseKey, k)
}
// Cache the parsing of the license key because public key crypto can be slow.
var (
mu sync.Mutex
lastKeyText string
lastInfo *license.Info
)
var MockGetConfiguredProductLicenseInfo func() (*license.Info, error)
// GetConfiguredProductLicenseInfo returns information about the current product license key
// specified in site configuration.
func GetConfiguredProductLicenseInfo() (*license.Info, error) {
if MockGetConfiguredProductLicenseInfo != nil {
return MockGetConfiguredProductLicenseInfo()
}
// Support reading the license key from the environment (intended for development, because we
// don't want to commit a valid license key to dev/config.json in the OSS repo).
keyText := os.Getenv("SOURCEGRAPH_LICENSE_KEY")
if keyText == "" {
keyText = conf.Get().LicenseKey
}
if keyText != "" {
mu.Lock()
defer mu.Unlock()
var info *license.Info
if keyText == lastKeyText {
info = lastInfo
} else {
var err error
info, err = ParseProductLicenseKey(keyText)
if err != nil {
return nil, err
}
lastKeyText = keyText
lastInfo = info
}
return info, nil
}
// No license key.
return nil, nil
}
// Make the Site.productSubscription GraphQL field return the actual info about the product license,
// if any.
func init() {
graphqlbackend.GetConfiguredProductLicenseInfo = func() (*graphqlbackend.ProductLicenseInfo, error) {
info, err := GetConfiguredProductLicenseInfo()
if info == nil || err != nil {
return nil, err
}
return &graphqlbackend.ProductLicenseInfo{
TagsValue: info.Tags,
UserCountValue: info.UserCount,
ExpiresAtValue: info.ExpiresAt,
}, nil
}
}
// licenseGenerationPrivateKeyURL is the URL where Sourcegraph staff can find the private key for
// generating licenses.
//
// NOTE: If you change this, use text search to replace other instances of it (in source code
// comments).
const licenseGenerationPrivateKeyURL = "https://team-sourcegraph.1password.com/vaults/dnrhbauihkhjs5ag6vszsme45a/allitems/zkdx6gpw4uqejs3flzj7ef5j4i"
// envLicenseGenerationPrivateKey (the env var SOURCEGRAPH_LICENSE_GENERATION_KEY) is the
// PEM-encoded form of the private key used to sign product license keys. It is stored at
// https://team-sourcegraph.1password.com/vaults/dnrhbauihkhjs5ag6vszsme45a/allitems/zkdx6gpw4uqejs3flzj7ef5j4i.
var envLicenseGenerationPrivateKey = env.Get("SOURCEGRAPH_LICENSE_GENERATION_KEY", "", "the PEM-encoded form of the private key used to sign product license keys ("+licenseGenerationPrivateKeyURL+")")
// licenseGenerationPrivateKey is the private key used to generate license keys.
var licenseGenerationPrivateKey = func() ssh.Signer {
if envLicenseGenerationPrivateKey == "" {
// Most Sourcegraph instances don't use/need this key. Generally only Sourcegraph.com and
// local dev will have this key set.
return nil
}
privateKey, err := ssh.ParsePrivateKey([]byte(envLicenseGenerationPrivateKey))
if err != nil {
log.Fatalf("Failed to parse private key in SOURCEGRAPH_LICENSE_GENERATION_KEY env var: %s.", err)
}
return privateKey
}()
// GenerateProductLicenseKey generates a product license key using the license generation private
// key configured in site configuration.
func GenerateProductLicenseKey(info license.Info) (string, error) {
if envLicenseGenerationPrivateKey == "" {
const msg = "no product license generation private key was configured"
if env.InsecureDev {
// Show more helpful error message in local dev.
return "", fmt.Errorf("%s (for testing by Sourcegraph staff: set the SOURCEGRAPH_LICENSE_GENERATION_KEY env var to the key obtained at %s)", msg, licenseGenerationPrivateKeyURL)
}
return "", errors.New(msg)
}
licenseKey, err := license.GenerateSignedKey(info, licenseGenerationPrivateKey)
if err != nil {
return "", err
}
return licenseKey, nil
}

View File

@ -0,0 +1,67 @@
package licensing
import (
"strings"
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
)
// Make the Site.productSubscription.productNameWithBrand GraphQL field (and other places) use the
// proper product name.
func init() {
graphqlbackend.GetProductNameWithBrand = productNameWithBrand
}
const (
// EnterpriseStarterTag is the license tag for Enterprise Starter (which includes only a subset
// of Enterprise features).
EnterpriseStarterTag = "starter"
// TrueUpUserCountTag is the license tag that indicates that the licensed user count can be
// exceeded and will be charged later.
TrueUpUserCountTag = "true-up"
)
var (
// EnterpriseStarterTags is the license tags for Enterprise Starter.
EnterpriseStarterTags = []string{EnterpriseStarterTag}
// EnterpriseTags is the license tags for Enterprise (intentionally empty because it has no
// feature restrictions)
EnterpriseTags = []string{}
)
// productNameWithBrand returns the product name with brand (e.g., "Sourcegraph Enterprise") based
// on the license info.
func productNameWithBrand(hasLicense bool, licenseTags []string) string {
if !hasLicense {
return "Sourcegraph Core"
}
hasTag := func(tag string) bool {
for _, t := range licenseTags {
if tag == t {
return true
}
}
return false
}
var name string
if hasTag("starter") {
name = " Starter"
}
var misc []string
if hasTag("trial") {
misc = append(misc, "trial")
}
if hasTag("dev") {
misc = append(misc, "dev use only")
}
if len(misc) > 0 {
name += " (" + strings.Join(misc, ", ") + ")"
}
return "Sourcegraph Enterprise" + name
}

View File

@ -0,0 +1,33 @@
package licensing
import (
"fmt"
"testing"
)
func TestProductNameWithBrand(t *testing.T) {
tests := []struct {
hasLicense bool
licenseTags []string
want string
}{
{hasLicense: false, want: "Sourcegraph Core"},
{hasLicense: true, licenseTags: nil, want: "Sourcegraph Enterprise"},
{hasLicense: true, licenseTags: []string{}, want: "Sourcegraph Enterprise"},
{hasLicense: true, licenseTags: []string{"x"}, want: "Sourcegraph Enterprise"}, // unrecognized tag "x" is ignored
{hasLicense: true, licenseTags: []string{"starter"}, want: "Sourcegraph Enterprise Starter"},
{hasLicense: true, licenseTags: []string{"trial"}, want: "Sourcegraph Enterprise (trial)"},
{hasLicense: true, licenseTags: []string{"dev"}, want: "Sourcegraph Enterprise (dev use only)"},
{hasLicense: true, licenseTags: []string{"starter", "trial"}, want: "Sourcegraph Enterprise Starter (trial)"},
{hasLicense: true, licenseTags: []string{"starter", "dev"}, want: "Sourcegraph Enterprise Starter (dev use only)"},
{hasLicense: true, licenseTags: []string{"starter", "trial", "dev"}, want: "Sourcegraph Enterprise Starter (trial, dev use only)"},
{hasLicense: true, licenseTags: []string{"trial", "dev"}, want: "Sourcegraph Enterprise (trial, dev use only)"},
}
for _, test := range tests {
t.Run(fmt.Sprintf("hasLicense=%v licenseTags=%v", test.hasLicense, test.licenseTags), func(t *testing.T) {
if got := productNameWithBrand(test.hasLicense, test.licenseTags); got != test.want {
t.Errorf("got %q, want %q", got, test.want)
}
})
}
}

View File

@ -0,0 +1,58 @@
package registry
import (
"github.com/sourcegraph/enterprise/cmd/frontend/internal/licensing"
frontendregistry "github.com/sourcegraph/sourcegraph/cmd/frontend/registry"
"github.com/sourcegraph/sourcegraph/pkg/conf"
"github.com/sourcegraph/sourcegraph/pkg/registry"
)
func init() {
frontendregistry.IsRemoteExtensionAllowed = func(extensionID string) bool {
allowedExtensions := getAllowedExtensionsFromSiteConfig()
if allowedExtensions == nil {
// Default is to allow all extensions.
return true
}
for _, x := range allowedExtensions {
if extensionID == x {
return true
}
}
return false
}
frontendregistry.FilterRemoteExtensions = func(extensions []*registry.Extension) []*registry.Extension {
allowedExtensions := getAllowedExtensionsFromSiteConfig()
if allowedExtensions == nil {
// Default is to allow all extensions.
return extensions
}
allow := make(map[string]interface{})
for _, id := range allowedExtensions {
allow[id] = struct{}{}
}
var keep []*registry.Extension
for _, x := range extensions {
if _, ok := allow[x.ExtensionID]; ok {
keep = append(keep, x)
}
}
return keep
}
}
func getAllowedExtensionsFromSiteConfig() []string {
// If the remote extension allow/disallow feature is not enabled, all remote extensions are
// allowed. This is achieved by a nil list.
if !licensing.IsFeatureEnabledLenient(licensing.FeatureRemoteExtensionsAllowDisallow) {
return nil
}
if c := conf.Get().Extensions; c != nil {
return c.AllowRemoteExtensions
}
return nil
}

View File

@ -0,0 +1,93 @@
package registry
import (
"reflect"
"sort"
"testing"
"github.com/sourcegraph/enterprise/cmd/frontend/internal/licensing"
"github.com/sourcegraph/enterprise/pkg/license"
frontendregistry "github.com/sourcegraph/sourcegraph/cmd/frontend/registry"
"github.com/sourcegraph/sourcegraph/pkg/conf"
"github.com/sourcegraph/sourcegraph/pkg/registry"
"github.com/sourcegraph/sourcegraph/schema"
)
func TestIsRemoteExtensionAllowed(t *testing.T) {
licensing.MockGetConfiguredProductLicenseInfo = func() (*license.Info, error) {
return &license.Info{Tags: licensing.EnterpriseTags}, nil
}
defer func() { licensing.MockGetConfiguredProductLicenseInfo = nil }()
defer conf.Mock(nil)
if !frontendregistry.IsRemoteExtensionAllowed("a") {
t.Errorf("want %q to be allowed", "a")
}
conf.Mock(&schema.SiteConfiguration{Extensions: &schema.Extensions{AllowRemoteExtensions: nil}})
if !frontendregistry.IsRemoteExtensionAllowed("a") {
t.Errorf("want %q to be allowed", "a")
}
conf.Mock(&schema.SiteConfiguration{Extensions: &schema.Extensions{AllowRemoteExtensions: []string{}}})
if frontendregistry.IsRemoteExtensionAllowed("a") {
t.Errorf("want %q to be disallowed", "a")
}
conf.Mock(&schema.SiteConfiguration{Extensions: &schema.Extensions{AllowRemoteExtensions: []string{"a"}}})
if !frontendregistry.IsRemoteExtensionAllowed("a") {
t.Errorf("want %q to be allowed", "a")
}
}
func sameElements(a, b []string) bool {
if len(a) != len(b) {
return false
}
aCopy := make([]string, len(a))
bCopy := make([]string, len(b))
copy(aCopy, a)
copy(bCopy, b)
sort.Strings(aCopy)
sort.Strings(bCopy)
return reflect.DeepEqual(aCopy, bCopy)
}
func TestFilterRemoteExtensions(t *testing.T) {
licensing.MockGetConfiguredProductLicenseInfo = func() (*license.Info, error) {
return &license.Info{Tags: licensing.EnterpriseTags}, nil
}
defer func() { licensing.MockGetConfiguredProductLicenseInfo = nil }()
run := func(allowRemoteExtensions *[]string, extensions []string, want []string) {
t.Helper()
if allowRemoteExtensions != nil {
conf.Mock(&schema.SiteConfiguration{Extensions: &schema.Extensions{AllowRemoteExtensions: *allowRemoteExtensions}})
defer conf.Mock(nil)
}
var xs []*registry.Extension
for _, id := range extensions {
xs = append(xs, &registry.Extension{ExtensionID: id})
}
got := []string{}
for _, x := range frontendregistry.FilterRemoteExtensions(xs) {
got = append(got, x.ExtensionID)
}
if !sameElements(got, want) {
t.Errorf("want %+v got %+v", want, got)
}
}
run(nil, []string{}, []string{})
run(nil, []string{"a"}, []string{"a"})
run(&[]string{}, []string{}, []string{})
run(&[]string{"a"}, []string{}, []string{})
run(&[]string{}, []string{"a"}, []string{})
run(&[]string{"a"}, []string{"b"}, []string{})
run(&[]string{"a"}, []string{"a"}, []string{"a"})
run(&[]string{"b", "c"}, []string{"a", "b", "c"}, []string{"b", "c"})
}

View File

@ -0,0 +1,2 @@
// Package registry contains the implementation of the extension registry.
package registry

View File

@ -0,0 +1,96 @@
package registry
import (
"fmt"
"net/http"
"net/url"
"path"
"path/filepath"
"strconv"
"strings"
"github.com/gorilla/mux"
frontendregistry "github.com/sourcegraph/sourcegraph/cmd/frontend/registry"
"github.com/sourcegraph/sourcegraph/pkg/conf"
"github.com/sourcegraph/sourcegraph/pkg/errcode"
)
func init() {
frontendregistry.HandleRegistryExtensionBundle = handleRegistryExtensionBundle
}
// handleRegistryExtensionBundle serves the bundled JavaScript source file or the source map for an
// extension in the registry as a raw JavaScript or JSON file.
func handleRegistryExtensionBundle(w http.ResponseWriter, r *http.Request) {
if conf.Extensions() == nil {
w.WriteHeader(http.StatusNotFound)
return
}
filename := mux.Vars(r)["RegistryExtensionReleaseFilename"]
ext := filepath.Ext(filename)
wantSourceMap := ext == ".map"
releaseIDStr := strings.TrimSuffix(filename, ext)
releaseID, err := strconv.ParseInt(releaseIDStr, 10, 64)
if err != nil {
http.Error(w, err.Error(), http.StatusBadRequest)
return
}
bundle, sourceMap, err := dbReleases{}.GetArtifacts(r.Context(), releaseID)
if errcode.IsNotFound(err) {
http.Error(w, err.Error(), http.StatusNotFound)
return
} else if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
// 🚨 SECURITY: Prevent this URL from being used in a <script> tag from other sites, because
// hosting user-provided scripts on this domain would let attackers steal sensitive data from
// anyone they lure to the attacker's site.
w.Header().Set("Content-Type", "text/plain; charset=utf-8")
w.Header().Set("Content-Security-Policy", "default-src 'none'; style-src 'unsafe-inline'; sandbox")
w.Header().Set("X-Frame-Options", "deny")
w.Header().Set("X-Content-Type-Options", "nosniff")
w.Header().Set("X-XSS-Protection", "1; mode=block")
// Allow downstream Sourcegraph sites' clients to access this file directly.
w.Header().Del("Access-Control-Allow-Credentials") // credentials are not needed
w.Header().Set("Access-Control-Allow-Origin", "*")
// We want to cache forever because an extension release is immutable, except that if the
// database is reset and and the registry_extension_releases.id sequence starts over, we don't
// want stale data from pre-reset. So, assume that the presence of a query string means that it
// includes some identifier that changes when the database is reset.
if r.URL.RawQuery != "" {
w.Header().Set("Cache-Control", "max-age=604800, private, immutable")
}
var data []byte
if wantSourceMap {
if sourceMap == nil {
http.Error(w, "extension has no source map", http.StatusNotFound)
return
}
data = sourceMap
} else {
data = bundle
}
w.Write(data)
if !wantSourceMap && sourceMap != nil {
// Append `//# sourceMappingURL=` directive to JS bundle if we have a source map. It is
// necessary to provide the absolute URL because the JS bundle is not loaded directly (e.g.,
// via importScripts); it is saved to a blob URL and then executed, which means any relative
// source map URL would be interpreted relative to the blob URL (so a relative URL wouldn't
// work). Also, we can't rely on the original sourceMappingURL directive (if provided at
// publish time) because it has no way of knowing the absolute URL to the source map.
//
// This implementation is not ideal because it means the JS bundle's contents depend on the
// app URL, which makes it technically not immutable. But given the blob URL constraint
// mentioned above, it's the best known solution.
if appURL, _ := url.Parse(conf.Get().AppURL); appURL != nil {
sourceMapURL := appURL.ResolveReference(&url.URL{Path: path.Join(path.Dir(r.URL.Path), releaseIDStr+".map")}).String()
fmt.Fprintf(w, "\n//# sourceMappingURL=%s", sourceMapURL)
}
}
}

View File

@ -0,0 +1,81 @@
package registry
import (
"context"
"strings"
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
"github.com/sourcegraph/sourcegraph/cmd/frontend/registry"
)
func init() {
registry.ListLocalRegistryExtensions = listLocalRegistryExtensions
registry.CountLocalRegistryExtensions = countLocalRegistryExtensions
}
func listLocalRegistryExtensions(ctx context.Context, args graphqlbackend.RegistryExtensionConnectionArgs) ([]graphqlbackend.RegistryExtension, error) {
if args.PrioritizeExtensionIDs != nil {
ids := filterStripLocalExtensionIDs(*args.PrioritizeExtensionIDs)
args.PrioritizeExtensionIDs = &ids
}
opt, err := toDBExtensionsListOptions(args)
if err != nil {
return nil, err
}
xs, err := dbExtensions{}.List(ctx, opt)
if err != nil {
return nil, err
}
if err := prefixLocalExtensionID(xs...); err != nil {
return nil, err
}
xs2 := make([]graphqlbackend.RegistryExtension, len(xs))
for i, x := range xs {
xs2[i] = &extensionDBResolver{v: x}
}
return xs2, nil
}
func countLocalRegistryExtensions(ctx context.Context, args graphqlbackend.RegistryExtensionConnectionArgs) (int, error) {
opt, err := toDBExtensionsListOptions(args)
if err != nil {
return 0, err
}
return dbExtensions{}.Count(ctx, opt)
}
func toDBExtensionsListOptions(args graphqlbackend.RegistryExtensionConnectionArgs) (dbExtensionsListOptions, error) {
var opt dbExtensionsListOptions
args.ConnectionArgs.Set(&opt.LimitOffset)
if args.Publisher != nil {
p, err := unmarshalRegistryPublisherID(*args.Publisher)
if err != nil {
return opt, err
}
opt.Publisher.UserID = p.userID
opt.Publisher.OrgID = p.orgID
}
if args.Query != nil {
opt.Query = *args.Query
}
if args.PrioritizeExtensionIDs != nil {
opt.PrioritizeExtensionIDs = *args.PrioritizeExtensionIDs
}
return opt, nil
}
// filterStripLocalExtensionIDs filters to local extension IDs and strips the
// host prefix.
func filterStripLocalExtensionIDs(extensionIDs []string) []string {
prefix := registry.GetLocalRegistryExtensionIDPrefix()
local := []string{}
for _, id := range extensionIDs {
parts := strings.SplitN(id, "/", 3)
if prefix != nil && len(parts) == 3 && parts[0] == *prefix {
local = append(local, parts[1]+"/"+parts[2])
} else if (prefix == nil || *prefix == "") && len(parts) == 2 {
local = append(local, id)
}
}
return local
}

View File

@ -0,0 +1,35 @@
package registry
import (
"net/url"
"reflect"
"testing"
"github.com/sourcegraph/sourcegraph/cmd/frontend/globals"
"github.com/sourcegraph/sourcegraph/pkg/conf"
"github.com/sourcegraph/sourcegraph/schema"
)
func TestFilteringExtensionIDs(t *testing.T) {
t.Run("filterStripLocalExtensionIDs on localhost", func(t *testing.T) {
conf.Mock(&schema.SiteConfiguration{AppURL: "http://localhost:3080"})
defer conf.Mock(nil)
input := []string{"localhost:3080/owner1/name1", "owner2/name2"}
want := []string{"owner1/name1"}
got := filterStripLocalExtensionIDs(input)
if !reflect.DeepEqual(got, want) {
t.Fatalf("got %+v, want %+v", got, want)
}
})
t.Run("filterStripLocalExtensionIDs on Sourcegraph.com", func(t *testing.T) {
oldAppURL := globals.AppURL
globals.AppURL = &url.URL{Scheme: "https", Host: "sourcegraph.com"}
defer func() { globals.AppURL = oldAppURL }()
input := []string{"localhost:3080/owner1/name1", "owner2/name2"}
want := []string{"owner2/name2"}
got := filterStripLocalExtensionIDs(input)
if !reflect.DeepEqual(got, want) {
t.Fatalf("got %+v, want %+v", got, want)
}
})
}

View File

@ -0,0 +1,72 @@
package registry
import (
"context"
"strings"
"time"
graphql "github.com/graph-gophers/graphql-go"
"github.com/sourcegraph/sourcegraph/cmd/frontend/backend"
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
"github.com/sourcegraph/sourcegraph/cmd/frontend/registry"
)
// extensionDBResolver implements the GraphQL type RegistryExtension.
type extensionDBResolver struct {
v *dbExtension
}
func (r *extensionDBResolver) ID() graphql.ID {
return registry.MarshalRegistryExtensionID(registry.RegistryExtensionID{LocalID: r.v.ID})
}
func (r *extensionDBResolver) UUID() string { return r.v.UUID }
func (r *extensionDBResolver) ExtensionID() string { return r.v.NonCanonicalExtensionID }
func (r *extensionDBResolver) ExtensionIDWithoutRegistry() string {
if r.v.NonCanonicalRegistry != "" {
return strings.TrimPrefix(r.v.NonCanonicalExtensionID, r.v.NonCanonicalRegistry+"/")
}
return r.v.NonCanonicalExtensionID
}
func (r *extensionDBResolver) Publisher(ctx context.Context) (graphqlbackend.RegistryPublisher, error) {
return getRegistryPublisher(ctx, r.v.Publisher)
}
func (r *extensionDBResolver) Name() string { return r.v.Name }
func (r *extensionDBResolver) Manifest(ctx context.Context) (graphqlbackend.ExtensionManifest, error) {
manifest, err := getExtensionManifestWithBundleURL(ctx, r.v.NonCanonicalExtensionID, r.v.ID, "release")
if err != nil {
return nil, err
}
return registry.NewExtensionManifest(manifest), nil
}
func (r *extensionDBResolver) CreatedAt() *string {
return strptr(r.v.CreatedAt.Format(time.RFC3339))
}
func (r *extensionDBResolver) UpdatedAt() *string {
return strptr(r.v.UpdatedAt.Format(time.RFC3339))
}
func (r *extensionDBResolver) URL() string {
return registry.ExtensionURL(r.v.NonCanonicalExtensionID)
}
func (r *extensionDBResolver) RemoteURL() *string { return nil }
func (r *extensionDBResolver) RegistryName() (string, error) {
return r.v.NonCanonicalRegistry, nil
}
func (r *extensionDBResolver) IsLocal() bool { return true }
func (r *extensionDBResolver) ViewerCanAdminister(ctx context.Context) (bool, error) {
err := toRegistryPublisherID(r.v).viewerCanAdminister(ctx)
if err == backend.ErrMustBeSiteAdmin || err == backend.ErrNotAnOrgMember || err == backend.ErrNotAuthenticated {
return false, nil
}
return err == nil, err
}
func strptr(s string) *string { return &s }

View File

@ -0,0 +1,71 @@
package registry
import (
"context"
"encoding/json"
"fmt"
"net/url"
"path"
"strconv"
"github.com/sourcegraph/sourcegraph/pkg/conf"
"github.com/sourcegraph/sourcegraph/pkg/errcode"
"github.com/sourcegraph/sourcegraph/pkg/jsonc"
)
// validateExtensionManifest validates a JSON extension manifest for syntax.
//
// TODO(sqs): Also validate it against the JSON Schema.
func validateExtensionManifest(text string) error {
var o interface{}
return jsonc.Unmarshal(text, &o)
}
// getExtensionManifestWithBundleURL returns the extension manifest as JSON. If there are no
// releases, it returns a nil manifest. If the manifest has no "url" field itself, a "url" field
// pointing to the extension's bundle is inserted.
func getExtensionManifestWithBundleURL(ctx context.Context, extensionID string, registryExtensionID int32, releaseTag string) (*string, error) {
var manifest *string
release, err := dbReleases{}.GetLatest(ctx, registryExtensionID, releaseTag, false)
if err != nil && !errcode.IsNotFound(err) {
return nil, err
}
if release != nil {
// Add URL to bundle if necessary.
var o map[string]interface{}
if err := jsonc.Unmarshal(release.Manifest, &o); err != nil {
return nil, fmt.Errorf("parsing extension manifest for extension with ID %d (release tag %q): %s", registryExtensionID, releaseTag, err)
}
if o == nil {
o = map[string]interface{}{}
}
urlStr, _ := o["url"].(string)
if urlStr == "" {
// Insert "url" field with link to bundle file on this site.
bundleURL, err := makeExtensionBundleURL(release.ID, release.CreatedAt.UnixNano(), extensionID)
if err != nil {
return nil, err
}
o["url"] = bundleURL
b, err := json.MarshalIndent(o, "", " ")
if err != nil {
return nil, err
}
release.Manifest = string(b)
}
manifest = &release.Manifest
}
return manifest, nil
}
func makeExtensionBundleURL(registryExtensionReleaseID int64, timestamp int64, extensionIDHint string) (string, error) {
u, err := url.Parse(conf.Get().AppURL)
if err != nil {
return "", err
}
u.Path = path.Join(u.Path, fmt.Sprintf("/-/static/extension/%d.js", registryExtensionReleaseID))
u.RawQuery = extensionIDHint + "--" + strconv.FormatInt(timestamp, 36) // meaningless value, just for cache-busting
return u.String(), nil
}

View File

@ -0,0 +1,63 @@
package registry
import (
"context"
"encoding/json"
"reflect"
"testing"
)
func TestGetExtensionManifestWithBundleURL(t *testing.T) {
resetMocks()
ctx := context.Background()
nilOrEmpty := func(s *string) string {
if s == nil {
return ""
}
return *s
}
t.Run(`manifest with "url"`, func(t *testing.T) {
mocks.releases.GetLatest = func(registryExtensionID int32, releaseTag string, includeArtifacts bool) (*dbRelease, error) {
return &dbRelease{
Manifest: `{"name":"x","url":"u"}`,
}, nil
}
defer func() { mocks.releases.GetLatest = nil }()
manifest, err := getExtensionManifestWithBundleURL(ctx, "x", 1, "t")
if err != nil {
t.Fatal(err)
}
if want := `{"name":"x","url":"u"}`; manifest == nil || !jsonDeepEqual(*manifest, want) {
t.Errorf("got %q, want %q", nilOrEmpty(manifest), want)
}
})
t.Run(`manifest without "url"`, func(t *testing.T) {
mocks.releases.GetLatest = func(registryExtensionID int32, releaseTag string, includeArtifacts bool) (*dbRelease, error) {
return &dbRelease{
Manifest: `{"name":"x"}`,
}, nil
}
defer func() { mocks.releases.GetLatest = nil }()
manifest, err := getExtensionManifestWithBundleURL(ctx, "x", 1, "t")
if err != nil {
t.Fatal(err)
}
if want := `{"name":"x","url":"/-/static/extension/0.js?x---1fmlvpbbdw2yo"}`; manifest == nil || !jsonDeepEqual(*manifest, want) {
t.Errorf("got %q, want %q", nilOrEmpty(manifest), want)
}
})
}
func jsonDeepEqual(a, b string) bool {
var va, vb interface{}
if err := json.Unmarshal([]byte(a), &va); err != nil {
panic(err)
}
if err := json.Unmarshal([]byte(b), &vb); err != nil {
panic(err)
}
return reflect.DeepEqual(va, vb)
}

View File

@ -0,0 +1,37 @@
package registry
import (
"context"
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
"github.com/sourcegraph/sourcegraph/cmd/frontend/registry"
"github.com/sourcegraph/sourcegraph/pkg/conf"
)
func init() {
conf.DefaultRemoteRegistry = "https://sourcegraph.com/.api/registry"
registry.GetLocalExtensionByExtensionID = func(ctx context.Context, extensionIDWithoutPrefix string) (graphqlbackend.RegistryExtension, error) {
x, err := dbExtensions{}.GetByExtensionID(ctx, extensionIDWithoutPrefix)
if err != nil {
return nil, err
}
if err := prefixLocalExtensionID(x); err != nil {
return nil, err
}
return &extensionDBResolver{v: x}, nil
}
}
// prefixLocalExtensionID adds the local registry's extension ID prefix (from
// GetLocalRegistryExtensionIDPrefix) to all extensions' extension IDs in the list.
func prefixLocalExtensionID(xs ...*dbExtension) error {
prefix := registry.GetLocalRegistryExtensionIDPrefix()
if prefix == nil {
return nil
}
for _, x := range xs {
x.NonCanonicalExtensionID = *prefix + "/" + x.NonCanonicalExtensionID
x.NonCanonicalRegistry = *prefix
}
return nil
}

View File

@ -0,0 +1,323 @@
package registry
import (
"context"
"database/sql"
"fmt"
"strings"
"time"
"github.com/google/uuid"
"github.com/keegancsmith/sqlf"
"github.com/pkg/errors"
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
"github.com/sourcegraph/sourcegraph/cmd/frontend/db/dbconn"
)
// dbExtension describes an extension in the extension registry.
//
// It is the internal form of github.com/sourcegraph/sourcegraph/pkg/registry.Extension (which is
// the external API type). These types should generally be kept in sync, but registry.Extension
// updates require backcompat.
type dbExtension struct {
ID int32
UUID string
Publisher dbPublisher
Name string
CreatedAt time.Time
UpdatedAt time.Time
// NonCanonicalExtensionID is the denormalized fully qualified extension ID
// ("[registry/]publisher/name" format), using the username/name of the extension's publisher
// (joined from another table) as of when the query executed. Do not persist this, because the
// (denormalized) registry and publisher names can change.
//
// If this value is obtained directly from a method on RegistryExtensions, this field will never
// contain the registry name prefix (which is necessary to distinguish local extensions from
// remote extensions). Call prefixLocalExtensionID to add it. The recommended way to apply this
// automatically (when needed) is to use registry.GetExtensionByExtensionID instead of
// (dbExtensions).GetByExtensionID.
NonCanonicalExtensionID string
// NonCanonicalRegistry is the denormalized registry name (as of when this field was set). This
// field is only set by prefixLocalExtensionID and is always empty if this value is obtained
// directly from a method on RegistryExtensions. Do not persist this value, because the
// (denormalized) registry name can change.
NonCanonicalRegistry string
}
type dbExtensions struct{}
// extensionNotFoundError occurs when an extension is not found in the extension registry.
type extensionNotFoundError struct {
args []interface{}
}
// NotFound implements errcode.NotFounder.
func (err extensionNotFoundError) NotFound() bool { return true }
func (err extensionNotFoundError) Error() string {
return fmt.Sprintf("registry extension not found: %v", err.args)
}
// Create creates a new extension in the extension registry. Exactly 1 of publisherUserID and publisherOrgID must be nonzero.
func (s dbExtensions) Create(ctx context.Context, publisherUserID, publisherOrgID int32, name string) (id int32, err error) {
if mocks.extensions.Create != nil {
return mocks.extensions.Create(publisherUserID, publisherOrgID, name)
}
if publisherUserID != 0 && publisherOrgID != 0 {
return 0, errors.New("at most 1 of the publisher user/org may be set")
}
uuid, err := uuid.NewRandom()
if err != nil {
return 0, err
}
if err := dbconn.Global.QueryRowContext(ctx,
// Include users/orgs table query (with "FOR UPDATE") to ensure that the publisher user/org
// not been deleted. If it was deleted, the query will return an error.
`
INSERT INTO registry_extensions(uuid, publisher_user_id, publisher_org_id, name)
VALUES(
$1,
(SELECT id FROM users WHERE id=$2 AND deleted_at IS NULL FOR UPDATE),
(SELECT id FROM orgs WHERE id=$3 AND deleted_at IS NULL FOR UPDATE),
$4
)
RETURNING id
`,
uuid, publisherUserID, publisherOrgID, name,
).Scan(&id); err != nil {
return 0, err
}
return id, nil
}
// GetByID retrieves the registry extension (if any) given its ID.
//
// 🚨 SECURITY: The caller must ensure that the actor is permitted to view this registry extension.
func (s dbExtensions) GetByID(ctx context.Context, id int32) (*dbExtension, error) {
if mocks.extensions.GetByID != nil {
return mocks.extensions.GetByID(id)
}
results, err := s.list(ctx, []*sqlf.Query{sqlf.Sprintf("x.id=%d", id)}, nil, nil)
if err != nil {
return nil, err
}
if len(results) == 0 {
return nil, extensionNotFoundError{[]interface{}{id}}
}
return results[0], nil
}
// GetByUUID retrieves the registry extension (if any) given its UUID.
//
// 🚨 SECURITY: The caller must ensure that the actor is permitted to view this registry extension.
func (s dbExtensions) GetByUUID(ctx context.Context, uuid string) (*dbExtension, error) {
if mocks.extensions.GetByUUID != nil {
return mocks.extensions.GetByUUID(uuid)
}
results, err := s.list(ctx, []*sqlf.Query{sqlf.Sprintf("x.uuid=%d", uuid)}, nil, nil)
if err != nil {
return nil, err
}
if len(results) == 0 {
return nil, extensionNotFoundError{[]interface{}{uuid}}
}
return results[0], nil
}
const (
// extensionPublisherNameExpr is the SQL expression for the extension's publisher's name (using
// the table aliases created by (dbExtensions).listCountSQL.
extensionPublisherNameExpr = "COALESCE(users.username, orgs.name)"
// extensionIDExpr is the SQL expression for the extension ID (using the table aliases created by
// (dbExtensions).listCountSQL.
extensionIDExpr = "CONCAT(" + extensionPublisherNameExpr + ", '/', x.name)"
)
// GetByExtensionID retrieves the registry extension (if any) given its extension ID, which is the
// concatenation of the publisher name, a slash ("/"), and the extension name.
//
// 🚨 SECURITY: The caller must ensure that the actor is permitted to view this registry extension.
func (s dbExtensions) GetByExtensionID(ctx context.Context, extensionID string) (*dbExtension, error) {
if mocks.extensions.GetByExtensionID != nil {
return mocks.extensions.GetByExtensionID(extensionID)
}
// TODO(sqs): prevent the creation of an org with the same name as a user so that there is no
// ambiguity as to whether the publisher refers to a user or org by the given name
// (https://github.com/sourcegraph/sourcegraph/issues/12068).
parts := strings.SplitN(extensionID, "/", 2)
if len(parts) < 2 {
return nil, extensionNotFoundError{[]interface{}{fmt.Sprintf("extensionID %q", extensionID)}}
}
publisherName := parts[0]
extensionName := parts[1]
results, err := s.list(ctx, []*sqlf.Query{
sqlf.Sprintf("x.name=%s", extensionName),
sqlf.Sprintf("(users.username=%s OR orgs.name=%s)", publisherName, publisherName),
}, nil, nil)
if err != nil {
return nil, err
}
if len(results) == 0 {
return nil, extensionNotFoundError{[]interface{}{fmt.Sprintf("extensionID %q", extensionID)}}
}
return results[0], nil
}
// dbExtensionsListOptions contains options for listing registry extensions.
type dbExtensionsListOptions struct {
Publisher dbPublisher
Query string // matches the extension ID
PrioritizeExtensionIDs []string
*db.LimitOffset
}
func (o dbExtensionsListOptions) sqlConditions() []*sqlf.Query {
var conds []*sqlf.Query
if o.Publisher.UserID != 0 {
conds = append(conds, sqlf.Sprintf("x.publisher_user_id=%d", o.Publisher.UserID))
}
if o.Publisher.OrgID != 0 {
conds = append(conds, sqlf.Sprintf("x.publisher_org_id=%d", o.Publisher.OrgID))
}
if o.Query != "" {
conds = append(conds, sqlf.Sprintf(extensionIDExpr+" ILIKE %s", "%"+strings.Replace(strings.ToLower(o.Query), " ", "%", -1)+"%"))
}
if len(conds) == 0 {
conds = append(conds, sqlf.Sprintf("TRUE"))
}
return conds
}
func (o dbExtensionsListOptions) sqlOrder() []*sqlf.Query {
ids := make([]*sqlf.Query, len(o.PrioritizeExtensionIDs)+1)
for i, id := range o.PrioritizeExtensionIDs {
ids[i] = sqlf.Sprintf("%v", string(id))
}
ids[len(o.PrioritizeExtensionIDs)] = sqlf.Sprintf("NULL")
return []*sqlf.Query{sqlf.Sprintf(extensionIDExpr+` IN (%v) ASC`, sqlf.Join(ids, ","))}
}
// List lists all registry extensions that satisfy the options.
//
// 🚨 SECURITY: The caller must ensure that the actor is permitted to list with the specified
// options.
func (s dbExtensions) List(ctx context.Context, opt dbExtensionsListOptions) ([]*dbExtension, error) {
return s.list(ctx, opt.sqlConditions(), opt.sqlOrder(), opt.LimitOffset)
}
func (dbExtensions) listCountSQL(conds []*sqlf.Query) *sqlf.Query {
return sqlf.Sprintf(`
FROM registry_extensions x
LEFT JOIN users ON users.id=publisher_user_id AND users.deleted_at IS NULL
LEFT JOIN orgs ON orgs.id=publisher_org_id AND orgs.deleted_at IS NULL
WHERE (%s) AND x.deleted_at IS NULL`,
sqlf.Join(conds, ") AND ("))
}
func (s dbExtensions) list(ctx context.Context, conds, order []*sqlf.Query, limitOffset *db.LimitOffset) ([]*dbExtension, error) {
order = append(order, sqlf.Sprintf("TRUE"))
q := sqlf.Sprintf(`
SELECT x.id, x.uuid, x.publisher_user_id, x.publisher_org_id, x.name, x.created_at, x.updated_at,
`+extensionIDExpr+` AS non_canonical_extension_id, `+extensionPublisherNameExpr+` AS non_canonical_publisher_name
%s
ORDER BY %s, x.id ASC
%s`,
s.listCountSQL(conds),
sqlf.Join(order, ","),
limitOffset.SQL(),
)
rows, err := dbconn.Global.QueryContext(ctx, q.Query(sqlf.PostgresBindVar), q.Args()...)
if err != nil {
return nil, err
}
defer rows.Close()
var results []*dbExtension
for rows.Next() {
var t dbExtension
var publisherUserID, publisherOrgID sql.NullInt64
if err := rows.Scan(&t.ID, &t.UUID, &publisherUserID, &publisherOrgID, &t.Name, &t.CreatedAt, &t.UpdatedAt, &t.NonCanonicalExtensionID, &t.Publisher.NonCanonicalName); err != nil {
return nil, err
}
t.Publisher.UserID = int32(publisherUserID.Int64)
t.Publisher.OrgID = int32(publisherOrgID.Int64)
results = append(results, &t)
}
return results, nil
}
// Count counts all registry extensions that satisfy the options (ignoring limit and offset).
//
// 🚨 SECURITY: The caller must ensure that the actor is permitted to count the results.
func (s dbExtensions) Count(ctx context.Context, opt dbExtensionsListOptions) (int, error) {
q := sqlf.Sprintf("SELECT COUNT(*) %s", s.listCountSQL(opt.sqlConditions()))
var count int
if err := dbconn.Global.QueryRowContext(ctx, q.Query(sqlf.PostgresBindVar), q.Args()...).Scan(&count); err != nil {
return 0, err
}
return count, nil
}
// Update updates information about the registry extension.
func (dbExtensions) Update(ctx context.Context, id int32, name *string) error {
if mocks.extensions.Update != nil {
return mocks.extensions.Update(id, name)
}
res, err := dbconn.Global.ExecContext(ctx,
"UPDATE registry_extensions SET name=COALESCE($2, name), updated_at=now() WHERE id=$1 AND deleted_at IS NULL",
id, name,
)
if err != nil {
return err
}
nrows, err := res.RowsAffected()
if err != nil {
return err
}
if nrows == 0 {
return extensionNotFoundError{[]interface{}{id}}
}
return nil
}
// Delete marks an registry extension as deleted.
func (dbExtensions) Delete(ctx context.Context, id int32) error {
if mocks.extensions.Delete != nil {
return mocks.extensions.Delete(id)
}
res, err := dbconn.Global.ExecContext(ctx, "UPDATE registry_extensions SET deleted_at=now() WHERE id=$1 AND deleted_at IS NULL", id)
if err != nil {
return err
}
nrows, err := res.RowsAffected()
if err != nil {
return err
}
if nrows == 0 {
return extensionNotFoundError{[]interface{}{id}}
}
return nil
}
// mockExtensions mocks the registry extensions store.
type mockExtensions struct {
Create func(publisherUserID, publisherOrgID int32, name string) (int32, error)
GetByID func(id int32) (*dbExtension, error)
GetByUUID func(uuid string) (*dbExtension, error)
GetByExtensionID func(extensionID string) (*dbExtension, error)
Update func(id int32, name *string) error
Delete func(id int32) error
}

View File

@ -0,0 +1,273 @@
package registry
import (
"encoding/json"
"reflect"
"testing"
"time"
"github.com/lib/pq"
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
dbtesting "github.com/sourcegraph/sourcegraph/cmd/frontend/db/testing"
"github.com/sourcegraph/sourcegraph/pkg/errcode"
)
// registryExtensionNamesForTests is a list of test cases containing valid and invalid registry
// extension names.
var registryExtensionNamesForTests = []struct {
name string
wantValid bool
}{
{"", false},
{"a", true},
{"-a", false},
{"a-", false},
{"a-b", true},
{"a--b", false},
{"a---b", false},
{"a.b", true},
{"a..b", false},
{"a...b", false},
{"a_b", true},
{"a__b", false},
{"a___b", false},
{"a-.b", false},
{"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx", false},
}
func TestRegistryExtensions_validUsernames(t *testing.T) {
if testing.Short() {
t.Skip()
}
ctx := dbtesting.TestContext(t)
user, err := db.Users.Create(ctx, db.NewUser{Username: "u"})
if err != nil {
t.Fatal(err)
}
for _, test := range registryExtensionNamesForTests {
t.Run(test.name, func(t *testing.T) {
valid := true
if _, err := (dbExtensions{}).Create(ctx, user.ID, 0, test.name); err != nil {
if e, ok := err.(*pq.Error); ok && (e.Constraint == "registry_extensions_name_valid_chars" || e.Constraint == "registry_extensions_name_length") {
valid = false
} else {
t.Fatal(err)
}
}
if valid != test.wantValid {
t.Errorf("%q: got valid %v, want %v", test.name, valid, test.wantValid)
}
})
}
}
func TestRegistryExtensions(t *testing.T) {
if testing.Short() {
t.Skip()
}
ctx := dbtesting.TestContext(t)
testGetByID := func(t *testing.T, id int32, want *dbExtension, wantPublisherName string) {
t.Helper()
x, err := dbExtensions{}.GetByID(ctx, id)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(x, want) {
t.Errorf("got %+v, want %+v", x, want)
}
if x.Publisher.NonCanonicalName != wantPublisherName {
t.Errorf("got publisher name %q, want %q", x.Publisher.NonCanonicalName, wantPublisherName)
}
}
testGetByExtensionID := func(t *testing.T, extensionID string, want *dbExtension) {
t.Helper()
x, err := dbExtensions{}.GetByExtensionID(ctx, extensionID)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(x, want) {
t.Errorf("got %+v, want %+v", x, want)
}
if x.NonCanonicalExtensionID != extensionID {
t.Errorf("got extension ID %q, want %q", x.NonCanonicalExtensionID, extensionID)
}
}
testList := func(t *testing.T, opt dbExtensionsListOptions, want []*dbExtension) {
t.Helper()
if ois, err := (dbExtensions{}).List(ctx, opt); err != nil {
t.Fatal(err)
} else if !reflect.DeepEqual(ois, want) {
t.Errorf("got %s, want %s", asJSON(t, ois), asJSON(t, want))
}
}
testListCount := func(t *testing.T, opt dbExtensionsListOptions, want []*dbExtension) {
t.Helper()
testList(t, opt, want)
if n, err := (dbExtensions{}).Count(ctx, opt); err != nil {
t.Fatal(err)
} else if want := len(want); n != want {
t.Errorf("got %d, want %d", n, want)
}
}
user, err := db.Users.Create(ctx, db.NewUser{Username: "u"})
if err != nil {
t.Fatal(err)
}
org, err := db.Orgs.Create(ctx, "o", nil)
if err != nil {
t.Fatal(err)
}
createAndGet := func(t *testing.T, publisherUserID, publisherOrgID int32, name string) *dbExtension {
t.Helper()
xID, err := dbExtensions{}.Create(ctx, publisherUserID, publisherOrgID, name)
if err != nil {
t.Fatal(err)
}
x, err := dbExtensions{}.GetByID(ctx, xID)
if err != nil {
t.Fatal(err)
}
return x
}
xu := createAndGet(t, user.ID, 0, "xu")
xo := createAndGet(t, 0, org.ID, "xo")
t.Run("List/Count/Get publishers", func(t *testing.T) {
publishers, err := dbExtensions{}.ListPublishers(ctx, dbPublishersListOptions{})
if err != nil {
t.Fatal(err)
}
if want := []*dbPublisher{
&xo.Publisher,
&xu.Publisher,
}; !reflect.DeepEqual(publishers, want) {
t.Errorf("got publishers %+v, want %+v", publishers, want)
}
if n, err := (dbExtensions{}).CountPublishers(ctx, dbPublishersListOptions{}); err != nil {
t.Fatal(err)
} else if want := 2; n != 2 {
t.Errorf("got count %d, want %d", n, want)
}
for _, p := range []*dbPublisher{&xo.Publisher, &xu.Publisher} {
got, err := dbExtensions{}.GetPublisher(ctx, p.NonCanonicalName)
if err != nil {
t.Fatal(err)
}
if !reflect.DeepEqual(got, p) {
t.Errorf("got %+v, want %+v", got, p)
}
}
if _, err := (dbExtensions{}).GetPublisher(ctx, "doesntexist"); !errcode.IsNotFound(err) {
t.Errorf("got err %v, want errcode.IsNotFound", err)
}
})
publishers := map[string]struct {
publisherUserID, publisherOrgID int32
publisherName string
}{
"user": {publisherUserID: user.ID, publisherName: "u"},
"org": {publisherOrgID: org.ID, publisherName: "o"},
}
for name, c := range publishers {
t.Run(name+" publisher", func(t *testing.T) {
x := createAndGet(t, c.publisherUserID, c.publisherOrgID, "x")
t.Run("GetByID", func(t *testing.T) {
testGetByID(t, x.ID, x, c.publisherName)
if _, err := (dbExtensions{}).GetByID(ctx, 12345 /* doesn't exist */); !errcode.IsNotFound(err) {
t.Errorf("got err %v, want errcode.IsNotFound", err)
}
})
t.Run("GetByExtensionID", func(t *testing.T) {
testGetByExtensionID(t, c.publisherName+"/"+x.Name, x)
if _, err := (dbExtensions{}).GetByExtensionID(ctx, "foo.bar"); !errcode.IsNotFound(err) {
t.Errorf("got err %v, want errcode.IsNotFound", err)
}
})
t.Run("List/Count all", func(t *testing.T) {
testListCount(t, dbExtensionsListOptions{}, []*dbExtension{xu, xo, x})
})
wantByPublisherUser := []*dbExtension{xu}
wantByPublisherOrg := []*dbExtension{xo}
var wantByCurrent []*dbExtension
if c.publisherUserID != 0 {
wantByPublisherUser = append(wantByPublisherUser, x)
wantByCurrent = wantByPublisherUser
} else {
wantByPublisherOrg = append(wantByPublisherOrg, x)
wantByCurrent = wantByPublisherOrg
}
t.Run("List/Count by PublisherUserID", func(t *testing.T) {
testListCount(t, dbExtensionsListOptions{Publisher: dbPublisher{UserID: user.ID}}, wantByPublisherUser)
})
t.Run("List/Count by Publisher.OrgID", func(t *testing.T) {
testListCount(t, dbExtensionsListOptions{Publisher: dbPublisher{OrgID: org.ID}}, wantByPublisherOrg)
})
t.Run("List/Count by Publisher.Query all", func(t *testing.T) {
testListCount(t, dbExtensionsListOptions{Query: "x"}, []*dbExtension{xu, xo, x})
})
t.Run("List/Count by Publisher.Query one", func(t *testing.T) {
testListCount(t, dbExtensionsListOptions{Query: c.publisherName + "/" + x.Name}, wantByCurrent)
})
t.Run("List/Count with prioritizeExtensionIDs", func(t *testing.T) {
testList(t, dbExtensionsListOptions{PrioritizeExtensionIDs: []string{xu.NonCanonicalExtensionID}, LimitOffset: &db.LimitOffset{Limit: 1}}, []*dbExtension{xu})
testList(t, dbExtensionsListOptions{PrioritizeExtensionIDs: []string{xo.NonCanonicalExtensionID}, LimitOffset: &db.LimitOffset{Limit: 1}}, []*dbExtension{xo})
})
if err := (dbExtensions{}).Delete(ctx, x.ID); err != nil {
t.Fatal(err)
}
if err := (dbExtensions{}).Delete(ctx, x.ID); !errcode.IsNotFound(err) {
t.Errorf("2nd Delete: got err %v, want errcode.IsNotFound", err)
}
if _, err := (dbExtensions{}).GetByID(ctx, x.ID); !errcode.IsNotFound(err) {
t.Errorf("GetByID after Delete: got err %v, want errcode.IsNotFound", err)
}
})
}
t.Run("Update", func(t *testing.T) {
x := xu
if err := (dbExtensions{}).Update(ctx, x.ID, nil); err != nil {
t.Fatal(err)
}
x1, err := dbExtensions{}.GetByID(ctx, x.ID)
if err != nil {
t.Fatal(err)
}
if time.Since(x1.UpdatedAt) > 1*time.Minute {
t.Errorf("got UpdatedAt %v, want recent", x1.UpdatedAt)
}
if x1.Name != x.Name {
t.Errorf("got name %q, want %q", x1.Name, x.Name)
}
})
t.Run("Create with same publisher and name", func(t *testing.T) {
_, err := dbExtensions{}.Create(ctx, user.ID, 0, "zzz")
if err != nil {
t.Fatal(err)
}
if _, err := (dbExtensions{}).Create(ctx, user.ID, 0, "zzz"); err == nil {
t.Fatal("err == nil")
}
})
}
func asJSON(t *testing.T, v interface{}) string {
b, err := json.MarshalIndent(v, "", " ")
if err != nil {
t.Fatal(err)
}
return string(b)
}

View File

@ -0,0 +1,239 @@
package registry
import (
"context"
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"os"
"strings"
"github.com/prometheus/client_golang/prometheus"
frontendregistry "github.com/sourcegraph/sourcegraph/cmd/frontend/registry"
"github.com/sourcegraph/sourcegraph/pkg/conf"
"github.com/sourcegraph/sourcegraph/pkg/errcode"
"github.com/sourcegraph/sourcegraph/pkg/honey"
"github.com/sourcegraph/sourcegraph/pkg/registry"
)
func init() {
frontendregistry.HandleRegistry = handleRegistry
}
// Funcs called by serveRegistry to get registry data. If fakeRegistryData is set, it is used as
// the data source instead of the database.
var (
registryList = func(ctx context.Context, opt dbExtensionsListOptions) ([]*registry.Extension, error) {
vs, err := dbExtensions{}.List(ctx, opt)
if err != nil {
return nil, err
}
xs := make([]*registry.Extension, len(vs))
for i, v := range vs {
x, err := toRegistryAPIExtension(ctx, v)
if err != nil {
return nil, err
}
xs[i] = x
}
return xs, nil
}
registryGetByUUID = func(ctx context.Context, uuid string) (*registry.Extension, error) {
x, err := dbExtensions{}.GetByUUID(ctx, uuid)
if err != nil {
return nil, err
}
return toRegistryAPIExtension(ctx, x)
}
registryGetByExtensionID = func(ctx context.Context, extensionID string) (*registry.Extension, error) {
x, err := dbExtensions{}.GetByExtensionID(ctx, extensionID)
if err != nil {
return nil, err
}
return toRegistryAPIExtension(ctx, x)
}
)
func toRegistryAPIExtension(ctx context.Context, v *dbExtension) (*registry.Extension, error) {
manifest, err := getExtensionManifestWithBundleURL(ctx, v.NonCanonicalExtensionID, v.ID, "release")
if err != nil {
return nil, err
}
baseURL := strings.TrimSuffix(conf.Get().AppURL, "/")
return &registry.Extension{
UUID: v.UUID,
ExtensionID: v.NonCanonicalExtensionID,
Publisher: registry.Publisher{
Name: v.Publisher.NonCanonicalName,
URL: baseURL + frontendregistry.PublisherExtensionsURL(v.Publisher.UserID != 0, v.Publisher.OrgID != 0, v.Publisher.NonCanonicalName),
},
Name: v.Name,
Manifest: manifest,
CreatedAt: v.CreatedAt,
UpdatedAt: v.UpdatedAt,
URL: baseURL + frontendregistry.ExtensionURL(v.NonCanonicalExtensionID),
}, nil
}
// handleRegistry serves the external HTTP API for the extension registry.
func handleRegistry(w http.ResponseWriter, r *http.Request) (err error) {
if conf.Extensions() == nil {
w.WriteHeader(http.StatusNotFound)
return
}
builder := honey.Builder("registry")
builder.AddField("api_version", r.Header.Get("Accept"))
builder.AddField("url", r.URL.String())
ev := builder.NewEvent()
defer func() {
ev.AddField("success", err == nil)
if err == nil {
registryRequestsSuccessCounter.Inc()
} else {
registryRequestsErrorCounter.Inc()
ev.AddField("error", err.Error())
}
ev.Send()
}()
// Identify this response as coming from the registry API.
w.Header().Set(registry.MediaTypeHeaderName, registry.MediaType)
w.Header().Set("Vary", registry.MediaTypeHeaderName)
// Validate API version.
if v := r.Header.Get("Accept"); v != registry.AcceptHeader {
http.Error(w, fmt.Sprintf("invalid Accept header: expected %q", registry.AcceptHeader), http.StatusBadRequest)
return nil
}
// This handler can be mounted at either /.internal or /.api.
urlPath := r.URL.Path
switch {
case strings.HasPrefix(urlPath, "/.internal"):
urlPath = strings.TrimPrefix(urlPath, "/.internal")
case strings.HasPrefix(urlPath, "/.api"):
urlPath = strings.TrimPrefix(urlPath, "/.api")
}
const extensionsPath = "/registry/extensions"
var result interface{}
switch {
case urlPath == extensionsPath:
query := r.URL.Query().Get("q")
ev.AddField("query", query)
xs, err := registryList(r.Context(), dbExtensionsListOptions{Query: query})
if err != nil {
return err
}
ev.AddField("results_count", len(xs))
result = xs
case strings.HasPrefix(urlPath, extensionsPath+"/"):
var (
spec = strings.TrimPrefix(urlPath, extensionsPath+"/")
x *registry.Extension
err error
)
switch {
case strings.HasPrefix(spec, "uuid/"):
x, err = registryGetByUUID(r.Context(), strings.TrimPrefix(spec, "uuid/"))
case strings.HasPrefix(spec, "extension-id/"):
x, err = registryGetByExtensionID(r.Context(), strings.TrimPrefix(spec, "extension-id/"))
default:
w.WriteHeader(http.StatusNotFound)
return nil
}
if x == nil || err != nil {
if x == nil || errcode.IsNotFound(err) {
w.Header().Set("Cache-Control", "max-age=5, private")
http.Error(w, "extension not found", http.StatusNotFound)
return nil
}
return err
}
ev.AddField("extension-id", x.ExtensionID)
result = x
default:
w.WriteHeader(http.StatusNotFound)
return nil
}
w.Header().Set("Cache-Control", "max-age=30, private")
data, err := json.Marshal(result)
if err != nil {
return err
}
w.Write(data)
return nil
}
var (
registryRequestsSuccessCounter = prometheus.NewCounter(prometheus.CounterOpts{
Namespace: "src",
Subsystem: "registry",
Name: "requests_success",
Help: "Number of successful requests (HTTP 200) to the HTTP registry API",
})
registryRequestsErrorCounter = prometheus.NewCounter(prometheus.CounterOpts{
Namespace: "src",
Subsystem: "registry",
Name: "requests_error",
Help: "Number of failed (non-HTTP 200) requests to the HTTP registry API",
})
)
func init() {
prometheus.MustRegister(registryRequestsSuccessCounter)
prometheus.MustRegister(registryRequestsErrorCounter)
}
func init() {
// Allow providing fake registry data for local dev (intended for use in local dev only).
//
// If FAKE_REGISTRY is set and refers to a valid JSON file (of []*registry.Extension), is used
// by serveRegistry (instead of the DB) as the source for registry data.
path := os.Getenv("FAKE_REGISTRY")
if path == "" {
return
}
readFakeExtensions := func() ([]*registry.Extension, error) {
data, err := ioutil.ReadFile(path)
if err != nil {
return nil, err
}
var xs []*registry.Extension
if err := json.Unmarshal(data, &xs); err != nil {
return nil, err
}
return xs, nil
}
registryList = func(ctx context.Context, opt dbExtensionsListOptions) ([]*registry.Extension, error) {
xs, err := readFakeExtensions()
if err != nil {
return nil, err
}
return frontendregistry.FilterRegistryExtensions(xs, opt.Query), nil
}
registryGetByUUID = func(ctx context.Context, uuid string) (*registry.Extension, error) {
xs, err := readFakeExtensions()
if err != nil {
return nil, err
}
return frontendregistry.FindRegistryExtension(xs, "uuid", uuid), nil
}
registryGetByExtensionID = func(ctx context.Context, extensionID string) (*registry.Extension, error) {
xs, err := readFakeExtensions()
if err != nil {
return nil, err
}
return frontendregistry.FindRegistryExtension(xs, "extensionID", extensionID), nil
}
}

View File

@ -0,0 +1,12 @@
package registry
func resetMocks() {
mocks = dbMocks{}
}
type dbMocks struct {
extensions mockExtensions
releases mockReleases
}
var mocks dbMocks

View File

@ -0,0 +1,74 @@
package registry
import (
"context"
"sync"
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend/graphqlutil"
frontendregistry "github.com/sourcegraph/sourcegraph/cmd/frontend/registry"
)
func init() {
frontendregistry.ExtensionRegistry.PublishersFunc = extensionRegistryPublishers
}
func extensionRegistryPublishers(ctx context.Context, args *graphqlutil.ConnectionArgs) (graphqlbackend.RegistryPublisherConnection, error) {
var opt dbPublishersListOptions
args.Set(&opt.LimitOffset)
return &registryPublisherConnection{opt: opt}, nil
}
// registryPublisherConnection resolves a list of registry publishers.
type registryPublisherConnection struct {
opt dbPublishersListOptions
// cache results because they are used by multiple fields
once sync.Once
registryPublishers []*dbPublisher
err error
}
func (r *registryPublisherConnection) compute(ctx context.Context) ([]*dbPublisher, error) {
r.once.Do(func() {
opt2 := r.opt
if opt2.LimitOffset != nil {
tmp := *opt2.LimitOffset
opt2.LimitOffset = &tmp
opt2.Limit++ // so we can detect if there is a next page
}
r.registryPublishers, r.err = dbExtensions{}.ListPublishers(ctx, opt2)
})
return r.registryPublishers, r.err
}
func (r *registryPublisherConnection) Nodes(ctx context.Context) ([]graphqlbackend.RegistryPublisher, error) {
publishers, err := r.compute(ctx)
if err != nil {
return nil, err
}
var l []graphqlbackend.RegistryPublisher
for _, publisher := range publishers {
p, err := getRegistryPublisher(ctx, *publisher)
if err != nil {
return nil, err
}
l = append(l, p)
}
return l, nil
}
func (r *registryPublisherConnection) TotalCount(ctx context.Context) (int32, error) {
count, err := dbExtensions{}.CountPublishers(ctx, r.opt)
return int32(count), err
}
func (r *registryPublisherConnection) PageInfo(ctx context.Context) (*graphqlutil.PageInfo, error) {
publishers, err := r.compute(ctx)
if err != nil {
return nil, err
}
return graphqlutil.HasNextPage(r.opt.LimitOffset != nil && len(publishers) > r.opt.Limit), nil
}

View File

@ -0,0 +1,148 @@
package registry
import (
"context"
"errors"
"fmt"
graphql "github.com/graph-gophers/graphql-go"
"github.com/graph-gophers/graphql-go/relay"
"github.com/sourcegraph/enterprise/cmd/frontend/internal/licensing"
"github.com/sourcegraph/sourcegraph/cmd/frontend/backend"
"github.com/sourcegraph/sourcegraph/cmd/frontend/db"
"github.com/sourcegraph/sourcegraph/cmd/frontend/graphqlbackend"
frontendregistry "github.com/sourcegraph/sourcegraph/cmd/frontend/registry"
)
func init() {
frontendregistry.ExtensionRegistry.ViewerPublishersFunc = extensionRegistryViewerPublishers
}
func extensionRegistryViewerPublishers(ctx context.Context) ([]graphqlbackend.RegistryPublisher, error) {
// The feature check here makes it so the any "New extension" form will show an error, so the
// user finds out before trying to submit the form that the feature is disabled.
if err := licensing.CheckFeature(licensing.FeatureExtensionRegistry); err != nil {
return nil, err
}
var publishers []graphqlbackend.RegistryPublisher
user, err := graphqlbackend.CurrentUser(ctx)
if err != nil || user == nil {
return nil, err
}
publishers = append(publishers, &registryPublisher{user: user})
orgs, err := db.Orgs.GetByUserID(ctx, user.SourcegraphID())
if err != nil {
return nil, err
}
for _, org := range orgs {
publishers = append(publishers, &registryPublisher{org: graphqlbackend.NewOrg(org)})
}
return publishers, nil
}
// registryPublisher implements the GraphQL type RegistryPublisher.
type registryPublisher struct {
user *graphqlbackend.UserResolver
org *graphqlbackend.OrgResolver
}
var _ graphqlbackend.RegistryPublisher = &registryPublisher{}
func (r *registryPublisher) ToUser() (*graphqlbackend.UserResolver, bool) {
return r.user, r.user != nil
}
func (r *registryPublisher) ToOrg() (*graphqlbackend.OrgResolver, bool) { return r.org, r.org != nil }
func (r *registryPublisher) toDBRegistryPublisher() dbPublisher {
switch {
case r.user != nil:
return dbPublisher{UserID: r.user.SourcegraphID(), NonCanonicalName: r.user.Username()}
case r.org != nil:
return dbPublisher{OrgID: r.org.OrgID(), NonCanonicalName: r.org.Name()}
default:
return dbPublisher{}
}
}
func (r *registryPublisher) RegistryExtensionConnectionURL() (*string, error) {
p := r.toDBRegistryPublisher()
url := frontendregistry.PublisherExtensionsURL(p.UserID != 0, p.OrgID != 0, p.NonCanonicalName)
if url == "" {
return nil, errRegistryUnknownPublisher
}
return &url, nil
}
var errRegistryUnknownPublisher = errors.New("unknown registry extension publisher")
func getRegistryPublisher(ctx context.Context, publisher dbPublisher) (*registryPublisher, error) {
switch {
case publisher.UserID != 0:
user, err := graphqlbackend.UserByIDInt32(ctx, publisher.UserID)
if err != nil {
return nil, err
}
return &registryPublisher{user: user}, nil
case publisher.OrgID != 0:
org, err := graphqlbackend.OrgByIDInt32(ctx, publisher.OrgID)
if err != nil {
return nil, err
}
return &registryPublisher{org: org}, nil
default:
return nil, errRegistryUnknownPublisher
}
}
type registryPublisherID struct {
userID, orgID int32
}
func toRegistryPublisherID(extension *dbExtension) *registryPublisherID {
return &registryPublisherID{
userID: extension.Publisher.UserID,
orgID: extension.Publisher.OrgID,
}
}
// unmarshalRegistryPublisherID unmarshals the GraphQL ID into the possible publisher ID
// types.
//
// 🚨 SECURITY
func unmarshalRegistryPublisherID(id graphql.ID) (*registryPublisherID, error) {
var (
p registryPublisherID
err error
)
switch kind := relay.UnmarshalKind(id); kind {
case "User":
p.userID, err = graphqlbackend.UnmarshalUserID(id)
case "Org":
p.orgID, err = graphqlbackend.UnmarshalOrgID(id)
default:
return nil, fmt.Errorf("unknown registry extension publisher type: %q", kind)
}
if err != nil {
return nil, err
}
return &p, nil
}
// viewerCanAdminister returns whether the current user is allowed to perform mutations on a
// registry extension with the given publisher.
//
// 🚨 SECURITY
func (p *registryPublisherID) viewerCanAdminister(ctx context.Context) error {
switch {
case p.userID != 0:
// 🚨 SECURITY: Check that the current user is either the publisher or a site admin.
return backend.CheckSiteAdminOrSameUser(ctx, p.userID)
case p.orgID != 0:
// 🚨 SECURITY: Check that the current user is a member of the publisher org.
return backend.CheckOrgAccess(ctx, p.orgID)
default:
return errRegistryUnknownPublisher
}
}

Some files were not shown because too many files have changed in this diff Show More