Refresh SAMS access tokens as needed (#62869)

* Refresh SAMS access tokens as needed

---------

Co-authored-by: Robert Lin <robert@bobheadxi.dev>
Co-authored-by: David Veszelovszki <veszelovszki@gmail.com>
This commit is contained in:
Chris Smith 2024-05-24 00:33:13 -07:00 committed by GitHub
parent ebe14c4e01
commit 38e84990e7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 125 additions and 22 deletions

View File

@ -227,11 +227,22 @@ func NewHandler(
// This means that for any cookie-based authentication method, we need to have
// CSRF protection. (However, that appears to be the case, see `newExternalHTTPHandler`
// and its use of `CookieMiddlewareWithCSRFSafety`.)
samsOAuthConfig, err := ssc.GetSAMSOAuthContext()
if err != nil {
// This situation is pretty bad, as it means no Cody Pro-related functionality
// can work properly. So while the site can continue to load as expected,
// we will supply a zero-value OAuth config that will only serve 503s.
//
// This makes the failure a lot more obvious than not registering the routes
// at all, and trying to figure out why we are seeing 404s or 405s.
logger.Error("error loading SAMS config, unable to register SSC API proxy", sglog.Error(err))
}
sscBackendProxy := ssc.APIProxyHandler{
CodyProConfig: conf.Get().Dotcom.CodyProConfig,
DB: db,
Logger: logger.Scoped("SSC Proxy"),
URLPrefix: "/.api/ssc/proxy",
CodyProConfig: conf.Get().Dotcom.CodyProConfig,
DB: db,
Logger: logger.Scoped("sscProxy"),
URLPrefix: "/.api/ssc/proxy",
SAMSOAuthContext: samsOAuthConfig,
}
m.PathPrefix("/ssc/proxy/").Handler(&sscBackendProxy)
}

View File

@ -16,6 +16,11 @@ go_library(
"//internal/conf",
"//internal/database",
"//internal/encryption",
"//internal/extsvc",
"//internal/extsvc/auth",
"//internal/httpcli",
"//internal/oauthtoken",
"//internal/oauthutil",
"//internal/trace",
"//lib/errors",
"//schema",

View File

@ -2,6 +2,7 @@ package ssc
import (
"context"
"fmt"
"io"
"net/http"
"net/url"
@ -9,13 +10,19 @@ import (
"strings"
"time"
"github.com/sourcegraph/sourcegraph/internal/httpcli"
"golang.org/x/oauth2"
"github.com/sourcegraph/log"
"github.com/sourcegraph/sourcegraph/internal/actor"
"github.com/sourcegraph/sourcegraph/internal/conf"
"github.com/sourcegraph/sourcegraph/internal/database"
"github.com/sourcegraph/sourcegraph/internal/encryption"
"github.com/sourcegraph/sourcegraph/internal/extsvc"
"github.com/sourcegraph/sourcegraph/internal/extsvc/auth"
"github.com/sourcegraph/sourcegraph/internal/oauthtoken"
"github.com/sourcegraph/sourcegraph/internal/oauthutil"
"github.com/sourcegraph/sourcegraph/lib/errors"
"github.com/sourcegraph/sourcegraph/schema"
)
@ -40,10 +47,40 @@ type APIProxyHandler struct {
// URLPrefix of where the handler is served. e.g. ".api/ssc/proxy/". This
// will be replaced with the SSC-specific URL prefix ()"cody/api/v1/").
URLPrefix string
// SAMSOAuthContext is the metadata necessary for contacting SAMS. Used
// when we notice a Sourcegraph account's SAMS identity has an expired
// access token.
SAMSOAuthContext *oauthutil.OAuthContext
}
var _ http.Handler = (*APIProxyHandler)(nil)
// GetSAMSOAuthContext returns the OAuthContext object to describe the SAMS
// IdP registered to the current Sourcegraph instance. (As identified by
// `GETSAMSServiceID()`)
func GetSAMSOAuthContext() (*oauthutil.OAuthContext, error) {
for _, provider := range conf.Get().AuthProviders {
oidcInfo := provider.Openidconnect
if oidcInfo == nil {
continue
}
if oidcInfo.Issuer == GetSAMSServiceID() {
oauthCtx := oauthutil.OAuthContext{
ClientID: oidcInfo.ClientID,
ClientSecret: oidcInfo.ClientSecret,
Endpoint: oauth2.Endpoint{
AuthURL: fmt.Sprintf("%s/oauth/authorize", oidcInfo.Issuer),
TokenURL: fmt.Sprintf("%s/oauth/token", oidcInfo.Issuer),
},
}
return &oauthCtx, nil
}
}
return nil, errors.New("no SAMS configuration found")
}
// getUserIDFromRequest extracts the Sourcegraph User ID from the incomming request,
// or returns an error suitable for sending to the end user.
func (p *APIProxyHandler) getUserIDFromContext(ctx context.Context) (int32, error) {
@ -104,8 +141,10 @@ func (p *APIProxyHandler) buildProxyRequest(sourceReq *http.Request, token strin
return proxyReq, nil
}
// getSAMSCredentialsForUser fetches the OAuth token from the user's SAMS external identity.
func (p *APIProxyHandler) getSAMSCredentialsForUser(ctx context.Context, userID int32) (*oauth2.Token, error) {
// getSAMSCredentialsForUser fetches the SAMS identity for the the given Sourcegraph user ID, and
// decrypts the OAuth token stored within.
func (p *APIProxyHandler) getSAMSCredentialsForUser(ctx context.Context, userID int32) (
*extsvc.Account, *oauth2.Token, error) {
// NOTE: It's possible for a user to have multiple SAMS identities attached to the same Sourcegraph
// user account. The underlying implementation provides a stable result sorting by ID, so we
// just return the first SAMS identity found.
@ -124,11 +163,11 @@ func (p *APIProxyHandler) getSAMSCredentialsForUser(ctx context.Context, userID
ExcludeExpired: true,
})
if err != nil {
return nil, errors.Wrap(err, "listing user external accounts")
return nil, nil, errors.Wrap(err, "listing user external accounts")
}
switch len(extAccounts) {
case 0:
return nil, errors.New("user does not have a SAMS identity")
return nil, nil, errors.New("user does not have a SAMS identity")
case 1:
// Expected, AOK
default:
@ -140,21 +179,57 @@ func (p *APIProxyHandler) getSAMSCredentialsForUser(ctx context.Context, userID
// Load the specific external account (SAMS identity).
samsIdentity, err := p.DB.UserExternalAccounts().Get(ctx, extAccounts[0].ID)
if err != nil {
return nil, errors.Wrap(err, "getting user SAMS identity")
return nil, nil, errors.Wrap(err, "getting user SAMS identity")
}
// Decrypt and unmarshall as an OAuth token.
token, err := encryption.DecryptJSON[oauth2.Token](ctx, samsIdentity.AuthData)
if err != nil {
return nil, errors.Wrap(err, "decrypting/unmarshalling SAMS auth data")
return nil, nil, errors.Wrap(err, "decrypting/unmarshalling SAMS auth data")
}
return token, nil
return samsIdentity, token, nil
}
// tryRefreshSAMSCredentials attempts to refresh the user's SAMS credentials by
// exchanging the OAuth refresh token we have on file for a new access/refresh token.
//
// Upon success, the new tokens will be persisted in the database.
func (p *APIProxyHandler) tryRefreshSAMSCredentials(
ctx context.Context, samsIdent *extsvc.Account, currentToken *oauth2.Token) (string, error) {
if samsIdent == nil || currentToken == nil {
return "", errors.New("current identity or current token not provided")
}
externalAccountID := samsIdent.ID // ID of the external identity, not the user ID.
refreshFn := oauthtoken.GetAccountRefreshAndStoreOAuthTokenFunc(
p.DB.UserExternalAccounts(), externalAccountID, p.SAMSOAuthContext)
// Perform the refresh.
userBearerToken := auth.OAuthBearerToken{
Token: currentToken.AccessToken,
RefreshToken: currentToken.RefreshToken,
Expiry: currentToken.Expiry,
}
newToken, _ /* newRefreshToken */, newExpiry, err := refreshFn(
ctx, httpcli.UncachedExternalDoer, &userBearerToken)
if err != nil {
return "", errors.Wrap(err, "refreshing SAMS token")
}
p.Logger.Info("refresh user's SAMS token", log.Time("new expiration", newExpiry))
return newToken, nil
}
func (p *APIProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
p.Logger.Info("proxying SSC API request", log.String("url", r.URL.String()))
// Confirm the proxy is configured correctly.
if p.CodyProConfig == nil || p.SAMSOAuthContext == nil || p.SAMSOAuthContext.ClientID == "" {
http.Error(w, "proxy not configured", http.StatusServiceUnavailable)
return
}
sgUserID, err := p.getUserIDFromContext(ctx)
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
@ -162,7 +237,7 @@ func (p *APIProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}
// Lookup the user's SAMS credentials.
samsToken, err := p.getSAMSCredentialsForUser(ctx, sgUserID)
samsIdentity, samsToken, err := p.getSAMSCredentialsForUser(ctx, sgUserID)
if err != nil {
// Here we assume that the function will only fail because of an IO problem.
// And not that a user simply doesn't have a SAMS identity. (Since for dotcom
@ -177,19 +252,28 @@ func (p *APIProxyHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
// the request to the SSC backend will fail because the OAuth credentials
// associated with their SAMS login has expired.
//
// The frontend needs to expect this 401 response, and force the user to
// reauthenticate. (Which would then pick up a new SAMS auth token.) Or
// we need to have a background process that will periodically refresh the
// user's SAMS credentials, so that the refresh and access token for the
// user's SAMS account are sufficiently fresh.
// If we detect this, we first try to refresh the user's SAMS access token. But
// that may also fail (if the underlying SAMS refresh token has also expired). So
// the frontend must expect this situation via a 401 response, and force the user
// to reauthenticate. (Which would then pick up a new SAMS auth token.)
accessToken := samsToken.AccessToken
if samsToken.Expiry.Before(time.Now()) {
p.Logger.Warn("the user's SAMS token has expired", log.Time("expiry", samsToken.Expiry))
http.Error(w, "Sourcegraph Accounts identity has expired", http.StatusUnauthorized)
return
newToken, err := p.tryRefreshSAMSCredentials(ctx, samsIdentity, samsToken)
if err != nil {
p.Logger.Error("error trying to refresh the user's SAMS credentials", log.Error(err))
// Just fail here since there is nothing we can do. We know the token is invalid,
// and we were unable to create a new one.
http.Error(w, "Sourcegraph Accounts identity has expired", http.StatusUnauthorized)
return
}
accessToken = newToken
}
// Copy the incoming request and send it to the SSC backend.
proxyRequest, err := p.buildProxyRequest(r, samsToken.AccessToken)
proxyRequest, err := p.buildProxyRequest(r, accessToken)
if err != nil {
p.Logger.Error("building SSC proxy request", log.Error(err))
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)

View File

@ -203,7 +203,7 @@ func TestSSCAPIProxy(t *testing.T) {
// has no available SAMS identity.
t.Run("ErrorNoSAMSIdentity", func(t *testing.T) {
ctx := context.Background()
_, err := testHandler.getSAMSCredentialsForUser(ctx, testUserID)
_, _, err := testHandler.getSAMSCredentialsForUser(ctx, testUserID)
assert.ErrorContains(t, err, "user does not have a SAMS identity")
})
@ -229,9 +229,12 @@ func TestSSCAPIProxy(t *testing.T) {
// Try to get the user's SAMS creds given this.
ctx := context.Background()
token, err := testHandler.getSAMSCredentialsForUser(ctx, testUserID)
ident, token, err := testHandler.getSAMSCredentialsForUser(ctx, testUserID)
require.NoError(t, err)
assert.Equal(t, validSAMSIdentity.ID, ident.ID)
assert.Equal(t, validSAMSIdentity.ServiceID, ident.ServiceID)
assert.Equal(t, testToken.AccessToken, token.AccessToken)
assert.Equal(t, testToken.RefreshToken, token.RefreshToken)
assert.WithinDuration(t, testToken.Expiry, token.Expiry, time.Second)