feat/cody-gateway: support wildcard models (#62909)

In https://sourcegraph.slack.com/archives/C05SZB829D0/p1715638980052279 we shared a decision we landed on as part of #62263:

> Ignoring (then removing) per-subscription model allowlists: As part of the API discussions, we've also surfaced some opportunities for improvements - to make it easier to roll out new models to Enterprise, we're not including per-subscription model allowlists in the new API, and as part of the Cody Gateway migration (by end-of-June), we will update Cody Gateway to stop enforcing per-subscription model allowlists. Cody Gateway will still retain a Cody-Gateway-wide model allowlist. [@chrsmith](https://sourcegraph.slack.com/team/U061QHKUBJ8) is working on a broader design here and will have more to share on this later.

To support this, we first need to extend Cody Gateway's model allowlist enforcement to respect a notion of "allow all models that are allowed in Cody Gateway". To ensure models are explicitly provided today, an empty `AllowedModels` is considered invalid, so we add a special single-element-slice-`*` configuration that can be used to indicate an actor's rate limit allows all models (`prefixedMasterAllowlist`).

This change also unifies somewhat the way we enforce allowed models in various places by introducing `(*RateLimit).EvaluateAllowedModels(...)` as the unified way to construct the final allowlist for a given rate limit.

I'm planning to roll this out before rolling out actual functionality changes (https://github.com/sourcegraph/sourcegraph/pull/62911) to ensure changes in cached rate limits don't end up confusing an older revision of Cody Gateway that doesn't yet support wildcard models. With #62911, rolling out new models to Enterprise customers no longer require additional code/override changes.

Part of https://linear.app/sourcegraph/issue/CORE-135

## Test plan

Unit tests, and E2E test of this in https://github.com/sourcegraph/sourcegraph/pull/62911
This commit is contained in:
Robert Lin 2024-05-31 13:09:01 -07:00 committed by GitHub
parent 9d0131f936
commit 5833a98185
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 99 additions and 38 deletions

View File

@ -4,6 +4,7 @@ import (
"context"
"fmt"
"net/http"
"slices"
"strconv"
"time"
@ -19,7 +20,11 @@ import (
type RateLimit struct {
// AllowedModels is a set of models in Cody Gateway's model configuration
// format, "$PROVIDER/$MODEL_NAME".
// format, "$PROVIDER/$MODEL_NAME". A single-item slice with value '*' means
// that all models in the 'master allowlist' are allowed.
//
// DO NOT USE DIRECTLY when enforcing permissions: use EvaluateAllowedModels(...)
// instead.
AllowedModels []string `json:"allowedModels"`
Limit int64 `json:"limit"`
@ -58,6 +63,26 @@ func (r *RateLimit) IsValid() bool {
return r != nil && r.Interval > 0 && r.Limit > 0 && len(r.AllowedModels) > 0
}
// EvaluateAllowedModels returns the intersection of a 'master' allowlist and
// the actor's allowlist, where only values on the 'master' allowlist are returned.
// The provided allowlist MUST be prefixed with the provider name (e.g. "anthropic/").
//
// If the actor's allowlist is a single value '*', then the master allowlist is
// returned (i.e. all models are allowed).
func (r *RateLimit) EvaluateAllowedModels(prefixedMasterAllowlist []string) []string {
if len(r.AllowedModels) == 1 && r.AllowedModels[0] == "*" {
return prefixedMasterAllowlist // all models allowed
}
var result []string
for _, val := range r.AllowedModels {
if slices.Contains(prefixedMasterAllowlist, val) {
result = append(result, val)
}
}
return result
}
type concurrencyLimiter struct {
logger log.Logger
actor *Actor

View File

@ -173,3 +173,51 @@ func TestAsErrConcurrencyLimitExceeded(t *testing.T) {
assert.True(t, errors.As(err, &ErrConcurrencyLimitExceeded{}))
assert.True(t, errors.As(errors.Wrap(err, "foo"), &ErrConcurrencyLimitExceeded{}))
}
func TestRateLimit_EvaluateAllowedModels(t *testing.T) {
tests := []struct {
name string
allowedModels []string
prefixedMasterAllowlist []string
want []string
}{
{
name: "all models allowed",
allowedModels: []string{"*"},
prefixedMasterAllowlist: []string{"provider/model1", "provider/model2", "provider/model3"},
want: []string{"provider/model1", "provider/model2", "provider/model3"},
},
{
name: "no models allowed",
allowedModels: []string{},
prefixedMasterAllowlist: []string{"provider/model1", "provider/model2", "provider/model3"},
want: []string{},
},
{
name: "some models allowed",
allowedModels: []string{"provider/model1", "provider/model3"},
prefixedMasterAllowlist: []string{"provider/model1", "provider/model2", "provider/model3"},
want: []string{"provider/model1", "provider/model3"},
},
{
name: "non-existent models allowed",
allowedModels: []string{"provider/model4", "provider/model5"},
prefixedMasterAllowlist: []string{"provider/model1", "provider/model2", "provider/model3"},
want: []string{},
},
{
name: "multiple models with wildcard is ignored",
allowedModels: []string{"provider/model1", "*", "provider/model4"},
prefixedMasterAllowlist: []string{"provider/model1"},
want: []string{"provider/model1"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
r := &RateLimit{AllowedModels: tt.allowedModels}
got := r.EvaluateAllowedModels(tt.prefixedMasterAllowlist)
assert.ElementsMatch(t, tt.want, got)
})
}
}

View File

@ -7,7 +7,6 @@ import (
"fmt"
"io"
"net/http"
"slices"
"strconv"
"strings"
"time"
@ -137,7 +136,7 @@ func makeUpstreamHandler[ReqT UpstreamRequest](
// upstreamName is the name of the upstream provider. It MUST match the
// provider names defined clientside, i.e. "anthropic" or "openai".
upstreamName string,
// unprefixed upstream model names
allowedModels []string,
methods upstreamHandlerMethods[ReqT],
@ -150,10 +149,10 @@ func makeUpstreamHandler[ReqT UpstreamRequest](
// Convert allowedModels to the Cody Gateway configuration format with the
// provider as a prefix. This aligns with the models returned when we query
// for rate limits from actor sources.
clonedAllowedModels := make([]string, len(allowedModels))
copy(clonedAllowedModels, allowedModels)
for i := range clonedAllowedModels {
clonedAllowedModels[i] = fmt.Sprintf("%s/%s", upstreamName, clonedAllowedModels[i])
prefixedAllowedModels := make([]string, len(allowedModels))
copy(prefixedAllowedModels, allowedModels)
for i := range prefixedAllowedModels {
prefixedAllowedModels[i] = fmt.Sprintf("%s/%s", upstreamName, prefixedAllowedModels[i])
}
// upstreamHandler is the actual HTTP handle that will perform "all of the things"
@ -200,6 +199,9 @@ func makeUpstreamHandler[ReqT UpstreamRequest](
// This isn't very robust, but should tide us through a brief transition
// period until everything deploys and our caches refresh.
for i := range rateLimit.AllowedModels {
if rateLimit.AllowedModels[i] == "*" {
continue // special wildcard value
}
if !strings.Contains(rateLimit.AllowedModels[i], "/") {
rateLimit.AllowedModels[i] = fmt.Sprintf("%s/%s", upstreamName, rateLimit.AllowedModels[i])
}
@ -305,7 +307,7 @@ func makeUpstreamHandler[ReqT UpstreamRequest](
// the prefix yet when extracted - we need to add it back here. This
// full gatewayModel is also used in events tracking.
gatewayModel := fmt.Sprintf("%s/%s", upstreamName, model)
if allowed := intersection(clonedAllowedModels, rateLimit.AllowedModels); !isAllowedModel(allowed, gatewayModel) {
if allowed := rateLimit.EvaluateAllowedModels(prefixedAllowedModels); !isAllowedModel(allowed, gatewayModel) {
response.JSONError(logger, w, http.StatusBadRequest,
errors.Newf("model %q is not allowed, allowed: [%s]",
gatewayModel, strings.Join(allowed, ", ")))
@ -535,12 +537,3 @@ func isAllowedModel(allowedModels []string, model string) bool {
}
return false
}
func intersection(a, b []string) (c []string) {
for _, val := range a {
if slices.Contains(b, val) {
c = append(c, val)
}
}
return c
}

View File

@ -4,7 +4,6 @@ import (
"context"
"encoding/json"
"net/http"
"slices"
"strconv"
"strings"
"time"
@ -34,7 +33,7 @@ func NewHandler(
rs limiter.RedisStore,
rateLimitNotifier notify.RateLimitNotifier,
mf ModelFactory,
allowedModels []string,
prefixedAllowedModels []string,
) http.Handler {
baseLogger = baseLogger.Scoped("embeddingshandler")
@ -64,7 +63,7 @@ func NewHandler(
return
}
if !isAllowedModel(intersection(allowedModels, rateLimit.AllowedModels), body.Model) {
if !isAllowedModel(rateLimit.EvaluateAllowedModels(prefixedAllowedModels), body.Model) {
response.JSONError(logger, w, http.StatusBadRequest, errors.Newf("model %q is not allowed", body.Model))
return
}
@ -200,12 +199,3 @@ func isAllowedModel(allowedModels []string, model string) bool {
}
return false
}
func intersection(a, b []string) (c []string) {
for _, val := range a {
if slices.Contains(b, val) {
c = append(c, val)
}
}
return c
}

View File

@ -10,11 +10,11 @@ import (
"github.com/sourcegraph/sourcegraph/internal/codygateway"
)
type ModelName string
type PrefixedModelName string
const (
ModelNameOpenAIAda ModelName = "openai/text-embedding-ada-002"
ModelNameSourcegraphSTMultiQA ModelName = "sourcegraph/st-multi-qa-mpnet-base-dot-v1"
ModelNameOpenAIAda PrefixedModelName = "openai/text-embedding-ada-002"
ModelNameSourcegraphSTMultiQA PrefixedModelName = "sourcegraph/st-multi-qa-mpnet-base-dot-v1"
)
type EmbeddingsClient interface {
@ -26,23 +26,23 @@ type ModelFactory interface {
ForModel(model string) (_ EmbeddingsClient, ok bool)
}
type ModelFactoryMap map[ModelName]EmbeddingsClient
type ModelFactoryMap map[PrefixedModelName]EmbeddingsClient
func (mf ModelFactoryMap) ForModel(model string) (EmbeddingsClient, bool) {
c, ok := mf[ModelName(model)]
c, ok := mf[PrefixedModelName(model)]
return c, ok
}
func NewListHandler() http.Handler {
func NewListHandler(prefixedAllowedModels []string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
act := actor.FromContext(r.Context())
modelEnabled := func(model ModelName) bool {
modelEnabled := func(model PrefixedModelName) bool {
rl, ok := act.RateLimits[codygateway.FeatureEmbeddings]
if !act.AccessEnabled || !ok || !rl.IsValid() {
return false
}
return slices.Contains(rl.AllowedModels, string(model))
return slices.Contains(rl.EvaluateAllowedModels(prefixedAllowedModels), string(model))
}
models := modelsResponse{

View File

@ -158,7 +158,8 @@ func NewHandler(
attributesOpenAICompletions,
openAIHandler)
registerSimpleGETEndpoint("v1.embeddings.models", "/embeddings/models", embeddings.NewListHandler())
registerSimpleGETEndpoint("v1.embeddings.models", "/embeddings/models",
embeddings.NewListHandler(config.EmbeddingsAllowedModels))
factoryMap := embeddings.ModelFactoryMap{
embeddings.ModelNameOpenAIAda: embeddings.NewOpenAIClient(httpClient, config.OpenAI.AccessToken),

View File

@ -45,6 +45,7 @@ type Config struct {
Fireworks FireworksConfig
// Prefixed model names
AllowedEmbeddingsModels []string
AllowAnonymous bool
@ -88,12 +89,14 @@ type OpenTelemetryConfig struct {
}
type AnthropicConfig struct {
// Non-prefixed model names
AllowedModels []string
AccessToken string
FlaggingConfig FlaggingConfig
}
type FireworksConfig struct {
// Non-prefixed model names
AllowedModels []string
AccessToken string
StarcoderCommunitySingleTenantPercent int
@ -102,6 +105,7 @@ type FireworksConfig struct {
}
type OpenAIConfig struct {
// Non-prefixed model names
AllowedModels []string
AccessToken string
OrgID string