diff --git a/cmd/cody-gateway/shared/BUILD.bazel b/cmd/cody-gateway/shared/BUILD.bazel index f7f40c57bae..580d071afb0 100644 --- a/cmd/cody-gateway/shared/BUILD.bazel +++ b/cmd/cody-gateway/shared/BUILD.bazel @@ -5,7 +5,6 @@ go_library( srcs = [ "main.go", "metrics.go", - "redis.go", "service.go", "tracing.go", ], @@ -33,7 +32,6 @@ go_library( "//internal/goroutine", "//internal/httpcli", "//internal/httpserver", - "//internal/lazyregexp", "//internal/observation", "//internal/rcache", "//internal/redispool", diff --git a/cmd/cody-gateway/shared/main.go b/cmd/cody-gateway/shared/main.go index e742511116d..ba1d8ad7910 100644 --- a/cmd/cody-gateway/shared/main.go +++ b/cmd/cody-gateway/shared/main.go @@ -98,8 +98,6 @@ func Main(ctx context.Context, obctx *observation.Context, ready service.ReadyFu } } - redisPool := connectToRedis(cfg.RedisEndpoint) - // Create an uncached external doer, we never want to cache any responses. // Not only is the cache hit rate going to be really low and requests large-ish, // but also do we not want to retain any data. @@ -115,7 +113,10 @@ func Main(ctx context.Context, obctx *observation.Context, ready service.ReadyFu return errors.Wrap(err, "init metric 'redis_latency'") } - redisCache := redispool.RedisKeyValue(redisPool).WithLatencyRecorder(func(call string, latency time.Duration, err error) { + redisCache := redispool.NewKeyValue(cfg.RedisEndpoint, &redis.Pool{ + MaxIdle: 10, + IdleTimeout: 240 * time.Second, + }).WithLatencyRecorder(func(call string, latency time.Duration, err error) { redisLatency.Record(context.Background(), latency.Milliseconds(), metric.WithAttributeSet(attribute.NewSet( attribute.Bool("error", err != nil), attribute.String("command", call)))) @@ -240,7 +241,7 @@ func Main(ctx context.Context, obctx *observation.Context, ready service.ReadyFu return errors.Wrap(err, "httpapi.NewHandler") } // Diagnostic and Maintenance layers, exposing additional APIs and endpoints. - handler = httpapi.NewDiagnosticsHandler(obctx.Logger, handler, redisPool, cfg.DiagnosticsSecret, sources) + handler = httpapi.NewDiagnosticsHandler(obctx.Logger, handler, redisCache.Pool(), cfg.DiagnosticsSecret, sources) handler = httpapi.NewMaintenanceHandler(obctx.Logger, handler, cfg, redisCache) // Collect request client for downstream handlers. Outside of dev, we always set up @@ -257,7 +258,7 @@ func Main(ctx context.Context, obctx *observation.Context, ready service.ReadyFu }) // Set up redis-based distributed mutex for the source syncer worker - sourceWorkerMutex := redsync.New(redigo.NewPool(redisPool)).NewMutex("source-syncer-worker", + sourceWorkerMutex := redsync.New(redigo.NewPool(redisCache.Pool())).NewMutex("source-syncer-worker", // Do not retry endlessly becuase it's very likely that someone else has // a long-standing hold on the mutex. We will try again on the next periodic // goroutine run. diff --git a/cmd/cody-gateway/shared/redis.go b/cmd/cody-gateway/shared/redis.go deleted file mode 100644 index a57c5fa59a6..00000000000 --- a/cmd/cody-gateway/shared/redis.go +++ /dev/null @@ -1,41 +0,0 @@ -package shared - -import ( - "strings" - "time" - - "github.com/gomodule/redigo/redis" - - "github.com/sourcegraph/sourcegraph/internal/lazyregexp" - "github.com/sourcegraph/sourcegraph/lib/errors" -) - -var schemeMatcher = lazyregexp.New(`^[A-Za-z][A-Za-z0-9\+\-\.]*://`) - -// connectToRedis connects to Redis given the raw endpoint string. -// Cody Gateway maintains its own pool of Redis connections, it should not be dependent -// on the sourcegraph deployment Redis dualism. -// -// The string can have two formats: -// 1. If there is a HTTP scheme, it should be either be "redis://" or "rediss://" and the URL -// must be of the format specified in https://www.iana.org/assignments/uri-schemes/prov/redis. -// 2. Otherwise, it is assumed to be of the format $HOSTNAME:$PORT. -func connectToRedis(endpoint string) *redis.Pool { - return &redis.Pool{ - MaxIdle: 10, - IdleTimeout: 240 * time.Second, - TestOnBorrow: func(c redis.Conn, t time.Time) error { - _, err := c.Do("PING") - return err - }, - Dial: func() (redis.Conn, error) { - if schemeMatcher.MatchString(endpoint) { // expect "redis://" - return redis.DialURL(endpoint) - } - if strings.Contains(endpoint, "/") { - return nil, errors.New("Redis endpoint without scheme should not contain '/'") - } - return redis.Dial("tcp", endpoint) - }, - } -} diff --git a/cmd/worker/internal/ratelimit/BUILD.bazel b/cmd/worker/internal/ratelimit/BUILD.bazel index ad708e6deb8..0ed2cf034c1 100644 --- a/cmd/worker/internal/ratelimit/BUILD.bazel +++ b/cmd/worker/internal/ratelimit/BUILD.bazel @@ -47,7 +47,6 @@ go_test( "//internal/types", "//lib/pointers", "//schema", - "@com_github_gomodule_redigo//redis", "@com_github_google_go_cmp//cmp", "@com_github_sourcegraph_log//logtest", "@com_github_stretchr_testify//assert", diff --git a/cmd/worker/internal/ratelimit/handler_test.go b/cmd/worker/internal/ratelimit/handler_test.go index 6fc04f2fbc9..a9aba5529aa 100644 --- a/cmd/worker/internal/ratelimit/handler_test.go +++ b/cmd/worker/internal/ratelimit/handler_test.go @@ -5,7 +5,6 @@ import ( "testing" "time" - "github.com/gomodule/redigo/redis" "github.com/google/go-cmp/cmp" "github.com/sourcegraph/log/logtest" "github.com/stretchr/testify/assert" @@ -28,22 +27,10 @@ func TestHandler_Handle(t *testing.T) { db := database.NewDB(logger, dbtest.NewDB(t)) prefix := "__test__" + t.Name() - redisHost := "127.0.0.1:6379" - - pool := &redis.Pool{ - MaxIdle: 3, - IdleTimeout: 240 * time.Second, - Dial: func() (redis.Conn, error) { - return redis.Dial("tcp", redisHost) - }, - TestOnBorrow: func(c redis.Conn, t time.Time) error { - _, err := c.Do("PING") - return err - }, - } + kv := redispool.NewTestKeyValue() t.Cleanup(func() { - if err := redispool.DeleteAllKeysWithPrefix(redispool.RedisKeyValue(pool), prefix); err != nil { + if err := redispool.DeleteAllKeysWithPrefix(kv, prefix); err != nil { t.Logf("Failed to clear redis: %+v\n", err) } }) @@ -69,14 +56,14 @@ func TestHandler_Handle(t *testing.T) { h := handler{ externalServiceStore: db.ExternalServices(), newRateLimiterFunc: func(bucketName string) ratelimit.GlobalLimiter { - return ratelimit.NewTestGlobalRateLimiter(pool, prefix, bucketName) + return ratelimit.NewTestGlobalRateLimiter(kv.Pool(), prefix, bucketName) }, logger: logger, } err = h.Handle(ctx) assert.NoError(t, err) - info, err := ratelimit.GetGlobalLimiterStateFromPool(ctx, pool, prefix) + info, err := ratelimit.GetGlobalLimiterStateFromStore(kv, prefix) require.NoError(t, err) if diff := cmp.Diff(map[string]ratelimit.GlobalLimiterInfo{ diff --git a/internal/featureflag/BUILD.bazel b/internal/featureflag/BUILD.bazel index d8ca6f76a92..23c7cdf6ac6 100644 --- a/internal/featureflag/BUILD.bazel +++ b/internal/featureflag/BUILD.bazel @@ -35,7 +35,6 @@ go_test( "//internal/redispool", "//lib/errors", "@com_github_derision_test_go_mockgen_v2//testutil/require", - "@com_github_gomodule_redigo//redis", "@com_github_google_go_cmp//cmp", "@com_github_rafaeljusto_redigomock_v3//:redigomock", "@com_github_stretchr_testify//require", diff --git a/internal/featureflag/middleware_test.go b/internal/featureflag/middleware_test.go index 07530f795e2..7c9f0fe0333 100644 --- a/internal/featureflag/middleware_test.go +++ b/internal/featureflag/middleware_test.go @@ -7,7 +7,6 @@ import ( "testing" mockrequire "github.com/derision-test/go-mockgen/v2/testutil/require" - "github.com/gomodule/redigo/redis" "github.com/rafaeljusto/redigomock/v3" "github.com/stretchr/testify/require" @@ -173,19 +172,17 @@ func setupRedisTest(t *testing.T) { t.Cleanup(func() { mockConn.Clear(); mockConn.Close() }) - mockConn.GenericCommand("HSET").Handle(func(args []interface{}) (interface{}, error) { - cache[args[0].(string)] = []byte(args[2].(string)) - return nil, nil + mockStore := redispool.NewMockKeyValue() + mockStore.HSetFunc.SetDefaultHook(func(key string, field string, value any) error { + cache[key] = []byte(value.(string)) + return nil }) - - mockConn.GenericCommand("HGET").Handle(func(args []interface{}) (interface{}, error) { - return cache[args[0].(string)], nil + mockStore.HGetFunc.SetDefaultHook(func(key string, field string) redispool.Value { + return redispool.NewValue(cache[key], nil) }) - - mockConn.GenericCommand("DEL").Handle(func(args []interface{}) (interface{}, error) { - delete(cache, args[0].(string)) - return nil, nil + mockStore.DelFunc.SetDefaultHook(func(key string) error { + delete(cache, key) + return nil }) - - evalStore = redispool.RedisKeyValue(&redis.Pool{Dial: func() (redis.Conn, error) { return mockConn, nil }, MaxIdle: 10}) + evalStore = mockStore } diff --git a/internal/ratelimit/globallimiter.go b/internal/ratelimit/globallimiter.go index ef7c8fdf846..c0df5597398 100644 --- a/internal/ratelimit/globallimiter.go +++ b/internal/ratelimit/globallimiter.go @@ -379,38 +379,21 @@ type GlobalLimiterInfo struct { // GetGlobalLimiterState reports how all the existing rate limiters are configured, // keyed by bucket name. func GetGlobalLimiterState(ctx context.Context) (map[string]GlobalLimiterInfo, error) { - return GetGlobalLimiterStateFromPool(ctx, kv().Pool(), tokenBucketGlobalPrefix) + return GetGlobalLimiterStateFromStore(kv(), tokenBucketGlobalPrefix) } -func GetGlobalLimiterStateFromPool(ctx context.Context, pool *redis.Pool, prefix string) (map[string]GlobalLimiterInfo, error) { - conn, err := pool.GetContext(ctx) - if err != nil { - return nil, errors.Wrap(err, "failed to get connection") - } - defer conn.Close() - +func GetGlobalLimiterStateFromStore(rstore redispool.KeyValue, prefix string) (map[string]GlobalLimiterInfo, error) { // First, find all known limiters in redis. - resp, err := conn.Do("KEYS", fmt.Sprintf("%s:*:%s", prefix, bucketAllowedBurstKeySuffix)) + keys, err := rstore.Keys(fmt.Sprintf("%s:*:%s", prefix, bucketAllowedBurstKeySuffix)) if err != nil { return nil, errors.Wrap(err, "failed to list keys") } - keys, ok := resp.([]interface{}) - if !ok { - return nil, errors.Newf("invalid response from redis keys command, expected []interface{}, got %T", resp) - } m := make(map[string]GlobalLimiterInfo, len(keys)) - for _, k := range keys { - kchars, ok := k.([]uint8) - if !ok { - return nil, errors.Newf("invalid response from redis keys command, expected string, got %T", k) - } - key := string(kchars) + for _, key := range keys { limiterName := strings.TrimSuffix(strings.TrimPrefix(key, prefix+":"), ":"+bucketAllowedBurstKeySuffix) rlKeys := getRateLimiterKeys(prefix, limiterName) - rstore := redispool.RedisKeyValue(pool) - currentCapacity, err := rstore.Get(rlKeys.BucketKey).Int() if err != nil && err != redis.ErrNil { return nil, errors.Wrap(err, "failed to read current capacity") @@ -515,20 +498,10 @@ type TB interface { func SetupForTest(t TB) { t.Helper() - pool := &redis.Pool{ - MaxIdle: 3, - IdleTimeout: 240 * time.Second, - Dial: func() (redis.Conn, error) { - return redis.Dial("tcp", "127.0.0.1:6379") - }, - TestOnBorrow: func(c redis.Conn, t time.Time) error { - _, err := c.Do("PING") - return err - }, - } + kvMock = redispool.NewTestKeyValue() tokenBucketGlobalPrefix = "__test__" + t.Name() - c := pool.Get() + c := kvMock.Pool().Get() defer c.Close() // If we are not on CI, skip the test if our redis connection fails. @@ -539,7 +512,6 @@ func SetupForTest(t TB) { } } - kvMock = redispool.RedisKeyValue(pool) if err := redispool.DeleteAllKeysWithPrefix(kvMock, tokenBucketGlobalPrefix); err != nil { t.Fatalf("could not clear test prefix: &v", err) } diff --git a/internal/ratelimit/globallimiter_test.go b/internal/ratelimit/globallimiter_test.go index 2814be3cea3..e259d7bb0ee 100644 --- a/internal/ratelimit/globallimiter_test.go +++ b/internal/ratelimit/globallimiter_test.go @@ -24,8 +24,8 @@ func TestGlobalRateLimiter(t *testing.T) { // This test is verifying the basic functionality of the rate limiter. // We should be able to get a token once the token bucket config is set. prefix := "__test__" + t.Name() - pool := redisPoolForTest(t, prefix) - rl := getTestRateLimiter(prefix, pool, testBucketName) + kv := redisKeyValueForTest(t, prefix) + rl := getTestRateLimiter(prefix, kv.Pool(), testBucketName) clock := glock.NewMockClock() rl.nowFunc = clock.Now @@ -135,8 +135,8 @@ func TestGlobalRateLimiter_TimeToWaitExceedsLimit(t *testing.T) { // This test is verifying that if the amount of time needed to wait for a token // exceeds the context deadline, a TokenGrantExceedsLimitError is returned. prefix := "__test__" + t.Name() - pool := redisPoolForTest(t, prefix) - rl := getTestRateLimiter(prefix, pool, testBucketName) + kv := redisKeyValueForTest(t, prefix) + rl := getTestRateLimiter(prefix, kv.Pool(), testBucketName) clock := glock.NewMockClock() rl.nowFunc = clock.Now @@ -176,8 +176,8 @@ func TestGlobalRateLimiter_TimeToWaitExceedsLimit(t *testing.T) { func TestGlobalRateLimiter_AllBlockedError(t *testing.T) { // Verify that a limit of 0 means "block all". prefix := "__test__" + t.Name() - pool := redisPoolForTest(t, prefix) - rl := getTestRateLimiter(prefix, pool, testBucketName) + kv := redisKeyValueForTest(t, prefix) + rl := getTestRateLimiter(prefix, kv.Pool(), testBucketName) clock := glock.NewMockClock() rl.nowFunc = clock.Now @@ -205,8 +205,8 @@ func TestGlobalRateLimiter_AllBlockedError(t *testing.T) { func TestGlobalRateLimiter_Inf(t *testing.T) { // Verify that a rate of -1 means inf. prefix := "__test__" + t.Name() - pool := redisPoolForTest(t, prefix) - rl := getTestRateLimiter(prefix, pool, testBucketName) + kv := redisKeyValueForTest(t, prefix) + rl := getTestRateLimiter(prefix, kv.Pool(), testBucketName) clock := glock.NewMockClock() rl.nowFunc = clock.Now @@ -235,8 +235,8 @@ func TestGlobalRateLimiter_UnconfiguredLimiter(t *testing.T) { // This test is verifying the basic functionality of the rate limiter. // We should be able to get a token once the token bucket config is set. prefix := "__test__" + t.Name() - pool := redisPoolForTest(t, prefix) - rl := getTestRateLimiter(prefix, pool, testBucketName) + kv := redisKeyValueForTest(t, prefix) + rl := getTestRateLimiter(prefix, kv.Pool(), testBucketName) clock := glock.NewMockClock() rl.nowFunc = clock.Now @@ -303,32 +303,23 @@ func getTestRateLimiter(prefix string, pool *redis.Pool, bucketName string) glob // Mostly copy-pasta from rache. Will clean up later as the relationship // between the two packages becomes cleaner. -func redisPoolForTest(t *testing.T, prefix string) *redis.Pool { +func redisKeyValueForTest(t *testing.T, prefix string) redispool.KeyValue { t.Helper() - pool := &redis.Pool{ - MaxIdle: 3, - IdleTimeout: 240 * time.Second, - Dial: func() (redis.Conn, error) { - return redis.Dial("tcp", "127.0.0.1:6379") - }, - TestOnBorrow: func(c redis.Conn, t time.Time) error { - _, err := c.Do("PING") - return err - }, - } - - if err := redispool.DeleteAllKeysWithPrefix(redispool.RedisKeyValue(pool), prefix); err != nil { + store := redispool.NewTestKeyValue() + if err := redispool.DeleteAllKeysWithPrefix(store, prefix); err != nil { t.Logf("Could not clear test prefix name=%q prefix=%q error=%v", t.Name(), prefix, err) } - return pool + return store } func TestLimitInfo(t *testing.T) { ctx := context.Background() prefix := "__test__" + t.Name() - pool := redisPoolForTest(t, prefix) + + store := redisKeyValueForTest(t, prefix) + pool := store.Pool() r1 := getTestRateLimiter(prefix, pool, "extsvc:github:1") // 1/s allowed. @@ -340,7 +331,7 @@ func TestLimitInfo(t *testing.T) { // No requests allowed. require.NoError(t, r3.SetTokenBucketConfig(ctx, 0, time.Hour)) - info, err := GetGlobalLimiterStateFromPool(ctx, pool, prefix) + info, err := GetGlobalLimiterStateFromStore(store, prefix) require.NoError(t, err) if diff := cmp.Diff(map[string]GlobalLimiterInfo{ @@ -372,7 +363,7 @@ func TestLimitInfo(t *testing.T) { // Now claim 3 tokens from the limiter. require.NoError(t, r1.WaitN(ctx, 3)) - info, err = GetGlobalLimiterStateFromPool(ctx, pool, prefix) + info, err = GetGlobalLimiterStateFromStore(store, prefix) require.NoError(t, err) if diff := cmp.Diff(map[string]GlobalLimiterInfo{ diff --git a/internal/rcache/rcache.go b/internal/rcache/rcache.go index 10f0eb13974..d7b957c6a85 100644 --- a/internal/rcache/rcache.go +++ b/internal/rcache/rcache.go @@ -243,25 +243,14 @@ const testAddr = "127.0.0.1:6379" func SetupForTest(t testing.TB) redispool.KeyValue { t.Helper() - pool := &redis.Pool{ - MaxIdle: 3, - IdleTimeout: 240 * time.Second, - Dial: func() (redis.Conn, error) { - return redis.Dial("tcp", testAddr) - }, - TestOnBorrow: func(c redis.Conn, t time.Time) error { - _, err := c.Do("PING") - return err - }, - } - kvMock = redispool.RedisKeyValue(pool) + kvMock = redispool.NewTestKeyValue() t.Cleanup(func() { - pool.Close() + kvMock.Pool().Close() kvMock = nil }) globalPrefix = "__test__" + t.Name() - c := pool.Get() + c := kvMock.Pool().Get() defer c.Close() // If we are not on CI, skip the test if our redis connection fails. diff --git a/internal/redispool/keyvalue.go b/internal/redispool/keyvalue.go index 4b1fff0ed3b..f389521bb24 100644 --- a/internal/redispool/keyvalue.go +++ b/internal/redispool/keyvalue.go @@ -122,13 +122,6 @@ func (v Values) StringMap() (map[string]string, error) { type LatencyRecorder func(call string, latency time.Duration, err error) -type redisKeyValue struct { - pool *redis.Pool - ctx context.Context - prefix string - recorder *LatencyRecorder -} - // NewKeyValue returns a KeyValue for addr. // // poolOpts is a required argument which sets defaults in the case we connect @@ -141,20 +134,39 @@ func NewKeyValue(addr string, poolOpts *redis.Pool) KeyValue { poolOpts.Dial = func() (redis.Conn, error) { return dialRedis(addr) } - return RedisKeyValue(poolOpts) + return &redisKeyValue{pool: poolOpts} } -// RedisKeyValue returns a KeyValue backed by pool. +// NewTestKeyValue returns a KeyValue connected to a local Redis server for integration tests. +func NewTestKeyValue() KeyValue { + pool := &redis.Pool{ + MaxIdle: 3, + IdleTimeout: 240 * time.Second, + Dial: func() (redis.Conn, error) { + return redis.Dial("tcp", "127.0.0.1:6379") + }, + TestOnBorrow: func(c redis.Conn, t time.Time) error { + _, err := c.Do("PING") + return err + }, + } + return &redisKeyValue{pool: pool} +} + +// redisKeyValue is a KeyValue backed by pool // -// Note: RedisKeyValue additionally implements +// Note: redisKeyValue additionally implements // // interface { // // WithPrefix wraps r to return a RedisKeyValue that prefixes all keys with // // prefix + ":". // WithPrefix(prefix string) KeyValue // } -func RedisKeyValue(pool *redis.Pool) KeyValue { - return &redisKeyValue{pool: pool} +type redisKeyValue struct { + pool *redis.Pool + ctx context.Context + prefix string + recorder *LatencyRecorder } func (r *redisKeyValue) Get(key string) Value { diff --git a/internal/redispool/keyvalue_test.go b/internal/redispool/keyvalue_test.go index 09e9d5c3354..b80c609125e 100644 --- a/internal/redispool/keyvalue_test.go +++ b/internal/redispool/keyvalue_test.go @@ -396,20 +396,10 @@ func TestKeyValueWithPrefix(t *testing.T) { func redisKeyValueForTest(t *testing.T) redispool.KeyValue { t.Helper() - pool := &redis.Pool{ - MaxIdle: 3, - IdleTimeout: 240 * time.Second, - Dial: func() (redis.Conn, error) { - return redis.Dial("tcp", "127.0.0.1:6379") - }, - TestOnBorrow: func(c redis.Conn, t time.Time) error { - _, err := c.Do("PING") - return err - }, - } - + kv := redispool.NewTestKeyValue() prefix := "__test__" + t.Name() - c := pool.Get() + + c := kv.Pool().Get() defer c.Close() // If we are not on CI, skip the test if our redis connection fails. @@ -420,7 +410,6 @@ func redisKeyValueForTest(t *testing.T) redispool.KeyValue { } } - kv := redispool.RedisKeyValue(pool) if err := redispool.DeleteAllKeysWithPrefix(kv, prefix); err != nil { t.Logf("Could not clear test prefix name=%q prefix=%q error=%v", t.Name(), prefix, err) } diff --git a/internal/redispool/redispool_test.go b/internal/redispool/redispool_test.go index 1cfdb4d3f27..9b6f7c1789c 100644 --- a/internal/redispool/redispool_test.go +++ b/internal/redispool/redispool_test.go @@ -6,9 +6,7 @@ import ( "reflect" "strconv" "testing" - "time" - "github.com/gomodule/redigo/redis" "github.com/sourcegraph/log/logtest" ) @@ -40,19 +38,9 @@ func TestMain(m *testing.M) { func TestDeleteAllKeysWithPrefix(t *testing.T) { t.Helper() - pool := &redis.Pool{ - MaxIdle: 3, - IdleTimeout: 240 * time.Second, - Dial: func() (redis.Conn, error) { - return redis.Dial("tcp", "127.0.0.1:6379") - }, - TestOnBorrow: func(c redis.Conn, t time.Time) error { - _, err := c.Do("PING") - return err - }, - } + kv := NewTestKeyValue() - c := pool.Get() + c := kv.Pool().Get() defer c.Close() // If we are not on CI, skip the test if our redis connection fails. @@ -63,7 +51,6 @@ func TestDeleteAllKeysWithPrefix(t *testing.T) { } } - kv := RedisKeyValue(pool) var aKeys, bKeys []string var key string for i := range 10 {