cody-gateway: remove SyncOnce, use existing SourceUpdater (#59102)

Follow-up on https://github.com/sourcegraph/sourcegraph/pull/58732#discussion_r1414348506 as I was reading the code again this morning after looking at some discussions - this change removes the SyncOnce interface and uses our existing SourceUpdater.

There is some more nuance here than I remember, however: `Actor.Update` is also called when a user hits their rate limit. This could be a good thing (poll for update if a user is rate-limited) - if we increase a user's rate limit it will get updated quickly. Because of the above  if the actor was recently updated we no-op and return an error. This could be useful in the general refresh endpoint use case as well, to avoid something repeatedly hitting dotcom in quick succession.

## Test plan

Updated tests assert same behaviour as existing implementation
This commit is contained in:
Robert Lin 2023-12-21 11:43:27 -08:00 committed by GitHub
parent d59dddf444
commit 9ce3f1858e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 97 additions and 75 deletions

View File

@ -113,10 +113,11 @@ func (a *Actor) Logger(logger log.Logger) log.Logger {
// does not necessarily occur on every call.
//
// If the actor has no source, this is a no-op.
func (a *Actor) Update(ctx context.Context) {
func (a *Actor) Update(ctx context.Context) error {
if su, ok := a.Source.(SourceUpdater); ok && su != nil {
su.Update(ctx, a)
return su.Update(ctx, a)
}
return nil
}
func (a *Actor) TraceAttributes() []attribute.KeyValue {
@ -197,7 +198,8 @@ func (a *Actor) Limiter(
concurrentInterval: limit.ConcurrentRequestsInterval,
nextLimiter: updateOnErrorLimiter{
actor: a,
logger: logger.Scoped("updateOnError"),
actor: a,
nextLimiter: baseLimiter,
},

View File

@ -10,6 +10,7 @@ import (
"github.com/Khan/genqlient/graphql"
graphqltypes "github.com/graph-gophers/graphql-go"
"github.com/graph-gophers/graphql-go/relay"
"github.com/sourcegraph/sourcegraph/internal/accesstoken"
"github.com/gregjones/httpcache"
@ -33,6 +34,10 @@ const tokenLength = 4 + 64
var (
defaultUpdateInterval = 15 * time.Minute
// defaultRefreshInterval is used for updates, which is also called when a
// user's rate limit is hit, so we don't want to update every time. We use
// a shorter interval than the default in this case.
defaultRefreshInterval = 5 * time.Minute
)
type Source struct {
@ -42,7 +47,7 @@ type Source struct {
concurrencyConfig codygateway.ActorConcurrencyLimitConfig
}
var _ actor.SourceSingleSyncer = &Source{}
var _ actor.SourceUpdater = &Source{}
func NewSource(logger log.Logger, cache httpcache.Cache, dotComClient graphql.Client, concurrencyConfig codygateway.ActorConcurrencyLimitConfig) *Source {
return &Source{
@ -59,8 +64,14 @@ func (s *Source) Get(ctx context.Context, token string) (*actor.Actor, error) {
return s.get(ctx, token, false)
}
func (s *Source) SyncOne(ctx context.Context, token string) error {
_, err := s.get(ctx, token, true)
func (s *Source) Update(ctx context.Context, act *actor.Actor) error {
if act.LastUpdated != nil && time.Since(*act.LastUpdated) < defaultRefreshInterval {
return actor.ErrActorRecentlyUpdated{
RetryAt: act.LastUpdated.Add(defaultRefreshInterval),
}
}
_, err := s.get(ctx, act.Key, true)
return err
}

View File

@ -138,17 +138,28 @@ func (e ErrConcurrencyLimitExceeded) WriteResponse(w http.ResponseWriter) {
// updateOnErrorLimiter calls Actor.Update if nextLimiter responds with certain
// access errors.
type updateOnErrorLimiter struct {
actor *Actor
logger log.Logger
actor *Actor
nextLimiter limiter.Limiter
}
func (u updateOnErrorLimiter) TryAcquire(ctx context.Context) (func(context.Context, int) error, error) {
commit, err := u.nextLimiter.TryAcquire(ctx)
// If we have an access issue, try to update the actor in case they have
// been granted updated access.
if errors.As(err, &limiter.NoAccessError{}) || errors.As(err, &limiter.RateLimitExceededError{}) {
oteltrace.SpanFromContext(ctx).
SetAttributes(attribute.Bool("update-on-error", true))
u.actor.Update(ctx) // TODO: run this in goroutine+background context maybe?
// Do update transiently, outside request hotpath
go func() {
if updateErr := u.actor.Update(context.WithoutCancel(ctx)); updateErr != nil &&
!IsErrActorRecentlyUpdated(updateErr) {
u.logger.Warn("unexpected error updating actor",
log.Error(updateErr),
log.NamedError("originalError", err))
}
}()
}
return commit, err
}

View File

@ -169,8 +169,7 @@ func TestConcurrencyLimiter_TryAcquire(t *testing.T) {
}
func TestAsErrConcurrencyLimitExceeded(t *testing.T) {
var err error
err = ErrConcurrencyLimitExceeded{}
var err error = ErrConcurrencyLimitExceeded{}
assert.True(t, errors.As(err, &ErrConcurrencyLimitExceeded{}))
assert.True(t, errors.As(errors.Wrap(err, "foo"), &ErrConcurrencyLimitExceeded{}))
}

View File

@ -113,15 +113,16 @@ func (s *Source) Get(ctx context.Context, token string) (*actor.Actor, error) {
return act, nil
}
func (s *Source) Update(ctx context.Context, actor *actor.Actor) {
if time.Since(*actor.LastUpdated) < minUpdateInterval {
func (s *Source) Update(ctx context.Context, act *actor.Actor) error {
if time.Since(*act.LastUpdated) < minUpdateInterval {
// Last update was too recent - do it later.
return
return actor.ErrActorRecentlyUpdated{
RetryAt: act.LastUpdated.Add(minUpdateInterval),
}
}
if _, err := s.fetchAndCache(ctx, actor.Key); err != nil {
sgtrace.Logger(ctx, s.log).Info("failed to update actor", log.Error(err))
}
_, err := s.fetchAndCache(ctx, act.Key)
return err
}
// Sync retrieves all known actors from this source and updates its cache.

View File

@ -7,14 +7,13 @@ import (
"github.com/go-redsync/redsync/v4"
"github.com/sourcegraph/conc/pool"
"github.com/sourcegraph/sourcegraph/internal/codygateway"
"github.com/sourcegraph/log"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/codes"
"go.opentelemetry.io/otel/trace"
"github.com/sourcegraph/log"
"github.com/sourcegraph/sourcegraph/internal/codygateway"
"github.com/sourcegraph/sourcegraph/internal/goroutine"
"github.com/sourcegraph/sourcegraph/internal/observation"
sgtrace "github.com/sourcegraph/sourcegraph/internal/trace"
@ -45,11 +44,27 @@ type Source interface {
Get(ctx context.Context, token string) (*Actor, error)
}
// ErrActorRecentlyUpdated can be used to indicate that an actor cannot be
// updated because it was already updated more recently than allowed by a
// Source implementation.
type ErrActorRecentlyUpdated struct {
RetryAt time.Time
}
func (e ErrActorRecentlyUpdated) Error() string {
return fmt.Sprintf("actor was recently updated - try again in %s",
time.Until(e.RetryAt).Truncate(time.Second).String())
}
func IsErrActorRecentlyUpdated(err error) bool { return errors.As(err, &ErrActorRecentlyUpdated{}) }
type SourceUpdater interface {
Source
// Update updates the given actor's state, though the implementation may
// decide not to do so every time.
Update(ctx context.Context, actor *Actor)
//
// Error can be ErrActorRecentlyUpdated if the actor was updated too recently.
Update(ctx context.Context, actor *Actor) error
}
type SourceSyncer interface {
@ -61,12 +76,6 @@ type SourceSyncer interface {
Sync(ctx context.Context) (int, error)
}
type SourceSingleSyncer interface {
Source
// SyncOne retrieves a single actor from this source and updates its cache.
SyncOne(ctx context.Context, token string) error
}
type Sources struct{ sources []Source }
func NewSources(sources ...Source) *Sources {
@ -153,23 +162,6 @@ func (s *Sources) SyncAll(ctx context.Context, logger log.Logger) error {
return nil
}
// SyncOne immediately runs a sync on the source implementing SourceSingleSyncer that can sync for a given token.
// Syncing is done sequentially, first error is returned - this mirrors the behaviour of Source.Get()
//
// By default, this is only used by "/v1/limits/refresh" endpoint.
func (s *Sources) SyncOne(ctx context.Context, token string) error {
for _, src := range s.sources {
if src, ok := src.(SourceSingleSyncer); ok {
err := src.SyncOne(ctx, token)
if err != nil {
return errors.Wrapf(err, "failed to sync %s", src.Name())
}
return nil
}
}
return errors.Newf("no source found for token %v", token[:4])
}
// Worker is a goroutine.BackgroundRoutine that runs any SourceSyncer implementations
// at a regular interval. It uses a redsync.Mutex to ensure only one worker is running
// at a time.

View File

@ -26,14 +26,8 @@ type mockSourceSyncer struct {
syncCount atomic.Int32
}
type mockSourceSingleSyncer struct {
mockSourceSyncer
}
var _ SourceSyncer = &mockSourceSyncer{}
var _ SourceSingleSyncer = &mockSourceSingleSyncer{}
func (m *mockSourceSyncer) Name() string { return "mock" }
func (m *mockSourceSyncer) Get(context.Context, string) (*Actor, error) {
@ -45,7 +39,13 @@ func (m *mockSourceSyncer) Sync(context.Context) (int, error) {
return 10, nil
}
func (m *mockSourceSingleSyncer) SyncOne(_ context.Context, _ string) error {
type mockSourceUpdater struct {
mockSourceSyncer
}
var _ SourceUpdater = &mockSourceUpdater{}
func (m *mockSourceUpdater) Update(context.Context, *Actor) error {
m.syncCount.Inc()
return nil
}
@ -141,29 +141,31 @@ func TestSourcesSyncAll(t *testing.T) {
assert.Equal(t, int32(2), s2.syncCount.Load())
}
func TestSourcesSyncOne(t *testing.T) {
func TestSourcesUpdate(t *testing.T) {
t.Parallel()
var s1 mockSourceSyncer
var s2 mockSourceSingleSyncer
var s3 mockSourceSingleSyncer
sources := NewSources(&s1, &s2, &s3)
err := sources.SyncOne(context.Background(), "sgd_qweqweqw")
require.NoError(t, err)
var s2 mockSourceUpdater
var s3 mockSourceUpdater
act := Actor{
Key: "sgd_qweqweqw",
Source: &s2, // belongs to s2 source only
}
err := act.Update(context.Background())
assert.NoError(t, err)
assert.Equal(t, int32(0), s1.syncCount.Load())
assert.Equal(t, int32(1), s2.syncCount.Load())
assert.Equal(t, int32(0), s3.syncCount.Load())
err = sources.SyncOne(context.Background(), "sgd_qweqweqw")
require.NoError(t, err)
err = act.Update(context.Background())
assert.NoError(t, err)
assert.Equal(t, int32(0), s1.syncCount.Load())
assert.Equal(t, int32(2), s2.syncCount.Load())
assert.Equal(t, int32(0), s3.syncCount.Load())
}
func TestIsErrNotFromSource(t *testing.T) {
var err error
err = ErrNotFromSource{Reason: "foo"}
var err error = ErrNotFromSource{Reason: "foo"}
assert.True(t, IsErrNotFromSource(err))
autogold.Expect("token not from source: foo").Equal(t, err.Error())
@ -174,3 +176,9 @@ func TestIsErrNotFromSource(t *testing.T) {
err = errors.New("foo")
assert.False(t, IsErrNotFromSource(err))
}
func TestErrActorRecentlyUpdated(t *testing.T) {
var err error = ErrActorRecentlyUpdated{RetryAt: time.Now().Add(time.Minute)}
assert.True(t, IsErrActorRecentlyUpdated(err))
assert.Equal(t, "actor was recently updated - try again in 59s", err.Error())
}

View File

@ -11,7 +11,6 @@ go_library(
"//cmd/cody-gateway/internal/limiter",
"//cmd/cody-gateway/internal/notify",
"//cmd/cody-gateway/internal/response",
"//internal/authbearer",
"//internal/codygateway",
"//internal/completions/types",
"//internal/trace",

View File

@ -18,6 +18,7 @@ import (
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/response"
"github.com/sourcegraph/sourcegraph/internal/codygateway"
"github.com/sourcegraph/sourcegraph/internal/completions/types"
"github.com/sourcegraph/sourcegraph/internal/trace"
sgtrace "github.com/sourcegraph/sourcegraph/internal/trace"
"github.com/sourcegraph/sourcegraph/lib/errors"
)
@ -211,20 +212,21 @@ func ListLimitsHandler(baseLogger log.Logger, redisStore limiter.RedisStore) htt
})
}
func RefreshLimitsHandler(baseLogger log.Logger, sources *actor.Sources) http.Handler {
func RefreshLimitsHandler(baseLogger log.Logger) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
act := actor.FromContext(r.Context())
logger := act.Logger(sgtrace.Logger(r.Context(), baseLogger))
token, err := authbearer.ExtractBearer(r.Header)
if err != nil {
response.JSONError(logger, w, http.StatusBadRequest, err)
if err := act.Update(r.Context()); err != nil {
logger := act.Logger(trace.Logger(r.Context(), baseLogger))
if actor.IsErrActorRecentlyUpdated(err) {
response.JSONError(logger, w, http.StatusTooManyRequests, err)
} else {
response.JSONError(logger, w, http.StatusInternalServerError, err)
}
return
}
err = sources.SyncOne(r.Context(), token)
if err != nil {
response.JSONError(logger, w, http.StatusBadRequest, err)
}
w.WriteHeader(http.StatusOK)
})
}

View File

@ -4,8 +4,6 @@ import (
"context"
"net/http"
"github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/actor"
"github.com/gorilla/mux"
"github.com/sourcegraph/log"
"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
@ -61,7 +59,6 @@ func NewHandler(
authr *auth.Authenticator,
promptRecorder completions.PromptRecorder,
config *Config,
sources *actor.Sources,
) (http.Handler, error) {
// Initialize metrics
counter, err := meter.Int64UpDownCounter("cody-gateway.concurrent_upstream_requests",
@ -227,7 +224,7 @@ func NewHandler(
authr.Middleware(
requestlogger.Middleware(
logger,
featurelimiter.RefreshLimitsHandler(logger, sources),
featurelimiter.RefreshLimitsHandler(logger),
),
),
otelhttp.WithPublicEndpoint(),

View File

@ -171,7 +171,7 @@ func Main(ctx context.Context, obctx *observation.Context, ready service.ReadyFu
FireworksDisableSingleTenant: config.Fireworks.DisableSingleTenant,
EmbeddingsAllowedModels: config.AllowedEmbeddingsModels,
AutoFlushStreamingResponses: config.AutoFlushStreamingResponses,
}, sources)
})
if err != nil {
return errors.Wrap(err, "httpapi.NewHandler")
}