mirror of
https://github.com/sourcegraph/sourcegraph.git
synced 2026-02-06 18:11:48 +00:00
feat/cody-gateway: support wildcard models (#62909)
In https://sourcegraph.slack.com/archives/C05SZB829D0/p1715638980052279 we shared a decision we landed on as part of #62263: > Ignoring (then removing) per-subscription model allowlists: As part of the API discussions, we've also surfaced some opportunities for improvements - to make it easier to roll out new models to Enterprise, we're not including per-subscription model allowlists in the new API, and as part of the Cody Gateway migration (by end-of-June), we will update Cody Gateway to stop enforcing per-subscription model allowlists. Cody Gateway will still retain a Cody-Gateway-wide model allowlist. [@chrsmith](https://sourcegraph.slack.com/team/U061QHKUBJ8) is working on a broader design here and will have more to share on this later. To support this, we first need to extend Cody Gateway's model allowlist enforcement to respect a notion of "allow all models that are allowed in Cody Gateway". To ensure models are explicitly provided today, an empty `AllowedModels` is considered invalid, so we add a special single-element-slice-`*` configuration that can be used to indicate an actor's rate limit allows all models (`prefixedMasterAllowlist`). This change also unifies somewhat the way we enforce allowed models in various places by introducing `(*RateLimit).EvaluateAllowedModels(...)` as the unified way to construct the final allowlist for a given rate limit. I'm planning to roll this out before rolling out actual functionality changes (https://github.com/sourcegraph/sourcegraph/pull/62911) to ensure changes in cached rate limits don't end up confusing an older revision of Cody Gateway that doesn't yet support wildcard models. With #62911, rolling out new models to Enterprise customers no longer require additional code/override changes. Part of https://linear.app/sourcegraph/issue/CORE-135 ## Test plan Unit tests, and E2E test of this in https://github.com/sourcegraph/sourcegraph/pull/62911
This commit is contained in:
parent
9d0131f936
commit
5833a98185
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
@ -19,7 +20,11 @@ import (
|
||||
|
||||
type RateLimit struct {
|
||||
// AllowedModels is a set of models in Cody Gateway's model configuration
|
||||
// format, "$PROVIDER/$MODEL_NAME".
|
||||
// format, "$PROVIDER/$MODEL_NAME". A single-item slice with value '*' means
|
||||
// that all models in the 'master allowlist' are allowed.
|
||||
//
|
||||
// DO NOT USE DIRECTLY when enforcing permissions: use EvaluateAllowedModels(...)
|
||||
// instead.
|
||||
AllowedModels []string `json:"allowedModels"`
|
||||
|
||||
Limit int64 `json:"limit"`
|
||||
@ -58,6 +63,26 @@ func (r *RateLimit) IsValid() bool {
|
||||
return r != nil && r.Interval > 0 && r.Limit > 0 && len(r.AllowedModels) > 0
|
||||
}
|
||||
|
||||
// EvaluateAllowedModels returns the intersection of a 'master' allowlist and
|
||||
// the actor's allowlist, where only values on the 'master' allowlist are returned.
|
||||
// The provided allowlist MUST be prefixed with the provider name (e.g. "anthropic/").
|
||||
//
|
||||
// If the actor's allowlist is a single value '*', then the master allowlist is
|
||||
// returned (i.e. all models are allowed).
|
||||
func (r *RateLimit) EvaluateAllowedModels(prefixedMasterAllowlist []string) []string {
|
||||
if len(r.AllowedModels) == 1 && r.AllowedModels[0] == "*" {
|
||||
return prefixedMasterAllowlist // all models allowed
|
||||
}
|
||||
|
||||
var result []string
|
||||
for _, val := range r.AllowedModels {
|
||||
if slices.Contains(prefixedMasterAllowlist, val) {
|
||||
result = append(result, val)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
type concurrencyLimiter struct {
|
||||
logger log.Logger
|
||||
actor *Actor
|
||||
|
||||
@ -173,3 +173,51 @@ func TestAsErrConcurrencyLimitExceeded(t *testing.T) {
|
||||
assert.True(t, errors.As(err, &ErrConcurrencyLimitExceeded{}))
|
||||
assert.True(t, errors.As(errors.Wrap(err, "foo"), &ErrConcurrencyLimitExceeded{}))
|
||||
}
|
||||
|
||||
func TestRateLimit_EvaluateAllowedModels(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
allowedModels []string
|
||||
prefixedMasterAllowlist []string
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
name: "all models allowed",
|
||||
allowedModels: []string{"*"},
|
||||
prefixedMasterAllowlist: []string{"provider/model1", "provider/model2", "provider/model3"},
|
||||
want: []string{"provider/model1", "provider/model2", "provider/model3"},
|
||||
},
|
||||
{
|
||||
name: "no models allowed",
|
||||
allowedModels: []string{},
|
||||
prefixedMasterAllowlist: []string{"provider/model1", "provider/model2", "provider/model3"},
|
||||
want: []string{},
|
||||
},
|
||||
{
|
||||
name: "some models allowed",
|
||||
allowedModels: []string{"provider/model1", "provider/model3"},
|
||||
prefixedMasterAllowlist: []string{"provider/model1", "provider/model2", "provider/model3"},
|
||||
want: []string{"provider/model1", "provider/model3"},
|
||||
},
|
||||
{
|
||||
name: "non-existent models allowed",
|
||||
allowedModels: []string{"provider/model4", "provider/model5"},
|
||||
prefixedMasterAllowlist: []string{"provider/model1", "provider/model2", "provider/model3"},
|
||||
want: []string{},
|
||||
},
|
||||
{
|
||||
name: "multiple models with wildcard is ignored",
|
||||
allowedModels: []string{"provider/model1", "*", "provider/model4"},
|
||||
prefixedMasterAllowlist: []string{"provider/model1"},
|
||||
want: []string{"provider/model1"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
r := &RateLimit{AllowedModels: tt.allowedModels}
|
||||
got := r.EvaluateAllowedModels(tt.prefixedMasterAllowlist)
|
||||
assert.ElementsMatch(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@ -7,7 +7,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@ -137,7 +136,7 @@ func makeUpstreamHandler[ReqT UpstreamRequest](
|
||||
// upstreamName is the name of the upstream provider. It MUST match the
|
||||
// provider names defined clientside, i.e. "anthropic" or "openai".
|
||||
upstreamName string,
|
||||
|
||||
// unprefixed upstream model names
|
||||
allowedModels []string,
|
||||
|
||||
methods upstreamHandlerMethods[ReqT],
|
||||
@ -150,10 +149,10 @@ func makeUpstreamHandler[ReqT UpstreamRequest](
|
||||
// Convert allowedModels to the Cody Gateway configuration format with the
|
||||
// provider as a prefix. This aligns with the models returned when we query
|
||||
// for rate limits from actor sources.
|
||||
clonedAllowedModels := make([]string, len(allowedModels))
|
||||
copy(clonedAllowedModels, allowedModels)
|
||||
for i := range clonedAllowedModels {
|
||||
clonedAllowedModels[i] = fmt.Sprintf("%s/%s", upstreamName, clonedAllowedModels[i])
|
||||
prefixedAllowedModels := make([]string, len(allowedModels))
|
||||
copy(prefixedAllowedModels, allowedModels)
|
||||
for i := range prefixedAllowedModels {
|
||||
prefixedAllowedModels[i] = fmt.Sprintf("%s/%s", upstreamName, prefixedAllowedModels[i])
|
||||
}
|
||||
|
||||
// upstreamHandler is the actual HTTP handle that will perform "all of the things"
|
||||
@ -200,6 +199,9 @@ func makeUpstreamHandler[ReqT UpstreamRequest](
|
||||
// This isn't very robust, but should tide us through a brief transition
|
||||
// period until everything deploys and our caches refresh.
|
||||
for i := range rateLimit.AllowedModels {
|
||||
if rateLimit.AllowedModels[i] == "*" {
|
||||
continue // special wildcard value
|
||||
}
|
||||
if !strings.Contains(rateLimit.AllowedModels[i], "/") {
|
||||
rateLimit.AllowedModels[i] = fmt.Sprintf("%s/%s", upstreamName, rateLimit.AllowedModels[i])
|
||||
}
|
||||
@ -305,7 +307,7 @@ func makeUpstreamHandler[ReqT UpstreamRequest](
|
||||
// the prefix yet when extracted - we need to add it back here. This
|
||||
// full gatewayModel is also used in events tracking.
|
||||
gatewayModel := fmt.Sprintf("%s/%s", upstreamName, model)
|
||||
if allowed := intersection(clonedAllowedModels, rateLimit.AllowedModels); !isAllowedModel(allowed, gatewayModel) {
|
||||
if allowed := rateLimit.EvaluateAllowedModels(prefixedAllowedModels); !isAllowedModel(allowed, gatewayModel) {
|
||||
response.JSONError(logger, w, http.StatusBadRequest,
|
||||
errors.Newf("model %q is not allowed, allowed: [%s]",
|
||||
gatewayModel, strings.Join(allowed, ", ")))
|
||||
@ -535,12 +537,3 @@ func isAllowedModel(allowedModels []string, model string) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func intersection(a, b []string) (c []string) {
|
||||
for _, val := range a {
|
||||
if slices.Contains(b, val) {
|
||||
c = append(c, val)
|
||||
}
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
@ -34,7 +33,7 @@ func NewHandler(
|
||||
rs limiter.RedisStore,
|
||||
rateLimitNotifier notify.RateLimitNotifier,
|
||||
mf ModelFactory,
|
||||
allowedModels []string,
|
||||
prefixedAllowedModels []string,
|
||||
) http.Handler {
|
||||
baseLogger = baseLogger.Scoped("embeddingshandler")
|
||||
|
||||
@ -64,7 +63,7 @@ func NewHandler(
|
||||
return
|
||||
}
|
||||
|
||||
if !isAllowedModel(intersection(allowedModels, rateLimit.AllowedModels), body.Model) {
|
||||
if !isAllowedModel(rateLimit.EvaluateAllowedModels(prefixedAllowedModels), body.Model) {
|
||||
response.JSONError(logger, w, http.StatusBadRequest, errors.Newf("model %q is not allowed", body.Model))
|
||||
return
|
||||
}
|
||||
@ -200,12 +199,3 @@ func isAllowedModel(allowedModels []string, model string) bool {
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func intersection(a, b []string) (c []string) {
|
||||
for _, val := range a {
|
||||
if slices.Contains(b, val) {
|
||||
c = append(c, val)
|
||||
}
|
||||
}
|
||||
return c
|
||||
}
|
||||
|
||||
@ -10,11 +10,11 @@ import (
|
||||
"github.com/sourcegraph/sourcegraph/internal/codygateway"
|
||||
)
|
||||
|
||||
type ModelName string
|
||||
type PrefixedModelName string
|
||||
|
||||
const (
|
||||
ModelNameOpenAIAda ModelName = "openai/text-embedding-ada-002"
|
||||
ModelNameSourcegraphSTMultiQA ModelName = "sourcegraph/st-multi-qa-mpnet-base-dot-v1"
|
||||
ModelNameOpenAIAda PrefixedModelName = "openai/text-embedding-ada-002"
|
||||
ModelNameSourcegraphSTMultiQA PrefixedModelName = "sourcegraph/st-multi-qa-mpnet-base-dot-v1"
|
||||
)
|
||||
|
||||
type EmbeddingsClient interface {
|
||||
@ -26,23 +26,23 @@ type ModelFactory interface {
|
||||
ForModel(model string) (_ EmbeddingsClient, ok bool)
|
||||
}
|
||||
|
||||
type ModelFactoryMap map[ModelName]EmbeddingsClient
|
||||
type ModelFactoryMap map[PrefixedModelName]EmbeddingsClient
|
||||
|
||||
func (mf ModelFactoryMap) ForModel(model string) (EmbeddingsClient, bool) {
|
||||
c, ok := mf[ModelName(model)]
|
||||
c, ok := mf[PrefixedModelName(model)]
|
||||
return c, ok
|
||||
}
|
||||
|
||||
func NewListHandler() http.Handler {
|
||||
func NewListHandler(prefixedAllowedModels []string) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
act := actor.FromContext(r.Context())
|
||||
|
||||
modelEnabled := func(model ModelName) bool {
|
||||
modelEnabled := func(model PrefixedModelName) bool {
|
||||
rl, ok := act.RateLimits[codygateway.FeatureEmbeddings]
|
||||
if !act.AccessEnabled || !ok || !rl.IsValid() {
|
||||
return false
|
||||
}
|
||||
return slices.Contains(rl.AllowedModels, string(model))
|
||||
return slices.Contains(rl.EvaluateAllowedModels(prefixedAllowedModels), string(model))
|
||||
}
|
||||
|
||||
models := modelsResponse{
|
||||
|
||||
@ -158,7 +158,8 @@ func NewHandler(
|
||||
attributesOpenAICompletions,
|
||||
openAIHandler)
|
||||
|
||||
registerSimpleGETEndpoint("v1.embeddings.models", "/embeddings/models", embeddings.NewListHandler())
|
||||
registerSimpleGETEndpoint("v1.embeddings.models", "/embeddings/models",
|
||||
embeddings.NewListHandler(config.EmbeddingsAllowedModels))
|
||||
|
||||
factoryMap := embeddings.ModelFactoryMap{
|
||||
embeddings.ModelNameOpenAIAda: embeddings.NewOpenAIClient(httpClient, config.OpenAI.AccessToken),
|
||||
|
||||
@ -45,6 +45,7 @@ type Config struct {
|
||||
|
||||
Fireworks FireworksConfig
|
||||
|
||||
// Prefixed model names
|
||||
AllowedEmbeddingsModels []string
|
||||
|
||||
AllowAnonymous bool
|
||||
@ -88,12 +89,14 @@ type OpenTelemetryConfig struct {
|
||||
}
|
||||
|
||||
type AnthropicConfig struct {
|
||||
// Non-prefixed model names
|
||||
AllowedModels []string
|
||||
AccessToken string
|
||||
FlaggingConfig FlaggingConfig
|
||||
}
|
||||
|
||||
type FireworksConfig struct {
|
||||
// Non-prefixed model names
|
||||
AllowedModels []string
|
||||
AccessToken string
|
||||
StarcoderCommunitySingleTenantPercent int
|
||||
@ -102,6 +105,7 @@ type FireworksConfig struct {
|
||||
}
|
||||
|
||||
type OpenAIConfig struct {
|
||||
// Non-prefixed model names
|
||||
AllowedModels []string
|
||||
AccessToken string
|
||||
OrgID string
|
||||
|
||||
Loading…
Reference in New Issue
Block a user