cody-gateway: unify limiter handlers (#53346)

Unifies limiter middleware under `httpapi/featurelimiter`. The two
existing implementations are very similar but not quite, leading to
deviations in error handling and so on.

## Test plan

CI
This commit is contained in:
Robert Lin 2023-06-12 15:46:32 -07:00 committed by GitHub
parent 4302e13b99
commit a4acc92b6a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 85 additions and 121 deletions

View File

@ -188,9 +188,8 @@ func TestConcurrencyLimiter_TryAcquire(t *testing.T) {
}
func TestAsErrConcurrencyLimitExceeded(t *testing.T) {
var concurrencyLimitExceeded ErrConcurrencyLimitExceeded
var err error
err = ErrConcurrencyLimitExceeded{}
assert.True(t, errors.As(err, &concurrencyLimitExceeded))
assert.True(t, errors.As(errors.Wrap(err, "foo"), &concurrencyLimitExceeded))
assert.True(t, errors.As(err, &ErrConcurrencyLimitExceeded{}))
assert.True(t, errors.As(errors.Wrap(err, "foo"), &ErrConcurrencyLimitExceeded{}))
}

View File

@ -4,7 +4,6 @@ go_library(
name = "completions",
srcs = [
"anthropic.go",
"limiter.go",
"openai.go",
"upstream.go",
],
@ -13,13 +12,13 @@ go_library(
deps = [
"//enterprise/cmd/cody-gateway/internal/actor",
"//enterprise/cmd/cody-gateway/internal/events",
"//enterprise/cmd/cody-gateway/internal/httpapi/featurelimiter",
"//enterprise/cmd/cody-gateway/internal/limiter",
"//enterprise/cmd/cody-gateway/internal/notify",
"//enterprise/cmd/cody-gateway/internal/response",
"//enterprise/internal/codygateway",
"//enterprise/internal/completions/client/anthropic",
"//enterprise/internal/completions/client/openai",
"//enterprise/internal/completions/types",
"//internal/httpcli",
"//internal/trace",
"//lib/errors",

View File

@ -16,6 +16,7 @@ import (
"github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/actor"
"github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/events"
"github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/httpapi/featurelimiter"
"github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/limiter"
"github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/notify"
"github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/response"
@ -79,7 +80,7 @@ func makeUpstreamHandler[ReqT any](
allowedModels[i] = fmt.Sprintf("%s/%s", upstreamName, allowedModels[i])
}
return rateLimit(
return featurelimiter.Handle(
baseLogger,
eventLogger,
limiter.NewPrefixRedisStore("rate_limit:", rs),
@ -88,9 +89,9 @@ func makeUpstreamHandler[ReqT any](
act := actor.FromContext(r.Context())
logger := act.Logger(sgtrace.Logger(r.Context(), baseLogger))
feature, err := extractFeature(r)
if err != nil {
response.JSONError(logger, w, http.StatusBadRequest, err)
feature := featurelimiter.GetFeature(r.Context())
if feature == "" {
response.JSONError(logger, w, http.StatusBadRequest, errors.New("no feature provided"))
return
}

View File

@ -4,7 +4,6 @@ go_library(
name = "embeddings",
srcs = [
"handler.go",
"limiter.go",
"models.go",
"openai.go",
],
@ -13,6 +12,7 @@ go_library(
deps = [
"//enterprise/cmd/cody-gateway/internal/actor",
"//enterprise/cmd/cody-gateway/internal/events",
"//enterprise/cmd/cody-gateway/internal/httpapi/featurelimiter",
"//enterprise/cmd/cody-gateway/internal/limiter",
"//enterprise/cmd/cody-gateway/internal/notify",
"//enterprise/cmd/cody-gateway/internal/response",

View File

@ -13,6 +13,7 @@ import (
"github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/actor"
"github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/events"
"github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/httpapi/featurelimiter"
"github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/limiter"
"github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/notify"
"github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/response"
@ -33,11 +34,12 @@ func NewHandler(
) http.Handler {
baseLogger = baseLogger.Scoped("embeddingshandler", "The HTTP API handler for the embeddings endpoint.")
return rateLimit(
return featurelimiter.HandleFeature(
baseLogger,
eventLogger,
limiter.NewPrefixRedisStore("rate_limit:", rs),
rateLimitNotifier,
codygateway.FeatureEmbeddings,
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
act := actor.FromContext(r.Context())
logger := act.Logger(sgtrace.Logger(r.Context(), baseLogger))

View File

@ -1,88 +0,0 @@
package embeddings
import (
"net/http"
"strconv"
"github.com/sourcegraph/log"
"github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/actor"
"github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/events"
"github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/limiter"
"github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/notify"
"github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/response"
"github.com/sourcegraph/sourcegraph/enterprise/internal/codygateway"
sgtrace "github.com/sourcegraph/sourcegraph/internal/trace"
"github.com/sourcegraph/sourcegraph/lib/errors"
)
func rateLimit(
baseLogger log.Logger,
eventLogger events.Logger,
cache limiter.RedisStore,
rateLimitNotifier notify.RateLimitNotifier,
next http.Handler,
) http.Handler {
baseLogger = baseLogger.Scoped("rateLimit", "rate limit handler")
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
act := actor.FromContext(r.Context())
logger := act.Logger(sgtrace.Logger(r.Context(), baseLogger))
l, ok := act.Limiter(logger, cache, codygateway.FeatureEmbeddings, rateLimitNotifier)
if !ok {
response.JSONError(logger, w, http.StatusForbidden, errors.New("no access to embeddings"))
return
}
commit, err := l.TryAcquire(r.Context())
if err != nil {
if loggerErr := eventLogger.LogEvent(
r.Context(),
events.Event{
Name: codygateway.EventNameRateLimited,
Source: act.Source.Name(),
Identifier: act.ID,
Metadata: map[string]any{
"error": err.Error(),
codygateway.CompletionsEventFeatureMetadataField: "embeddings",
},
},
); loggerErr != nil {
logger.Error("failed to log event", log.Error(loggerErr))
}
var rateLimitExceeded limiter.RateLimitExceededError
if errors.As(err, &rateLimitExceeded) {
rateLimitExceeded.WriteResponse(w)
return
}
if errors.Is(err, limiter.NoAccessError{}) {
response.JSONError(logger, w, http.StatusForbidden, err)
return
}
response.JSONError(logger, w, http.StatusInternalServerError, err)
return
}
responseRecorder := response.NewStatusHeaderRecorder(w)
next.ServeHTTP(responseRecorder, r)
// If response is healthy, consume the rate limit
if responseRecorder.StatusCode >= 200 && responseRecorder.StatusCode < 300 {
uh := w.Header().Get(usageHeaderName)
if uh == "" {
logger.Error("no usage header set on response")
}
usage, err := strconv.Atoi(uh)
if err != nil {
logger.Error("failed to parse usage header as number", log.Error(err))
}
if err := commit(usage); err != nil {
logger.Error("failed to commit rate limit consumption", log.Error(err))
}
}
})
}

View File

@ -0,0 +1,20 @@
load("@io_bazel_rules_go//go:def.bzl", "go_library")
go_library(
name = "featurelimiter",
srcs = ["featurelimiter.go"],
importpath = "github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/httpapi/featurelimiter",
visibility = ["//enterprise/cmd/cody-gateway:__subpackages__"],
deps = [
"//enterprise/cmd/cody-gateway/internal/actor",
"//enterprise/cmd/cody-gateway/internal/events",
"//enterprise/cmd/cody-gateway/internal/limiter",
"//enterprise/cmd/cody-gateway/internal/notify",
"//enterprise/cmd/cody-gateway/internal/response",
"//enterprise/internal/codygateway",
"//enterprise/internal/completions/types",
"//internal/trace",
"//lib/errors",
"@com_github_sourcegraph_log//:log",
],
)

View File

@ -1,6 +1,7 @@
package completions
package featurelimiter
import (
"context"
"net/http"
"strings"
@ -17,24 +18,67 @@ import (
"github.com/sourcegraph/sourcegraph/lib/errors"
)
func rateLimit(
type contextKey string
const contextKeyFeature contextKey = "feature"
// GetFeature gets the feature used by Handle or HandleFeature.
func GetFeature(ctx context.Context) codygateway.Feature {
if f, ok := ctx.Value(contextKeyFeature).(codygateway.Feature); ok {
return f
}
return ""
}
// Handle extracts features from codygateway.FeatureHeaderName and uses it to
// determine the appropriate per-feature rate limits applied for an actor.
func Handle(
baseLogger log.Logger,
eventLogger events.Logger,
cache limiter.RedisStore,
rateLimitNotifier notify.RateLimitNotifier,
next http.Handler,
) http.Handler {
baseLogger = baseLogger.Scoped("rateLimit", "rate limit handler")
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
feature, err := extractFeature(r)
if err != nil {
response.JSONError(baseLogger, w, http.StatusBadRequest, err)
return
}
HandleFeature(baseLogger, eventLogger, cache, rateLimitNotifier, feature, next).
ServeHTTP(w, r)
})
}
func extractFeature(r *http.Request) (codygateway.Feature, error) {
h := strings.TrimSpace(r.Header.Get(codygateway.FeatureHeaderName))
if h == "" {
return "", errors.Newf("%s header is required", codygateway.FeatureHeaderName)
}
feature := types.CompletionsFeature(h)
if !feature.IsValid() {
return "", errors.Newf("invalid value for %s", codygateway.FeatureHeaderName)
}
// codygateway.Feature and types.CompletionsFeature map 1:1 for completions.
return codygateway.Feature(feature), nil
}
// Handle uses a predefined feature to determine the appropriate per-feature
// rate limits applied for an actor.
func HandleFeature(
baseLogger log.Logger,
eventLogger events.Logger,
cache limiter.RedisStore,
rateLimitNotifier notify.RateLimitNotifier,
feature codygateway.Feature,
next http.Handler,
) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
act := actor.FromContext(r.Context())
logger := act.Logger(sgtrace.Logger(r.Context(), baseLogger))
feature, err := extractFeature(r)
if err != nil {
response.JSONError(logger, w, http.StatusBadRequest, err)
return
}
r = r.WithContext(context.WithValue(r.Context(), contextKeyFeature, feature))
l, ok := act.Limiter(logger, cache, feature, rateLimitNotifier)
if !ok {
@ -96,16 +140,3 @@ func rateLimit(
}
})
}
func extractFeature(r *http.Request) (codygateway.Feature, error) {
h := strings.TrimSpace(r.Header.Get(codygateway.FeatureHeaderName))
if h == "" {
return "", errors.Newf("%s header is required", codygateway.FeatureHeaderName)
}
feature := types.CompletionsFeature(h)
if !feature.IsValid() {
return "", errors.Newf("invalid value for %s", codygateway.FeatureHeaderName)
}
// codygateway.Feature and types.CompletionsFeature map 1:1 for completions.
return codygateway.Feature(feature), nil
}

View File

@ -47,7 +47,7 @@ func TestErrStatusNotOK(t *testing.T) {
assert.Equal(t, resp.Header, writtenResp.Header)
// Should not have written the response body.
writtenBody, err := io.ReadAll(resp.Body)
writtenBody, err := io.ReadAll(writtenResp.Body)
assert.NoError(t, err)
assert.Empty(t, writtenBody)
})