Redis: remove RedisKeyValue constructor (#64442)

This PR removes the `redispool.RedisKeyValue` constructor in favor of
the `New...KeyValue` methods, which do not take a pool directly. This
way callers won't create a `Pool` reference, allowing us to track all
direct pool usage through `KeyValue.Pool()`.

This also simplifies a few things:
* Tests now use `NewTestKeyValue` instead of dialing up localhost
directly
* We can remove duplicated Redis connection logic in Cody Gateway
This commit is contained in:
Julie Tibshirani 2024-08-14 11:24:32 +03:00 committed by GitHub
parent 34ff925ed8
commit ca6e72fe18
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 77 additions and 197 deletions

View File

@ -5,7 +5,6 @@ go_library(
srcs = [
"main.go",
"metrics.go",
"redis.go",
"service.go",
"tracing.go",
],
@ -33,7 +32,6 @@ go_library(
"//internal/goroutine",
"//internal/httpcli",
"//internal/httpserver",
"//internal/lazyregexp",
"//internal/observation",
"//internal/rcache",
"//internal/redispool",

View File

@ -98,8 +98,6 @@ func Main(ctx context.Context, obctx *observation.Context, ready service.ReadyFu
}
}
redisPool := connectToRedis(cfg.RedisEndpoint)
// Create an uncached external doer, we never want to cache any responses.
// Not only is the cache hit rate going to be really low and requests large-ish,
// but also do we not want to retain any data.
@ -115,7 +113,10 @@ func Main(ctx context.Context, obctx *observation.Context, ready service.ReadyFu
return errors.Wrap(err, "init metric 'redis_latency'")
}
redisCache := redispool.RedisKeyValue(redisPool).WithLatencyRecorder(func(call string, latency time.Duration, err error) {
redisCache := redispool.NewKeyValue(cfg.RedisEndpoint, &redis.Pool{
MaxIdle: 10,
IdleTimeout: 240 * time.Second,
}).WithLatencyRecorder(func(call string, latency time.Duration, err error) {
redisLatency.Record(context.Background(), latency.Milliseconds(), metric.WithAttributeSet(attribute.NewSet(
attribute.Bool("error", err != nil),
attribute.String("command", call))))
@ -240,7 +241,7 @@ func Main(ctx context.Context, obctx *observation.Context, ready service.ReadyFu
return errors.Wrap(err, "httpapi.NewHandler")
}
// Diagnostic and Maintenance layers, exposing additional APIs and endpoints.
handler = httpapi.NewDiagnosticsHandler(obctx.Logger, handler, redisPool, cfg.DiagnosticsSecret, sources)
handler = httpapi.NewDiagnosticsHandler(obctx.Logger, handler, redisCache.Pool(), cfg.DiagnosticsSecret, sources)
handler = httpapi.NewMaintenanceHandler(obctx.Logger, handler, cfg, redisCache)
// Collect request client for downstream handlers. Outside of dev, we always set up
@ -257,7 +258,7 @@ func Main(ctx context.Context, obctx *observation.Context, ready service.ReadyFu
})
// Set up redis-based distributed mutex for the source syncer worker
sourceWorkerMutex := redsync.New(redigo.NewPool(redisPool)).NewMutex("source-syncer-worker",
sourceWorkerMutex := redsync.New(redigo.NewPool(redisCache.Pool())).NewMutex("source-syncer-worker",
// Do not retry endlessly becuase it's very likely that someone else has
// a long-standing hold on the mutex. We will try again on the next periodic
// goroutine run.

View File

@ -1,41 +0,0 @@
package shared
import (
"strings"
"time"
"github.com/gomodule/redigo/redis"
"github.com/sourcegraph/sourcegraph/internal/lazyregexp"
"github.com/sourcegraph/sourcegraph/lib/errors"
)
var schemeMatcher = lazyregexp.New(`^[A-Za-z][A-Za-z0-9\+\-\.]*://`)
// connectToRedis connects to Redis given the raw endpoint string.
// Cody Gateway maintains its own pool of Redis connections, it should not be dependent
// on the sourcegraph deployment Redis dualism.
//
// The string can have two formats:
// 1. If there is a HTTP scheme, it should be either be "redis://" or "rediss://" and the URL
// must be of the format specified in https://www.iana.org/assignments/uri-schemes/prov/redis.
// 2. Otherwise, it is assumed to be of the format $HOSTNAME:$PORT.
func connectToRedis(endpoint string) *redis.Pool {
return &redis.Pool{
MaxIdle: 10,
IdleTimeout: 240 * time.Second,
TestOnBorrow: func(c redis.Conn, t time.Time) error {
_, err := c.Do("PING")
return err
},
Dial: func() (redis.Conn, error) {
if schemeMatcher.MatchString(endpoint) { // expect "redis://"
return redis.DialURL(endpoint)
}
if strings.Contains(endpoint, "/") {
return nil, errors.New("Redis endpoint without scheme should not contain '/'")
}
return redis.Dial("tcp", endpoint)
},
}
}

View File

@ -47,7 +47,6 @@ go_test(
"//internal/types",
"//lib/pointers",
"//schema",
"@com_github_gomodule_redigo//redis",
"@com_github_google_go_cmp//cmp",
"@com_github_sourcegraph_log//logtest",
"@com_github_stretchr_testify//assert",

View File

@ -5,7 +5,6 @@ import (
"testing"
"time"
"github.com/gomodule/redigo/redis"
"github.com/google/go-cmp/cmp"
"github.com/sourcegraph/log/logtest"
"github.com/stretchr/testify/assert"
@ -28,22 +27,10 @@ func TestHandler_Handle(t *testing.T) {
db := database.NewDB(logger, dbtest.NewDB(t))
prefix := "__test__" + t.Name()
redisHost := "127.0.0.1:6379"
pool := &redis.Pool{
MaxIdle: 3,
IdleTimeout: 240 * time.Second,
Dial: func() (redis.Conn, error) {
return redis.Dial("tcp", redisHost)
},
TestOnBorrow: func(c redis.Conn, t time.Time) error {
_, err := c.Do("PING")
return err
},
}
kv := redispool.NewTestKeyValue()
t.Cleanup(func() {
if err := redispool.DeleteAllKeysWithPrefix(redispool.RedisKeyValue(pool), prefix); err != nil {
if err := redispool.DeleteAllKeysWithPrefix(kv, prefix); err != nil {
t.Logf("Failed to clear redis: %+v\n", err)
}
})
@ -69,14 +56,14 @@ func TestHandler_Handle(t *testing.T) {
h := handler{
externalServiceStore: db.ExternalServices(),
newRateLimiterFunc: func(bucketName string) ratelimit.GlobalLimiter {
return ratelimit.NewTestGlobalRateLimiter(pool, prefix, bucketName)
return ratelimit.NewTestGlobalRateLimiter(kv.Pool(), prefix, bucketName)
},
logger: logger,
}
err = h.Handle(ctx)
assert.NoError(t, err)
info, err := ratelimit.GetGlobalLimiterStateFromPool(ctx, pool, prefix)
info, err := ratelimit.GetGlobalLimiterStateFromStore(kv, prefix)
require.NoError(t, err)
if diff := cmp.Diff(map[string]ratelimit.GlobalLimiterInfo{

View File

@ -35,7 +35,6 @@ go_test(
"//internal/redispool",
"//lib/errors",
"@com_github_derision_test_go_mockgen_v2//testutil/require",
"@com_github_gomodule_redigo//redis",
"@com_github_google_go_cmp//cmp",
"@com_github_rafaeljusto_redigomock_v3//:redigomock",
"@com_github_stretchr_testify//require",

View File

@ -7,7 +7,6 @@ import (
"testing"
mockrequire "github.com/derision-test/go-mockgen/v2/testutil/require"
"github.com/gomodule/redigo/redis"
"github.com/rafaeljusto/redigomock/v3"
"github.com/stretchr/testify/require"
@ -173,19 +172,17 @@ func setupRedisTest(t *testing.T) {
t.Cleanup(func() { mockConn.Clear(); mockConn.Close() })
mockConn.GenericCommand("HSET").Handle(func(args []interface{}) (interface{}, error) {
cache[args[0].(string)] = []byte(args[2].(string))
return nil, nil
mockStore := redispool.NewMockKeyValue()
mockStore.HSetFunc.SetDefaultHook(func(key string, field string, value any) error {
cache[key] = []byte(value.(string))
return nil
})
mockConn.GenericCommand("HGET").Handle(func(args []interface{}) (interface{}, error) {
return cache[args[0].(string)], nil
mockStore.HGetFunc.SetDefaultHook(func(key string, field string) redispool.Value {
return redispool.NewValue(cache[key], nil)
})
mockConn.GenericCommand("DEL").Handle(func(args []interface{}) (interface{}, error) {
delete(cache, args[0].(string))
return nil, nil
mockStore.DelFunc.SetDefaultHook(func(key string) error {
delete(cache, key)
return nil
})
evalStore = redispool.RedisKeyValue(&redis.Pool{Dial: func() (redis.Conn, error) { return mockConn, nil }, MaxIdle: 10})
evalStore = mockStore
}

View File

@ -379,38 +379,21 @@ type GlobalLimiterInfo struct {
// GetGlobalLimiterState reports how all the existing rate limiters are configured,
// keyed by bucket name.
func GetGlobalLimiterState(ctx context.Context) (map[string]GlobalLimiterInfo, error) {
return GetGlobalLimiterStateFromPool(ctx, kv().Pool(), tokenBucketGlobalPrefix)
return GetGlobalLimiterStateFromStore(kv(), tokenBucketGlobalPrefix)
}
func GetGlobalLimiterStateFromPool(ctx context.Context, pool *redis.Pool, prefix string) (map[string]GlobalLimiterInfo, error) {
conn, err := pool.GetContext(ctx)
if err != nil {
return nil, errors.Wrap(err, "failed to get connection")
}
defer conn.Close()
func GetGlobalLimiterStateFromStore(rstore redispool.KeyValue, prefix string) (map[string]GlobalLimiterInfo, error) {
// First, find all known limiters in redis.
resp, err := conn.Do("KEYS", fmt.Sprintf("%s:*:%s", prefix, bucketAllowedBurstKeySuffix))
keys, err := rstore.Keys(fmt.Sprintf("%s:*:%s", prefix, bucketAllowedBurstKeySuffix))
if err != nil {
return nil, errors.Wrap(err, "failed to list keys")
}
keys, ok := resp.([]interface{})
if !ok {
return nil, errors.Newf("invalid response from redis keys command, expected []interface{}, got %T", resp)
}
m := make(map[string]GlobalLimiterInfo, len(keys))
for _, k := range keys {
kchars, ok := k.([]uint8)
if !ok {
return nil, errors.Newf("invalid response from redis keys command, expected string, got %T", k)
}
key := string(kchars)
for _, key := range keys {
limiterName := strings.TrimSuffix(strings.TrimPrefix(key, prefix+":"), ":"+bucketAllowedBurstKeySuffix)
rlKeys := getRateLimiterKeys(prefix, limiterName)
rstore := redispool.RedisKeyValue(pool)
currentCapacity, err := rstore.Get(rlKeys.BucketKey).Int()
if err != nil && err != redis.ErrNil {
return nil, errors.Wrap(err, "failed to read current capacity")
@ -515,20 +498,10 @@ type TB interface {
func SetupForTest(t TB) {
t.Helper()
pool := &redis.Pool{
MaxIdle: 3,
IdleTimeout: 240 * time.Second,
Dial: func() (redis.Conn, error) {
return redis.Dial("tcp", "127.0.0.1:6379")
},
TestOnBorrow: func(c redis.Conn, t time.Time) error {
_, err := c.Do("PING")
return err
},
}
kvMock = redispool.NewTestKeyValue()
tokenBucketGlobalPrefix = "__test__" + t.Name()
c := pool.Get()
c := kvMock.Pool().Get()
defer c.Close()
// If we are not on CI, skip the test if our redis connection fails.
@ -539,7 +512,6 @@ func SetupForTest(t TB) {
}
}
kvMock = redispool.RedisKeyValue(pool)
if err := redispool.DeleteAllKeysWithPrefix(kvMock, tokenBucketGlobalPrefix); err != nil {
t.Fatalf("could not clear test prefix: &v", err)
}

View File

@ -24,8 +24,8 @@ func TestGlobalRateLimiter(t *testing.T) {
// This test is verifying the basic functionality of the rate limiter.
// We should be able to get a token once the token bucket config is set.
prefix := "__test__" + t.Name()
pool := redisPoolForTest(t, prefix)
rl := getTestRateLimiter(prefix, pool, testBucketName)
kv := redisKeyValueForTest(t, prefix)
rl := getTestRateLimiter(prefix, kv.Pool(), testBucketName)
clock := glock.NewMockClock()
rl.nowFunc = clock.Now
@ -135,8 +135,8 @@ func TestGlobalRateLimiter_TimeToWaitExceedsLimit(t *testing.T) {
// This test is verifying that if the amount of time needed to wait for a token
// exceeds the context deadline, a TokenGrantExceedsLimitError is returned.
prefix := "__test__" + t.Name()
pool := redisPoolForTest(t, prefix)
rl := getTestRateLimiter(prefix, pool, testBucketName)
kv := redisKeyValueForTest(t, prefix)
rl := getTestRateLimiter(prefix, kv.Pool(), testBucketName)
clock := glock.NewMockClock()
rl.nowFunc = clock.Now
@ -176,8 +176,8 @@ func TestGlobalRateLimiter_TimeToWaitExceedsLimit(t *testing.T) {
func TestGlobalRateLimiter_AllBlockedError(t *testing.T) {
// Verify that a limit of 0 means "block all".
prefix := "__test__" + t.Name()
pool := redisPoolForTest(t, prefix)
rl := getTestRateLimiter(prefix, pool, testBucketName)
kv := redisKeyValueForTest(t, prefix)
rl := getTestRateLimiter(prefix, kv.Pool(), testBucketName)
clock := glock.NewMockClock()
rl.nowFunc = clock.Now
@ -205,8 +205,8 @@ func TestGlobalRateLimiter_AllBlockedError(t *testing.T) {
func TestGlobalRateLimiter_Inf(t *testing.T) {
// Verify that a rate of -1 means inf.
prefix := "__test__" + t.Name()
pool := redisPoolForTest(t, prefix)
rl := getTestRateLimiter(prefix, pool, testBucketName)
kv := redisKeyValueForTest(t, prefix)
rl := getTestRateLimiter(prefix, kv.Pool(), testBucketName)
clock := glock.NewMockClock()
rl.nowFunc = clock.Now
@ -235,8 +235,8 @@ func TestGlobalRateLimiter_UnconfiguredLimiter(t *testing.T) {
// This test is verifying the basic functionality of the rate limiter.
// We should be able to get a token once the token bucket config is set.
prefix := "__test__" + t.Name()
pool := redisPoolForTest(t, prefix)
rl := getTestRateLimiter(prefix, pool, testBucketName)
kv := redisKeyValueForTest(t, prefix)
rl := getTestRateLimiter(prefix, kv.Pool(), testBucketName)
clock := glock.NewMockClock()
rl.nowFunc = clock.Now
@ -303,32 +303,23 @@ func getTestRateLimiter(prefix string, pool *redis.Pool, bucketName string) glob
// Mostly copy-pasta from rache. Will clean up later as the relationship
// between the two packages becomes cleaner.
func redisPoolForTest(t *testing.T, prefix string) *redis.Pool {
func redisKeyValueForTest(t *testing.T, prefix string) redispool.KeyValue {
t.Helper()
pool := &redis.Pool{
MaxIdle: 3,
IdleTimeout: 240 * time.Second,
Dial: func() (redis.Conn, error) {
return redis.Dial("tcp", "127.0.0.1:6379")
},
TestOnBorrow: func(c redis.Conn, t time.Time) error {
_, err := c.Do("PING")
return err
},
}
if err := redispool.DeleteAllKeysWithPrefix(redispool.RedisKeyValue(pool), prefix); err != nil {
store := redispool.NewTestKeyValue()
if err := redispool.DeleteAllKeysWithPrefix(store, prefix); err != nil {
t.Logf("Could not clear test prefix name=%q prefix=%q error=%v", t.Name(), prefix, err)
}
return pool
return store
}
func TestLimitInfo(t *testing.T) {
ctx := context.Background()
prefix := "__test__" + t.Name()
pool := redisPoolForTest(t, prefix)
store := redisKeyValueForTest(t, prefix)
pool := store.Pool()
r1 := getTestRateLimiter(prefix, pool, "extsvc:github:1")
// 1/s allowed.
@ -340,7 +331,7 @@ func TestLimitInfo(t *testing.T) {
// No requests allowed.
require.NoError(t, r3.SetTokenBucketConfig(ctx, 0, time.Hour))
info, err := GetGlobalLimiterStateFromPool(ctx, pool, prefix)
info, err := GetGlobalLimiterStateFromStore(store, prefix)
require.NoError(t, err)
if diff := cmp.Diff(map[string]GlobalLimiterInfo{
@ -372,7 +363,7 @@ func TestLimitInfo(t *testing.T) {
// Now claim 3 tokens from the limiter.
require.NoError(t, r1.WaitN(ctx, 3))
info, err = GetGlobalLimiterStateFromPool(ctx, pool, prefix)
info, err = GetGlobalLimiterStateFromStore(store, prefix)
require.NoError(t, err)
if diff := cmp.Diff(map[string]GlobalLimiterInfo{

View File

@ -243,25 +243,14 @@ const testAddr = "127.0.0.1:6379"
func SetupForTest(t testing.TB) redispool.KeyValue {
t.Helper()
pool := &redis.Pool{
MaxIdle: 3,
IdleTimeout: 240 * time.Second,
Dial: func() (redis.Conn, error) {
return redis.Dial("tcp", testAddr)
},
TestOnBorrow: func(c redis.Conn, t time.Time) error {
_, err := c.Do("PING")
return err
},
}
kvMock = redispool.RedisKeyValue(pool)
kvMock = redispool.NewTestKeyValue()
t.Cleanup(func() {
pool.Close()
kvMock.Pool().Close()
kvMock = nil
})
globalPrefix = "__test__" + t.Name()
c := pool.Get()
c := kvMock.Pool().Get()
defer c.Close()
// If we are not on CI, skip the test if our redis connection fails.

View File

@ -122,13 +122,6 @@ func (v Values) StringMap() (map[string]string, error) {
type LatencyRecorder func(call string, latency time.Duration, err error)
type redisKeyValue struct {
pool *redis.Pool
ctx context.Context
prefix string
recorder *LatencyRecorder
}
// NewKeyValue returns a KeyValue for addr.
//
// poolOpts is a required argument which sets defaults in the case we connect
@ -141,20 +134,39 @@ func NewKeyValue(addr string, poolOpts *redis.Pool) KeyValue {
poolOpts.Dial = func() (redis.Conn, error) {
return dialRedis(addr)
}
return RedisKeyValue(poolOpts)
return &redisKeyValue{pool: poolOpts}
}
// RedisKeyValue returns a KeyValue backed by pool.
// NewTestKeyValue returns a KeyValue connected to a local Redis server for integration tests.
func NewTestKeyValue() KeyValue {
pool := &redis.Pool{
MaxIdle: 3,
IdleTimeout: 240 * time.Second,
Dial: func() (redis.Conn, error) {
return redis.Dial("tcp", "127.0.0.1:6379")
},
TestOnBorrow: func(c redis.Conn, t time.Time) error {
_, err := c.Do("PING")
return err
},
}
return &redisKeyValue{pool: pool}
}
// redisKeyValue is a KeyValue backed by pool
//
// Note: RedisKeyValue additionally implements
// Note: redisKeyValue additionally implements
//
// interface {
// // WithPrefix wraps r to return a RedisKeyValue that prefixes all keys with
// // prefix + ":".
// WithPrefix(prefix string) KeyValue
// }
func RedisKeyValue(pool *redis.Pool) KeyValue {
return &redisKeyValue{pool: pool}
type redisKeyValue struct {
pool *redis.Pool
ctx context.Context
prefix string
recorder *LatencyRecorder
}
func (r *redisKeyValue) Get(key string) Value {

View File

@ -396,20 +396,10 @@ func TestKeyValueWithPrefix(t *testing.T) {
func redisKeyValueForTest(t *testing.T) redispool.KeyValue {
t.Helper()
pool := &redis.Pool{
MaxIdle: 3,
IdleTimeout: 240 * time.Second,
Dial: func() (redis.Conn, error) {
return redis.Dial("tcp", "127.0.0.1:6379")
},
TestOnBorrow: func(c redis.Conn, t time.Time) error {
_, err := c.Do("PING")
return err
},
}
kv := redispool.NewTestKeyValue()
prefix := "__test__" + t.Name()
c := pool.Get()
c := kv.Pool().Get()
defer c.Close()
// If we are not on CI, skip the test if our redis connection fails.
@ -420,7 +410,6 @@ func redisKeyValueForTest(t *testing.T) redispool.KeyValue {
}
}
kv := redispool.RedisKeyValue(pool)
if err := redispool.DeleteAllKeysWithPrefix(kv, prefix); err != nil {
t.Logf("Could not clear test prefix name=%q prefix=%q error=%v", t.Name(), prefix, err)
}

View File

@ -6,9 +6,7 @@ import (
"reflect"
"strconv"
"testing"
"time"
"github.com/gomodule/redigo/redis"
"github.com/sourcegraph/log/logtest"
)
@ -40,19 +38,9 @@ func TestMain(m *testing.M) {
func TestDeleteAllKeysWithPrefix(t *testing.T) {
t.Helper()
pool := &redis.Pool{
MaxIdle: 3,
IdleTimeout: 240 * time.Second,
Dial: func() (redis.Conn, error) {
return redis.Dial("tcp", "127.0.0.1:6379")
},
TestOnBorrow: func(c redis.Conn, t time.Time) error {
_, err := c.Do("PING")
return err
},
}
kv := NewTestKeyValue()
c := pool.Get()
c := kv.Pool().Get()
defer c.Close()
// If we are not on CI, skip the test if our redis connection fails.
@ -63,7 +51,6 @@ func TestDeleteAllKeysWithPrefix(t *testing.T) {
}
}
kv := RedisKeyValue(pool)
var aKeys, bKeys []string
var key string
for i := range 10 {