From a4acc92b6acc3682efcded79a7dec51757d4aba7 Mon Sep 17 00:00:00 2001 From: Robert Lin Date: Mon, 12 Jun 2023 15:46:32 -0700 Subject: [PATCH] cody-gateway: unify limiter handlers (#53346) Unifies limiter middleware under `httpapi/featurelimiter`. The two existing implementations are very similar but not quite, leading to deviations in error handling and so on. ## Test plan CI --- .../cody-gateway/internal/actor/actor_test.go | 5 +- .../internal/httpapi/completions/BUILD.bazel | 3 +- .../internal/httpapi/completions/upstream.go | 9 +- .../internal/httpapi/embeddings/BUILD.bazel | 2 +- .../internal/httpapi/embeddings/handler.go | 4 +- .../internal/httpapi/embeddings/limiter.go | 88 ------------------- .../httpapi/featurelimiter/BUILD.bazel | 20 +++++ .../featurelimiter.go} | 73 ++++++++++----- .../internal/completions/types/errors_test.go | 2 +- 9 files changed, 85 insertions(+), 121 deletions(-) delete mode 100644 enterprise/cmd/cody-gateway/internal/httpapi/embeddings/limiter.go create mode 100644 enterprise/cmd/cody-gateway/internal/httpapi/featurelimiter/BUILD.bazel rename enterprise/cmd/cody-gateway/internal/httpapi/{completions/limiter.go => featurelimiter/featurelimiter.go} (74%) diff --git a/enterprise/cmd/cody-gateway/internal/actor/actor_test.go b/enterprise/cmd/cody-gateway/internal/actor/actor_test.go index 533c5c5217a..9668c21b291 100644 --- a/enterprise/cmd/cody-gateway/internal/actor/actor_test.go +++ b/enterprise/cmd/cody-gateway/internal/actor/actor_test.go @@ -188,9 +188,8 @@ func TestConcurrencyLimiter_TryAcquire(t *testing.T) { } func TestAsErrConcurrencyLimitExceeded(t *testing.T) { - var concurrencyLimitExceeded ErrConcurrencyLimitExceeded var err error err = ErrConcurrencyLimitExceeded{} - assert.True(t, errors.As(err, &concurrencyLimitExceeded)) - assert.True(t, errors.As(errors.Wrap(err, "foo"), &concurrencyLimitExceeded)) + assert.True(t, errors.As(err, &ErrConcurrencyLimitExceeded{})) + assert.True(t, errors.As(errors.Wrap(err, "foo"), &ErrConcurrencyLimitExceeded{})) } diff --git a/enterprise/cmd/cody-gateway/internal/httpapi/completions/BUILD.bazel b/enterprise/cmd/cody-gateway/internal/httpapi/completions/BUILD.bazel index b7042fa6874..86f59eaa355 100644 --- a/enterprise/cmd/cody-gateway/internal/httpapi/completions/BUILD.bazel +++ b/enterprise/cmd/cody-gateway/internal/httpapi/completions/BUILD.bazel @@ -4,7 +4,6 @@ go_library( name = "completions", srcs = [ "anthropic.go", - "limiter.go", "openai.go", "upstream.go", ], @@ -13,13 +12,13 @@ go_library( deps = [ "//enterprise/cmd/cody-gateway/internal/actor", "//enterprise/cmd/cody-gateway/internal/events", + "//enterprise/cmd/cody-gateway/internal/httpapi/featurelimiter", "//enterprise/cmd/cody-gateway/internal/limiter", "//enterprise/cmd/cody-gateway/internal/notify", "//enterprise/cmd/cody-gateway/internal/response", "//enterprise/internal/codygateway", "//enterprise/internal/completions/client/anthropic", "//enterprise/internal/completions/client/openai", - "//enterprise/internal/completions/types", "//internal/httpcli", "//internal/trace", "//lib/errors", diff --git a/enterprise/cmd/cody-gateway/internal/httpapi/completions/upstream.go b/enterprise/cmd/cody-gateway/internal/httpapi/completions/upstream.go index bc23a635d32..d63a5e649cd 100644 --- a/enterprise/cmd/cody-gateway/internal/httpapi/completions/upstream.go +++ b/enterprise/cmd/cody-gateway/internal/httpapi/completions/upstream.go @@ -16,6 +16,7 @@ import ( "github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/actor" "github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/events" + "github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/httpapi/featurelimiter" "github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/limiter" "github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/notify" "github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/response" @@ -79,7 +80,7 @@ func makeUpstreamHandler[ReqT any]( allowedModels[i] = fmt.Sprintf("%s/%s", upstreamName, allowedModels[i]) } - return rateLimit( + return featurelimiter.Handle( baseLogger, eventLogger, limiter.NewPrefixRedisStore("rate_limit:", rs), @@ -88,9 +89,9 @@ func makeUpstreamHandler[ReqT any]( act := actor.FromContext(r.Context()) logger := act.Logger(sgtrace.Logger(r.Context(), baseLogger)) - feature, err := extractFeature(r) - if err != nil { - response.JSONError(logger, w, http.StatusBadRequest, err) + feature := featurelimiter.GetFeature(r.Context()) + if feature == "" { + response.JSONError(logger, w, http.StatusBadRequest, errors.New("no feature provided")) return } diff --git a/enterprise/cmd/cody-gateway/internal/httpapi/embeddings/BUILD.bazel b/enterprise/cmd/cody-gateway/internal/httpapi/embeddings/BUILD.bazel index 7971dcf826e..c1d88185ba7 100644 --- a/enterprise/cmd/cody-gateway/internal/httpapi/embeddings/BUILD.bazel +++ b/enterprise/cmd/cody-gateway/internal/httpapi/embeddings/BUILD.bazel @@ -4,7 +4,6 @@ go_library( name = "embeddings", srcs = [ "handler.go", - "limiter.go", "models.go", "openai.go", ], @@ -13,6 +12,7 @@ go_library( deps = [ "//enterprise/cmd/cody-gateway/internal/actor", "//enterprise/cmd/cody-gateway/internal/events", + "//enterprise/cmd/cody-gateway/internal/httpapi/featurelimiter", "//enterprise/cmd/cody-gateway/internal/limiter", "//enterprise/cmd/cody-gateway/internal/notify", "//enterprise/cmd/cody-gateway/internal/response", diff --git a/enterprise/cmd/cody-gateway/internal/httpapi/embeddings/handler.go b/enterprise/cmd/cody-gateway/internal/httpapi/embeddings/handler.go index 31567c8c5d4..3309c768109 100644 --- a/enterprise/cmd/cody-gateway/internal/httpapi/embeddings/handler.go +++ b/enterprise/cmd/cody-gateway/internal/httpapi/embeddings/handler.go @@ -13,6 +13,7 @@ import ( "github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/actor" "github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/events" + "github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/httpapi/featurelimiter" "github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/limiter" "github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/notify" "github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/response" @@ -33,11 +34,12 @@ func NewHandler( ) http.Handler { baseLogger = baseLogger.Scoped("embeddingshandler", "The HTTP API handler for the embeddings endpoint.") - return rateLimit( + return featurelimiter.HandleFeature( baseLogger, eventLogger, limiter.NewPrefixRedisStore("rate_limit:", rs), rateLimitNotifier, + codygateway.FeatureEmbeddings, http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { act := actor.FromContext(r.Context()) logger := act.Logger(sgtrace.Logger(r.Context(), baseLogger)) diff --git a/enterprise/cmd/cody-gateway/internal/httpapi/embeddings/limiter.go b/enterprise/cmd/cody-gateway/internal/httpapi/embeddings/limiter.go deleted file mode 100644 index 5ad11550bb8..00000000000 --- a/enterprise/cmd/cody-gateway/internal/httpapi/embeddings/limiter.go +++ /dev/null @@ -1,88 +0,0 @@ -package embeddings - -import ( - "net/http" - "strconv" - - "github.com/sourcegraph/log" - - "github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/actor" - "github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/events" - "github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/limiter" - "github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/notify" - "github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/response" - "github.com/sourcegraph/sourcegraph/enterprise/internal/codygateway" - sgtrace "github.com/sourcegraph/sourcegraph/internal/trace" - "github.com/sourcegraph/sourcegraph/lib/errors" -) - -func rateLimit( - baseLogger log.Logger, - eventLogger events.Logger, - cache limiter.RedisStore, - rateLimitNotifier notify.RateLimitNotifier, - next http.Handler, -) http.Handler { - baseLogger = baseLogger.Scoped("rateLimit", "rate limit handler") - - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - act := actor.FromContext(r.Context()) - logger := act.Logger(sgtrace.Logger(r.Context(), baseLogger)) - - l, ok := act.Limiter(logger, cache, codygateway.FeatureEmbeddings, rateLimitNotifier) - if !ok { - response.JSONError(logger, w, http.StatusForbidden, errors.New("no access to embeddings")) - return - } - - commit, err := l.TryAcquire(r.Context()) - if err != nil { - if loggerErr := eventLogger.LogEvent( - r.Context(), - events.Event{ - Name: codygateway.EventNameRateLimited, - Source: act.Source.Name(), - Identifier: act.ID, - Metadata: map[string]any{ - "error": err.Error(), - codygateway.CompletionsEventFeatureMetadataField: "embeddings", - }, - }, - ); loggerErr != nil { - logger.Error("failed to log event", log.Error(loggerErr)) - } - - var rateLimitExceeded limiter.RateLimitExceededError - if errors.As(err, &rateLimitExceeded) { - rateLimitExceeded.WriteResponse(w) - return - } - - if errors.Is(err, limiter.NoAccessError{}) { - response.JSONError(logger, w, http.StatusForbidden, err) - return - } - - response.JSONError(logger, w, http.StatusInternalServerError, err) - return - } - - responseRecorder := response.NewStatusHeaderRecorder(w) - next.ServeHTTP(responseRecorder, r) - - // If response is healthy, consume the rate limit - if responseRecorder.StatusCode >= 200 && responseRecorder.StatusCode < 300 { - uh := w.Header().Get(usageHeaderName) - if uh == "" { - logger.Error("no usage header set on response") - } - usage, err := strconv.Atoi(uh) - if err != nil { - logger.Error("failed to parse usage header as number", log.Error(err)) - } - if err := commit(usage); err != nil { - logger.Error("failed to commit rate limit consumption", log.Error(err)) - } - } - }) -} diff --git a/enterprise/cmd/cody-gateway/internal/httpapi/featurelimiter/BUILD.bazel b/enterprise/cmd/cody-gateway/internal/httpapi/featurelimiter/BUILD.bazel new file mode 100644 index 00000000000..917d97be8e1 --- /dev/null +++ b/enterprise/cmd/cody-gateway/internal/httpapi/featurelimiter/BUILD.bazel @@ -0,0 +1,20 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_library") + +go_library( + name = "featurelimiter", + srcs = ["featurelimiter.go"], + importpath = "github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/httpapi/featurelimiter", + visibility = ["//enterprise/cmd/cody-gateway:__subpackages__"], + deps = [ + "//enterprise/cmd/cody-gateway/internal/actor", + "//enterprise/cmd/cody-gateway/internal/events", + "//enterprise/cmd/cody-gateway/internal/limiter", + "//enterprise/cmd/cody-gateway/internal/notify", + "//enterprise/cmd/cody-gateway/internal/response", + "//enterprise/internal/codygateway", + "//enterprise/internal/completions/types", + "//internal/trace", + "//lib/errors", + "@com_github_sourcegraph_log//:log", + ], +) diff --git a/enterprise/cmd/cody-gateway/internal/httpapi/completions/limiter.go b/enterprise/cmd/cody-gateway/internal/httpapi/featurelimiter/featurelimiter.go similarity index 74% rename from enterprise/cmd/cody-gateway/internal/httpapi/completions/limiter.go rename to enterprise/cmd/cody-gateway/internal/httpapi/featurelimiter/featurelimiter.go index 80f46a1a087..fa88f8dc89d 100644 --- a/enterprise/cmd/cody-gateway/internal/httpapi/completions/limiter.go +++ b/enterprise/cmd/cody-gateway/internal/httpapi/featurelimiter/featurelimiter.go @@ -1,6 +1,7 @@ -package completions +package featurelimiter import ( + "context" "net/http" "strings" @@ -17,24 +18,67 @@ import ( "github.com/sourcegraph/sourcegraph/lib/errors" ) -func rateLimit( +type contextKey string + +const contextKeyFeature contextKey = "feature" + +// GetFeature gets the feature used by Handle or HandleFeature. +func GetFeature(ctx context.Context) codygateway.Feature { + if f, ok := ctx.Value(contextKeyFeature).(codygateway.Feature); ok { + return f + } + return "" +} + +// Handle extracts features from codygateway.FeatureHeaderName and uses it to +// determine the appropriate per-feature rate limits applied for an actor. +func Handle( baseLogger log.Logger, eventLogger events.Logger, cache limiter.RedisStore, rateLimitNotifier notify.RateLimitNotifier, next http.Handler, ) http.Handler { - baseLogger = baseLogger.Scoped("rateLimit", "rate limit handler") + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + feature, err := extractFeature(r) + if err != nil { + response.JSONError(baseLogger, w, http.StatusBadRequest, err) + return + } + HandleFeature(baseLogger, eventLogger, cache, rateLimitNotifier, feature, next). + ServeHTTP(w, r) + }) +} + +func extractFeature(r *http.Request) (codygateway.Feature, error) { + h := strings.TrimSpace(r.Header.Get(codygateway.FeatureHeaderName)) + if h == "" { + return "", errors.Newf("%s header is required", codygateway.FeatureHeaderName) + } + feature := types.CompletionsFeature(h) + if !feature.IsValid() { + return "", errors.Newf("invalid value for %s", codygateway.FeatureHeaderName) + } + // codygateway.Feature and types.CompletionsFeature map 1:1 for completions. + return codygateway.Feature(feature), nil +} + +// Handle uses a predefined feature to determine the appropriate per-feature +// rate limits applied for an actor. +func HandleFeature( + baseLogger log.Logger, + eventLogger events.Logger, + cache limiter.RedisStore, + rateLimitNotifier notify.RateLimitNotifier, + feature codygateway.Feature, + next http.Handler, +) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { act := actor.FromContext(r.Context()) logger := act.Logger(sgtrace.Logger(r.Context(), baseLogger)) - feature, err := extractFeature(r) - if err != nil { - response.JSONError(logger, w, http.StatusBadRequest, err) - return - } + r = r.WithContext(context.WithValue(r.Context(), contextKeyFeature, feature)) l, ok := act.Limiter(logger, cache, feature, rateLimitNotifier) if !ok { @@ -96,16 +140,3 @@ func rateLimit( } }) } - -func extractFeature(r *http.Request) (codygateway.Feature, error) { - h := strings.TrimSpace(r.Header.Get(codygateway.FeatureHeaderName)) - if h == "" { - return "", errors.Newf("%s header is required", codygateway.FeatureHeaderName) - } - feature := types.CompletionsFeature(h) - if !feature.IsValid() { - return "", errors.Newf("invalid value for %s", codygateway.FeatureHeaderName) - } - // codygateway.Feature and types.CompletionsFeature map 1:1 for completions. - return codygateway.Feature(feature), nil -} diff --git a/enterprise/internal/completions/types/errors_test.go b/enterprise/internal/completions/types/errors_test.go index be2521cac62..0091f9cd023 100644 --- a/enterprise/internal/completions/types/errors_test.go +++ b/enterprise/internal/completions/types/errors_test.go @@ -47,7 +47,7 @@ func TestErrStatusNotOK(t *testing.T) { assert.Equal(t, resp.Header, writtenResp.Header) // Should not have written the response body. - writtenBody, err := io.ReadAll(resp.Body) + writtenBody, err := io.ReadAll(writtenResp.Body) assert.NoError(t, err) assert.Empty(t, writtenBody) })