mirror of
https://github.com/sourcegraph/sourcegraph.git
synced 2026-02-06 17:31:43 +00:00
cody-gateway: unify limiter handlers (#53346)
Unifies limiter middleware under `httpapi/featurelimiter`. The two existing implementations are very similar but not quite, leading to deviations in error handling and so on. ## Test plan CI
This commit is contained in:
parent
4302e13b99
commit
a4acc92b6a
@ -188,9 +188,8 @@ func TestConcurrencyLimiter_TryAcquire(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestAsErrConcurrencyLimitExceeded(t *testing.T) {
|
||||
var concurrencyLimitExceeded ErrConcurrencyLimitExceeded
|
||||
var err error
|
||||
err = ErrConcurrencyLimitExceeded{}
|
||||
assert.True(t, errors.As(err, &concurrencyLimitExceeded))
|
||||
assert.True(t, errors.As(errors.Wrap(err, "foo"), &concurrencyLimitExceeded))
|
||||
assert.True(t, errors.As(err, &ErrConcurrencyLimitExceeded{}))
|
||||
assert.True(t, errors.As(errors.Wrap(err, "foo"), &ErrConcurrencyLimitExceeded{}))
|
||||
}
|
||||
|
||||
@ -4,7 +4,6 @@ go_library(
|
||||
name = "completions",
|
||||
srcs = [
|
||||
"anthropic.go",
|
||||
"limiter.go",
|
||||
"openai.go",
|
||||
"upstream.go",
|
||||
],
|
||||
@ -13,13 +12,13 @@ go_library(
|
||||
deps = [
|
||||
"//enterprise/cmd/cody-gateway/internal/actor",
|
||||
"//enterprise/cmd/cody-gateway/internal/events",
|
||||
"//enterprise/cmd/cody-gateway/internal/httpapi/featurelimiter",
|
||||
"//enterprise/cmd/cody-gateway/internal/limiter",
|
||||
"//enterprise/cmd/cody-gateway/internal/notify",
|
||||
"//enterprise/cmd/cody-gateway/internal/response",
|
||||
"//enterprise/internal/codygateway",
|
||||
"//enterprise/internal/completions/client/anthropic",
|
||||
"//enterprise/internal/completions/client/openai",
|
||||
"//enterprise/internal/completions/types",
|
||||
"//internal/httpcli",
|
||||
"//internal/trace",
|
||||
"//lib/errors",
|
||||
|
||||
@ -16,6 +16,7 @@ import (
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/actor"
|
||||
"github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/events"
|
||||
"github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/httpapi/featurelimiter"
|
||||
"github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/limiter"
|
||||
"github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/notify"
|
||||
"github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/response"
|
||||
@ -79,7 +80,7 @@ func makeUpstreamHandler[ReqT any](
|
||||
allowedModels[i] = fmt.Sprintf("%s/%s", upstreamName, allowedModels[i])
|
||||
}
|
||||
|
||||
return rateLimit(
|
||||
return featurelimiter.Handle(
|
||||
baseLogger,
|
||||
eventLogger,
|
||||
limiter.NewPrefixRedisStore("rate_limit:", rs),
|
||||
@ -88,9 +89,9 @@ func makeUpstreamHandler[ReqT any](
|
||||
act := actor.FromContext(r.Context())
|
||||
logger := act.Logger(sgtrace.Logger(r.Context(), baseLogger))
|
||||
|
||||
feature, err := extractFeature(r)
|
||||
if err != nil {
|
||||
response.JSONError(logger, w, http.StatusBadRequest, err)
|
||||
feature := featurelimiter.GetFeature(r.Context())
|
||||
if feature == "" {
|
||||
response.JSONError(logger, w, http.StatusBadRequest, errors.New("no feature provided"))
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
@ -4,7 +4,6 @@ go_library(
|
||||
name = "embeddings",
|
||||
srcs = [
|
||||
"handler.go",
|
||||
"limiter.go",
|
||||
"models.go",
|
||||
"openai.go",
|
||||
],
|
||||
@ -13,6 +12,7 @@ go_library(
|
||||
deps = [
|
||||
"//enterprise/cmd/cody-gateway/internal/actor",
|
||||
"//enterprise/cmd/cody-gateway/internal/events",
|
||||
"//enterprise/cmd/cody-gateway/internal/httpapi/featurelimiter",
|
||||
"//enterprise/cmd/cody-gateway/internal/limiter",
|
||||
"//enterprise/cmd/cody-gateway/internal/notify",
|
||||
"//enterprise/cmd/cody-gateway/internal/response",
|
||||
|
||||
@ -13,6 +13,7 @@ import (
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/actor"
|
||||
"github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/events"
|
||||
"github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/httpapi/featurelimiter"
|
||||
"github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/limiter"
|
||||
"github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/notify"
|
||||
"github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/response"
|
||||
@ -33,11 +34,12 @@ func NewHandler(
|
||||
) http.Handler {
|
||||
baseLogger = baseLogger.Scoped("embeddingshandler", "The HTTP API handler for the embeddings endpoint.")
|
||||
|
||||
return rateLimit(
|
||||
return featurelimiter.HandleFeature(
|
||||
baseLogger,
|
||||
eventLogger,
|
||||
limiter.NewPrefixRedisStore("rate_limit:", rs),
|
||||
rateLimitNotifier,
|
||||
codygateway.FeatureEmbeddings,
|
||||
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
act := actor.FromContext(r.Context())
|
||||
logger := act.Logger(sgtrace.Logger(r.Context(), baseLogger))
|
||||
|
||||
@ -1,88 +0,0 @@
|
||||
package embeddings
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/sourcegraph/log"
|
||||
|
||||
"github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/actor"
|
||||
"github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/events"
|
||||
"github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/limiter"
|
||||
"github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/notify"
|
||||
"github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/response"
|
||||
"github.com/sourcegraph/sourcegraph/enterprise/internal/codygateway"
|
||||
sgtrace "github.com/sourcegraph/sourcegraph/internal/trace"
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
)
|
||||
|
||||
func rateLimit(
|
||||
baseLogger log.Logger,
|
||||
eventLogger events.Logger,
|
||||
cache limiter.RedisStore,
|
||||
rateLimitNotifier notify.RateLimitNotifier,
|
||||
next http.Handler,
|
||||
) http.Handler {
|
||||
baseLogger = baseLogger.Scoped("rateLimit", "rate limit handler")
|
||||
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
act := actor.FromContext(r.Context())
|
||||
logger := act.Logger(sgtrace.Logger(r.Context(), baseLogger))
|
||||
|
||||
l, ok := act.Limiter(logger, cache, codygateway.FeatureEmbeddings, rateLimitNotifier)
|
||||
if !ok {
|
||||
response.JSONError(logger, w, http.StatusForbidden, errors.New("no access to embeddings"))
|
||||
return
|
||||
}
|
||||
|
||||
commit, err := l.TryAcquire(r.Context())
|
||||
if err != nil {
|
||||
if loggerErr := eventLogger.LogEvent(
|
||||
r.Context(),
|
||||
events.Event{
|
||||
Name: codygateway.EventNameRateLimited,
|
||||
Source: act.Source.Name(),
|
||||
Identifier: act.ID,
|
||||
Metadata: map[string]any{
|
||||
"error": err.Error(),
|
||||
codygateway.CompletionsEventFeatureMetadataField: "embeddings",
|
||||
},
|
||||
},
|
||||
); loggerErr != nil {
|
||||
logger.Error("failed to log event", log.Error(loggerErr))
|
||||
}
|
||||
|
||||
var rateLimitExceeded limiter.RateLimitExceededError
|
||||
if errors.As(err, &rateLimitExceeded) {
|
||||
rateLimitExceeded.WriteResponse(w)
|
||||
return
|
||||
}
|
||||
|
||||
if errors.Is(err, limiter.NoAccessError{}) {
|
||||
response.JSONError(logger, w, http.StatusForbidden, err)
|
||||
return
|
||||
}
|
||||
|
||||
response.JSONError(logger, w, http.StatusInternalServerError, err)
|
||||
return
|
||||
}
|
||||
|
||||
responseRecorder := response.NewStatusHeaderRecorder(w)
|
||||
next.ServeHTTP(responseRecorder, r)
|
||||
|
||||
// If response is healthy, consume the rate limit
|
||||
if responseRecorder.StatusCode >= 200 && responseRecorder.StatusCode < 300 {
|
||||
uh := w.Header().Get(usageHeaderName)
|
||||
if uh == "" {
|
||||
logger.Error("no usage header set on response")
|
||||
}
|
||||
usage, err := strconv.Atoi(uh)
|
||||
if err != nil {
|
||||
logger.Error("failed to parse usage header as number", log.Error(err))
|
||||
}
|
||||
if err := commit(usage); err != nil {
|
||||
logger.Error("failed to commit rate limit consumption", log.Error(err))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
20
enterprise/cmd/cody-gateway/internal/httpapi/featurelimiter/BUILD.bazel
generated
Normal file
20
enterprise/cmd/cody-gateway/internal/httpapi/featurelimiter/BUILD.bazel
generated
Normal file
@ -0,0 +1,20 @@
|
||||
load("@io_bazel_rules_go//go:def.bzl", "go_library")
|
||||
|
||||
go_library(
|
||||
name = "featurelimiter",
|
||||
srcs = ["featurelimiter.go"],
|
||||
importpath = "github.com/sourcegraph/sourcegraph/enterprise/cmd/cody-gateway/internal/httpapi/featurelimiter",
|
||||
visibility = ["//enterprise/cmd/cody-gateway:__subpackages__"],
|
||||
deps = [
|
||||
"//enterprise/cmd/cody-gateway/internal/actor",
|
||||
"//enterprise/cmd/cody-gateway/internal/events",
|
||||
"//enterprise/cmd/cody-gateway/internal/limiter",
|
||||
"//enterprise/cmd/cody-gateway/internal/notify",
|
||||
"//enterprise/cmd/cody-gateway/internal/response",
|
||||
"//enterprise/internal/codygateway",
|
||||
"//enterprise/internal/completions/types",
|
||||
"//internal/trace",
|
||||
"//lib/errors",
|
||||
"@com_github_sourcegraph_log//:log",
|
||||
],
|
||||
)
|
||||
@ -1,6 +1,7 @@
|
||||
package completions
|
||||
package featurelimiter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
@ -17,24 +18,67 @@ import (
|
||||
"github.com/sourcegraph/sourcegraph/lib/errors"
|
||||
)
|
||||
|
||||
func rateLimit(
|
||||
type contextKey string
|
||||
|
||||
const contextKeyFeature contextKey = "feature"
|
||||
|
||||
// GetFeature gets the feature used by Handle or HandleFeature.
|
||||
func GetFeature(ctx context.Context) codygateway.Feature {
|
||||
if f, ok := ctx.Value(contextKeyFeature).(codygateway.Feature); ok {
|
||||
return f
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// Handle extracts features from codygateway.FeatureHeaderName and uses it to
|
||||
// determine the appropriate per-feature rate limits applied for an actor.
|
||||
func Handle(
|
||||
baseLogger log.Logger,
|
||||
eventLogger events.Logger,
|
||||
cache limiter.RedisStore,
|
||||
rateLimitNotifier notify.RateLimitNotifier,
|
||||
next http.Handler,
|
||||
) http.Handler {
|
||||
baseLogger = baseLogger.Scoped("rateLimit", "rate limit handler")
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
feature, err := extractFeature(r)
|
||||
if err != nil {
|
||||
response.JSONError(baseLogger, w, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
|
||||
HandleFeature(baseLogger, eventLogger, cache, rateLimitNotifier, feature, next).
|
||||
ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func extractFeature(r *http.Request) (codygateway.Feature, error) {
|
||||
h := strings.TrimSpace(r.Header.Get(codygateway.FeatureHeaderName))
|
||||
if h == "" {
|
||||
return "", errors.Newf("%s header is required", codygateway.FeatureHeaderName)
|
||||
}
|
||||
feature := types.CompletionsFeature(h)
|
||||
if !feature.IsValid() {
|
||||
return "", errors.Newf("invalid value for %s", codygateway.FeatureHeaderName)
|
||||
}
|
||||
// codygateway.Feature and types.CompletionsFeature map 1:1 for completions.
|
||||
return codygateway.Feature(feature), nil
|
||||
}
|
||||
|
||||
// Handle uses a predefined feature to determine the appropriate per-feature
|
||||
// rate limits applied for an actor.
|
||||
func HandleFeature(
|
||||
baseLogger log.Logger,
|
||||
eventLogger events.Logger,
|
||||
cache limiter.RedisStore,
|
||||
rateLimitNotifier notify.RateLimitNotifier,
|
||||
feature codygateway.Feature,
|
||||
next http.Handler,
|
||||
) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
act := actor.FromContext(r.Context())
|
||||
logger := act.Logger(sgtrace.Logger(r.Context(), baseLogger))
|
||||
|
||||
feature, err := extractFeature(r)
|
||||
if err != nil {
|
||||
response.JSONError(logger, w, http.StatusBadRequest, err)
|
||||
return
|
||||
}
|
||||
r = r.WithContext(context.WithValue(r.Context(), contextKeyFeature, feature))
|
||||
|
||||
l, ok := act.Limiter(logger, cache, feature, rateLimitNotifier)
|
||||
if !ok {
|
||||
@ -96,16 +140,3 @@ func rateLimit(
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func extractFeature(r *http.Request) (codygateway.Feature, error) {
|
||||
h := strings.TrimSpace(r.Header.Get(codygateway.FeatureHeaderName))
|
||||
if h == "" {
|
||||
return "", errors.Newf("%s header is required", codygateway.FeatureHeaderName)
|
||||
}
|
||||
feature := types.CompletionsFeature(h)
|
||||
if !feature.IsValid() {
|
||||
return "", errors.Newf("invalid value for %s", codygateway.FeatureHeaderName)
|
||||
}
|
||||
// codygateway.Feature and types.CompletionsFeature map 1:1 for completions.
|
||||
return codygateway.Feature(feature), nil
|
||||
}
|
||||
@ -47,7 +47,7 @@ func TestErrStatusNotOK(t *testing.T) {
|
||||
assert.Equal(t, resp.Header, writtenResp.Header)
|
||||
|
||||
// Should not have written the response body.
|
||||
writtenBody, err := io.ReadAll(resp.Body)
|
||||
writtenBody, err := io.ReadAll(writtenResp.Body)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, writtenBody)
|
||||
})
|
||||
|
||||
Loading…
Reference in New Issue
Block a user