diff --git a/cmd/cody-gateway/internal/httpapi/BUILD.bazel b/cmd/cody-gateway/internal/httpapi/BUILD.bazel index 273d9b80c05..c42fa359d85 100644 --- a/cmd/cody-gateway/internal/httpapi/BUILD.bazel +++ b/cmd/cody-gateway/internal/httpapi/BUILD.bazel @@ -36,7 +36,6 @@ go_library( "//internal/version", "//lib/errors", "//lib/pointers", - "@com_github_gomodule_redigo//redis", "@com_github_gorilla_mux//:mux", "@com_github_khan_genqlient//graphql", "@com_github_sourcegraph_log//:log", diff --git a/cmd/cody-gateway/internal/httpapi/diagnostics.go b/cmd/cody-gateway/internal/httpapi/diagnostics.go index c6289be72e4..ff1ccfe039d 100644 --- a/cmd/cody-gateway/internal/httpapi/diagnostics.go +++ b/cmd/cody-gateway/internal/httpapi/diagnostics.go @@ -6,7 +6,6 @@ import ( "net/http" "strings" - "github.com/gomodule/redigo/redis" "github.com/sourcegraph/log" "github.com/sourcegraph/log/hook" "github.com/sourcegraph/log/output" @@ -17,16 +16,16 @@ import ( "github.com/sourcegraph/sourcegraph/cmd/cody-gateway/internal/response" "github.com/sourcegraph/sourcegraph/internal/authbearer" "github.com/sourcegraph/sourcegraph/internal/instrumentation" + "github.com/sourcegraph/sourcegraph/internal/redispool" sgtrace "github.com/sourcegraph/sourcegraph/internal/trace" "github.com/sourcegraph/sourcegraph/internal/version" - "github.com/sourcegraph/sourcegraph/lib/errors" ) // NewDiagnosticsHandler creates a handler for diagnostic endpoints typically served // on "/-/..." paths. It should be placed before any authentication middleware, since // we do a simple auth on a static secret instead that is uniquely generated per // deployment. -func NewDiagnosticsHandler(baseLogger log.Logger, next http.Handler, redisPool *redis.Pool, secret string, sources *actor.Sources) http.Handler { +func NewDiagnosticsHandler(baseLogger log.Logger, next http.Handler, redisCache redispool.KeyValue, secret string, sources *actor.Sources) http.Handler { baseLogger = baseLogger.Scoped("diagnostics") hasValidSecret := func(l log.Logger, w http.ResponseWriter, r *http.Request) (yes bool) { @@ -58,7 +57,7 @@ func NewDiagnosticsHandler(baseLogger log.Logger, next http.Handler, redisPool * return } - if err := healthz(r.Context(), redisPool); err != nil { + if err := healthz(r.Context(), redisCache); err != nil { logger.Error("check failed", log.Error(err)) w.WriteHeader(http.StatusInternalServerError) @@ -110,21 +109,6 @@ func NewDiagnosticsHandler(baseLogger log.Logger, next http.Handler, redisPool * }) } -func healthz(ctx context.Context, rpool *redis.Pool) error { - // Check redis health - rconn, err := rpool.GetContext(ctx) - if err != nil { - return errors.Wrap(err, "redis: failed to get conn") - } - defer rconn.Close() - - data, err := rconn.Do("PING") - if err != nil { - return errors.Wrap(err, "redis: failed to ping") - } - if data != "PONG" { - return errors.New("redis: failed to ping: no pong received") - } - - return nil +func healthz(ctx context.Context, cache redispool.KeyValue) error { + return cache.WithContext(ctx).Ping() } diff --git a/cmd/cody-gateway/shared/main.go b/cmd/cody-gateway/shared/main.go index ba1d8ad7910..5e7b5590ca4 100644 --- a/cmd/cody-gateway/shared/main.go +++ b/cmd/cody-gateway/shared/main.go @@ -241,7 +241,7 @@ func Main(ctx context.Context, obctx *observation.Context, ready service.ReadyFu return errors.Wrap(err, "httpapi.NewHandler") } // Diagnostic and Maintenance layers, exposing additional APIs and endpoints. - handler = httpapi.NewDiagnosticsHandler(obctx.Logger, handler, redisCache.Pool(), cfg.DiagnosticsSecret, sources) + handler = httpapi.NewDiagnosticsHandler(obctx.Logger, handler, redisCache, cfg.DiagnosticsSecret, sources) handler = httpapi.NewMaintenanceHandler(obctx.Logger, handler, cfg, redisCache) // Collect request client for downstream handlers. Outside of dev, we always set up diff --git a/cmd/frontend/internal/cli/serve_cmd.go b/cmd/frontend/internal/cli/serve_cmd.go index bc41e301e29..9bbf351520e 100644 --- a/cmd/frontend/internal/cli/serve_cmd.go +++ b/cmd/frontend/internal/cli/serve_cmd.go @@ -450,19 +450,6 @@ func GetInternalAddr() string { return httpAddrInternal } -func pingRedis(kv redispool.KeyValue) error { - conn := kv.Pool().Get() - defer conn.Close() - data, err := conn.Do("PING") - if err != nil { - return err - } - if data != "PONG" { - return errors.New("no pong received") - } - return nil -} - // waitForRedis waits up to a certain timeout for Redis to become reachable, to reduce the // likelihood of the HTTP handlers starting to serve requests while Redis (and therefore session // data) is still unavailable. After the timeout has elapsed, if Redis is still unreachable, it @@ -473,7 +460,7 @@ func waitForRedis(logger sglog.Logger, kv redispool.KeyValue) { var err error for { time.Sleep(150 * time.Millisecond) - err = pingRedis(kv) + err = kv.Ping() if err == nil { return } diff --git a/internal/metrics/store/BUILD.bazel b/internal/metrics/store/BUILD.bazel index 36f46cafdac..ad81c14d233 100644 --- a/internal/metrics/store/BUILD.bazel +++ b/internal/metrics/store/BUILD.bazel @@ -12,7 +12,6 @@ go_library( deps = [ "//internal/redispool", "//lib/errors", - "@com_github_gomodule_redigo//redis", "@com_github_prometheus_client_golang//prometheus", "@com_github_prometheus_client_model//go", "@com_github_prometheus_common//expfmt", diff --git a/internal/metrics/store/store.go b/internal/metrics/store/store.go index 3870a462fba..7cb364d391d 100644 --- a/internal/metrics/store/store.go +++ b/internal/metrics/store/store.go @@ -5,7 +5,6 @@ import ( "io" "strings" - "github.com/gomodule/redigo/redis" "github.com/prometheus/client_golang/prometheus" dto "github.com/prometheus/client_model/go" "github.com/prometheus/common/expfmt" @@ -20,16 +19,6 @@ type Store interface { prometheus.Gatherer } -func NewDefaultStore() Store { - return &defaultStore{} -} - -type defaultStore struct{} - -func (*defaultStore) Gather() ([]*dto.MetricFamily, error) { - return prometheus.DefaultGatherer.Gather() -} - type DistributedStore interface { Store Ingest(instance string, mfs []*dto.MetricFamily) error @@ -48,13 +37,10 @@ type distributedStore struct { } func (d *distributedStore) Gather() ([]*dto.MetricFamily, error) { - pool := redispool.Cache.Pool() - - reConn := pool.Get() - defer reConn.Close() + cache := redispool.Cache // First, list all the keys for which we hold metrics. - keys, err := redis.Values(reConn.Do("KEYS", d.prefix+"*")) + keys, err := cache.Keys(d.prefix + "*") if err != nil { return nil, errors.Wrap(err, "listing entries from redis") } @@ -64,7 +50,7 @@ func (d *distributedStore) Gather() ([]*dto.MetricFamily, error) { } // Then bulk retrieve all the metrics blobs for all the instances. - encodedMetrics, err := redis.Strings(reConn.Do("MGET", keys...)) + encodedMetrics, err := cache.MGet(keys).Strings() if err != nil { return nil, errors.Wrap(err, "retrieving blobs from redis") } @@ -92,7 +78,7 @@ func (d *distributedStore) Gather() ([]*dto.MetricFamily, error) { } func (d *distributedStore) Ingest(instance string, mfs []*dto.MetricFamily) error { - pool := redispool.Cache.Pool() + cache := redispool.Cache // First, encode the metrics to text format so we can store them. var enc bytes.Buffer @@ -106,13 +92,10 @@ func (d *distributedStore) Ingest(instance string, mfs []*dto.MetricFamily) erro encodedMetrics := enc.String() - reConn := pool.Get() - defer reConn.Close() - // Store the metrics and set an expiry on the key, if we haven't retrieved // an updated set of metric data, we consider the host down and prune it // from the gatherer. - err := reConn.Send("SETEX", d.prefix+instance, d.expiry, encodedMetrics) + err := cache.SetEx(d.prefix+instance, d.expiry, encodedMetrics) if err != nil { return errors.Wrap(err, "writing metrics blob to redis") } diff --git a/internal/ratelimit/globallimiter.go b/internal/ratelimit/globallimiter.go index c0df5597398..1b1a01d7766 100644 --- a/internal/ratelimit/globallimiter.go +++ b/internal/ratelimit/globallimiter.go @@ -499,15 +499,11 @@ func SetupForTest(t TB) { t.Helper() kvMock = redispool.NewTestKeyValue() - tokenBucketGlobalPrefix = "__test__" + t.Name() - c := kvMock.Pool().Get() - defer c.Close() // If we are not on CI, skip the test if our redis connection fails. if os.Getenv("CI") == "" { - _, err := c.Do("PING") - if err != nil { + if err := kvMock.Ping(); err != nil { t.Skip("could not connect to redis", err) } } diff --git a/internal/rcache/rcache.go b/internal/rcache/rcache.go index d7b957c6a85..b1a6b4dde74 100644 --- a/internal/rcache/rcache.go +++ b/internal/rcache/rcache.go @@ -250,13 +250,10 @@ func SetupForTest(t testing.TB) redispool.KeyValue { }) globalPrefix = "__test__" + t.Name() - c := kvMock.Pool().Get() - defer c.Close() // If we are not on CI, skip the test if our redis connection fails. if os.Getenv("CI") == "" { - _, err := c.Do("PING") - if err != nil { + if err := kvMock.Ping(); err != nil { t.Skip("could not connect to redis", err) } } diff --git a/internal/redispool/keyvalue.go b/internal/redispool/keyvalue.go index f389521bb24..1fa3097d8f5 100644 --- a/internal/redispool/keyvalue.go +++ b/internal/redispool/keyvalue.go @@ -20,6 +20,8 @@ import ( type KeyValue interface { Get(key string) Value GetSet(key string, value any) Value + MGet(keys []string) Values + Set(key string, value any) error SetEx(key string, ttlSeconds int, value any) error SetNx(key string, value any) (bool, error) @@ -42,6 +44,9 @@ type KeyValue interface { LLen(key string) (int, error) LRange(key string, start, stop int) Values + // Ping checks the connection to the redis server. + Ping() error + // Keys returns all keys matching the glob pattern. NOTE: this command takes time // linear in the number of keys, and should not be run over large keyspaces. Keys(pattern string) ([]string, error) @@ -169,24 +174,33 @@ type redisKeyValue struct { recorder *LatencyRecorder } +func (r *redisKeyValue) Ping() error { + // The 'ping' command takes no arguments + return r.do("PING", []string{}, []any{}).err +} + func (r *redisKeyValue) Get(key string) Value { - return r.do("GET", key) + return r.doSimple("GET", key) } func (r *redisKeyValue) GetSet(key string, val any) Value { - return r.do("GETSET", key, val) + return r.doSimple("GETSET", key, val) +} + +func (r *redisKeyValue) MGet(keys []string) Values { + return Values(r.do("MGET", keys, []any{})) } func (r *redisKeyValue) Set(key string, val any) error { - return r.do("SET", key, val).err + return r.doSimple("SET", key, val).err } func (r *redisKeyValue) SetEx(key string, ttlSeconds int, val any) error { - return r.do("SETEX", key, ttlSeconds, val).err + return r.doSimple("SETEX", key, ttlSeconds, val).err } func (r *redisKeyValue) SetNx(key string, val any) (bool, error) { - _, err := r.do("SET", key, val, "NX").String() + _, err := r.doSimple("SET", key, val, "NX").String() if err == redis.ErrNil { return false, nil } @@ -194,61 +208,61 @@ func (r *redisKeyValue) SetNx(key string, val any) (bool, error) { } func (r *redisKeyValue) Incr(key string) (int, error) { - return r.do("INCR", key).Int() + return r.doSimple("INCR", key).Int() } func (r *redisKeyValue) Incrby(key string, value int) (int, error) { - return r.do("INCRBY", key, value).Int() + return r.doSimple("INCRBY", key, value).Int() } func (r *redisKeyValue) IncrByInt64(key string, value int64) (int64, error) { - return r.do("INCRBY", key, value).Int64() + return r.doSimple("INCRBY", key, value).Int64() } func (r *redisKeyValue) DecrByInt64(key string, value int64) (int64, error) { - return r.do("DECRBY", key, value).Int64() + return r.doSimple("DECRBY", key, value).Int64() } func (r *redisKeyValue) Del(key string) error { - return r.do("DEL", key).err + return r.doSimple("DEL", key).err } func (r *redisKeyValue) TTL(key string) (int, error) { - return r.do("TTL", key).Int() + return r.doSimple("TTL", key).Int() } func (r *redisKeyValue) Expire(key string, ttlSeconds int) error { - return r.do("EXPIRE", key, ttlSeconds).err + return r.doSimple("EXPIRE", key, ttlSeconds).err } func (r *redisKeyValue) HGet(key, field string) Value { - return r.do("HGET", key, field) + return r.doSimple("HGET", key, field) } func (r *redisKeyValue) HGetAll(key string) Values { - return Values(r.do("HGETALL", key)) + return Values(r.doSimple("HGETALL", key)) } func (r *redisKeyValue) HSet(key, field string, val any) error { - return r.do("HSET", key, field, val).err + return r.doSimple("HSET", key, field, val).err } func (r *redisKeyValue) HDel(key, field string) Value { - return r.do("HDEL", key, field) + return r.doSimple("HDEL", key, field) } func (r *redisKeyValue) LPush(key string, value any) error { - return r.do("LPUSH", key, value).err + return r.doSimple("LPUSH", key, value).err } func (r *redisKeyValue) LTrim(key string, start, stop int) error { - return r.do("LTRIM", key, start, stop).err + return r.doSimple("LTRIM", key, start, stop).err } func (r *redisKeyValue) LLen(key string) (int, error) { - raw := r.do("LLEN", key) + raw := r.doSimple("LLEN", key) return redis.Int(raw.reply, raw.err) } func (r *redisKeyValue) LRange(key string, start, stop int) Values { - return Values(r.do("LRANGE", key, start, stop)) + return Values(r.doSimple("LRANGE", key, start, stop)) } func (r *redisKeyValue) WithContext(ctx context.Context) KeyValue { @@ -278,14 +292,18 @@ func (r *redisKeyValue) WithPrefix(prefix string) KeyValue { } func (r *redisKeyValue) Keys(pattern string) ([]string, error) { - return Values(r.do("KEYS", pattern)).Strings() + return Values(r.doSimple("KEYS", pattern)).Strings() } func (r *redisKeyValue) Pool() *redis.Pool { return r.pool } -func (r *redisKeyValue) do(commandName string, key string, args ...any) Value { +func (r *redisKeyValue) doSimple(commandName string, key string, args ...any) Value { + return r.do(commandName, []string{key}, args) +} + +func (r *redisKeyValue) do(commandName string, keys []string, args []any) Value { var c redis.Conn if r.ctx != nil { var err error @@ -293,19 +311,22 @@ func (r *redisKeyValue) do(commandName string, key string, args ...any) Value { if err != nil { return Value{err: err} } - defer c.Close() } else { c = r.pool.Get() - defer c.Close() } + defer c.Close() + var start time.Time if r.recorder != nil { start = time.Now() } - prefixedKey := r.prefix + key - args = append([]any{prefixedKey}, args...) - reply, err := c.Do(commandName, args...) + prefixedKeys := make([]any, len(keys)) + for i, key := range keys { + prefixedKeys[i] = r.prefix + key + } + + reply, err := c.Do(commandName, append(prefixedKeys, args...)...) if r.recorder != nil { elapsed := time.Since(start) diff --git a/internal/redispool/keyvalue_test.go b/internal/redispool/keyvalue_test.go index b80c609125e..32a796eb1d3 100644 --- a/internal/redispool/keyvalue_test.go +++ b/internal/redispool/keyvalue_test.go @@ -189,7 +189,7 @@ func TestKeyValue(t *testing.T) { require.Equal(kv.Get("empty-string"), "") require.Equal(kv.Get("empty-bytes"), "") - // List group. Once empty we should be able to do a Get without a + // List group. Once empty we should be able to doSimple a Get without a // wrongtype error. require.Works(kv.LPush("empty-list", "here today gone tomorrow")) require.Equal(kv.Get("empty-list"), errWrongType) @@ -353,6 +353,20 @@ func TestKeyValue(t *testing.T) { require.Equal(kv.Get(k), "2") } }) + + t.Run("ping", func(t *testing.T) { + t.Parallel() + require := require{TB: t} + require.Works(kv.Ping()) + + brokenKv := redispool.NewKeyValue("nonexistent-redis-server:6379", &redis.Pool{ + MaxIdle: 3, + IdleTimeout: 5 * time.Second, + }) + if brokenKv.Ping() == nil { + t.Fatalf("ping: expected error, but did not receive one") + } + }) } func TestKeyValueWithPrefix(t *testing.T) { @@ -378,16 +392,22 @@ func TestKeyValueWithPrefix(t *testing.T) { require.Works(kv1.Set("other", "a")) + mget1, err := kv1.MGet([]string{"simple", "other"}).Strings() + require.Works(err) + if !reflect.DeepEqual(mget1, []string{"1", "a"}) { + t.Fatalf("mget mismatch: expected [1 a], got %v", mget1) + } + keys1, err := kv1.Keys("*") require.Works(err) if len(keys1) != 2 { - t.Fatalf("expected 2 keys, got %v", keys1) + t.Fatalf("keys mismatch: expected 2 keys, got %v", keys1) } - keys2, err := kv2.Keys("*") + keys2, err := kv2.Keys("s*") require.Works(err) if len(keys2) != 1 { - t.Fatalf("expected 1 key, got %v", keys1) + t.Fatalf("keys mismatch: expected 1 key, got %v", keys1) } } @@ -399,13 +419,9 @@ func redisKeyValueForTest(t *testing.T) redispool.KeyValue { kv := redispool.NewTestKeyValue() prefix := "__test__" + t.Name() - c := kv.Pool().Get() - defer c.Close() - // If we are not on CI, skip the test if our redis connection fails. if os.Getenv("CI") == "" { - _, err := c.Do("PING") - if err != nil { + if err := kv.Ping(); err != nil { t.Skip("could not connect to redis", err) } } diff --git a/internal/redispool/mocks.go b/internal/redispool/mocks.go index 6332dbb5611..63d97298f35 100644 --- a/internal/redispool/mocks.go +++ b/internal/redispool/mocks.go @@ -68,6 +68,12 @@ type MockKeyValue struct { // LTrimFunc is an instance of a mock function object controlling the // behavior of the method LTrim. LTrimFunc *KeyValueLTrimFunc + // MGetFunc is an instance of a mock function object controlling the + // behavior of the method MGet. + MGetFunc *KeyValueMGetFunc + // PingFunc is an instance of a mock function object controlling the + // behavior of the method Ping. + PingFunc *KeyValuePingFunc // PoolFunc is an instance of a mock function object controlling the // behavior of the method Pool. PoolFunc *KeyValuePoolFunc @@ -180,6 +186,16 @@ func NewMockKeyValue() *MockKeyValue { return }, }, + MGetFunc: &KeyValueMGetFunc{ + defaultHook: func([]string) (r0 Values) { + return + }, + }, + PingFunc: &KeyValuePingFunc{ + defaultHook: func() (r0 error) { + return + }, + }, PoolFunc: &KeyValuePoolFunc{ defaultHook: func() (r0 *redis.Pool) { return @@ -307,6 +323,16 @@ func NewStrictMockKeyValue() *MockKeyValue { panic("unexpected invocation of MockKeyValue.LTrim") }, }, + MGetFunc: &KeyValueMGetFunc{ + defaultHook: func([]string) Values { + panic("unexpected invocation of MockKeyValue.MGet") + }, + }, + PingFunc: &KeyValuePingFunc{ + defaultHook: func() error { + panic("unexpected invocation of MockKeyValue.Ping") + }, + }, PoolFunc: &KeyValuePoolFunc{ defaultHook: func() *redis.Pool { panic("unexpected invocation of MockKeyValue.Pool") @@ -400,6 +426,12 @@ func NewMockKeyValueFrom(i KeyValue) *MockKeyValue { LTrimFunc: &KeyValueLTrimFunc{ defaultHook: i.LTrim, }, + MGetFunc: &KeyValueMGetFunc{ + defaultHook: i.MGet, + }, + PingFunc: &KeyValuePingFunc{ + defaultHook: i.Ping, + }, PoolFunc: &KeyValuePoolFunc{ defaultHook: i.Pool, }, @@ -2203,6 +2235,205 @@ func (c KeyValueLTrimFuncCall) Results() []interface{} { return []interface{}{c.Result0} } +// KeyValueMGetFunc describes the behavior when the MGet method of the +// parent MockKeyValue instance is invoked. +type KeyValueMGetFunc struct { + defaultHook func([]string) Values + hooks []func([]string) Values + history []KeyValueMGetFuncCall + mutex sync.Mutex +} + +// MGet delegates to the next hook function in the queue and stores the +// parameter and result values of this invocation. +func (m *MockKeyValue) MGet(v0 []string) Values { + r0 := m.MGetFunc.nextHook()(v0) + m.MGetFunc.appendCall(KeyValueMGetFuncCall{v0, r0}) + return r0 +} + +// SetDefaultHook sets function that is called when the MGet method of the +// parent MockKeyValue instance is invoked and the hook queue is empty. +func (f *KeyValueMGetFunc) SetDefaultHook(hook func([]string) Values) { + f.defaultHook = hook +} + +// PushHook adds a function to the end of hook queue. Each invocation of the +// MGet method of the parent MockKeyValue instance invokes the hook at the +// front of the queue and discards it. After the queue is empty, the default +// hook function is invoked for any future action. +func (f *KeyValueMGetFunc) PushHook(hook func([]string) Values) { + f.mutex.Lock() + f.hooks = append(f.hooks, hook) + f.mutex.Unlock() +} + +// SetDefaultReturn calls SetDefaultHook with a function that returns the +// given values. +func (f *KeyValueMGetFunc) SetDefaultReturn(r0 Values) { + f.SetDefaultHook(func([]string) Values { + return r0 + }) +} + +// PushReturn calls PushHook with a function that returns the given values. +func (f *KeyValueMGetFunc) PushReturn(r0 Values) { + f.PushHook(func([]string) Values { + return r0 + }) +} + +func (f *KeyValueMGetFunc) nextHook() func([]string) Values { + f.mutex.Lock() + defer f.mutex.Unlock() + + if len(f.hooks) == 0 { + return f.defaultHook + } + + hook := f.hooks[0] + f.hooks = f.hooks[1:] + return hook +} + +func (f *KeyValueMGetFunc) appendCall(r0 KeyValueMGetFuncCall) { + f.mutex.Lock() + f.history = append(f.history, r0) + f.mutex.Unlock() +} + +// History returns a sequence of KeyValueMGetFuncCall objects describing the +// invocations of this function. +func (f *KeyValueMGetFunc) History() []KeyValueMGetFuncCall { + f.mutex.Lock() + history := make([]KeyValueMGetFuncCall, len(f.history)) + copy(history, f.history) + f.mutex.Unlock() + + return history +} + +// KeyValueMGetFuncCall is an object that describes an invocation of method +// MGet on an instance of MockKeyValue. +type KeyValueMGetFuncCall struct { + // Arg0 is the value of the 1st argument passed to this method + // invocation. + Arg0 []string + // Result0 is the value of the 1st result returned from this method + // invocation. + Result0 Values +} + +// Args returns an interface slice containing the arguments of this +// invocation. +func (c KeyValueMGetFuncCall) Args() []interface{} { + return []interface{}{c.Arg0} +} + +// Results returns an interface slice containing the results of this +// invocation. +func (c KeyValueMGetFuncCall) Results() []interface{} { + return []interface{}{c.Result0} +} + +// KeyValuePingFunc describes the behavior when the Ping method of the +// parent MockKeyValue instance is invoked. +type KeyValuePingFunc struct { + defaultHook func() error + hooks []func() error + history []KeyValuePingFuncCall + mutex sync.Mutex +} + +// Ping delegates to the next hook function in the queue and stores the +// parameter and result values of this invocation. +func (m *MockKeyValue) Ping() error { + r0 := m.PingFunc.nextHook()() + m.PingFunc.appendCall(KeyValuePingFuncCall{r0}) + return r0 +} + +// SetDefaultHook sets function that is called when the Ping method of the +// parent MockKeyValue instance is invoked and the hook queue is empty. +func (f *KeyValuePingFunc) SetDefaultHook(hook func() error) { + f.defaultHook = hook +} + +// PushHook adds a function to the end of hook queue. Each invocation of the +// Ping method of the parent MockKeyValue instance invokes the hook at the +// front of the queue and discards it. After the queue is empty, the default +// hook function is invoked for any future action. +func (f *KeyValuePingFunc) PushHook(hook func() error) { + f.mutex.Lock() + f.hooks = append(f.hooks, hook) + f.mutex.Unlock() +} + +// SetDefaultReturn calls SetDefaultHook with a function that returns the +// given values. +func (f *KeyValuePingFunc) SetDefaultReturn(r0 error) { + f.SetDefaultHook(func() error { + return r0 + }) +} + +// PushReturn calls PushHook with a function that returns the given values. +func (f *KeyValuePingFunc) PushReturn(r0 error) { + f.PushHook(func() error { + return r0 + }) +} + +func (f *KeyValuePingFunc) nextHook() func() error { + f.mutex.Lock() + defer f.mutex.Unlock() + + if len(f.hooks) == 0 { + return f.defaultHook + } + + hook := f.hooks[0] + f.hooks = f.hooks[1:] + return hook +} + +func (f *KeyValuePingFunc) appendCall(r0 KeyValuePingFuncCall) { + f.mutex.Lock() + f.history = append(f.history, r0) + f.mutex.Unlock() +} + +// History returns a sequence of KeyValuePingFuncCall objects describing the +// invocations of this function. +func (f *KeyValuePingFunc) History() []KeyValuePingFuncCall { + f.mutex.Lock() + history := make([]KeyValuePingFuncCall, len(f.history)) + copy(history, f.history) + f.mutex.Unlock() + + return history +} + +// KeyValuePingFuncCall is an object that describes an invocation of method +// Ping on an instance of MockKeyValue. +type KeyValuePingFuncCall struct { + // Result0 is the value of the 1st result returned from this method + // invocation. + Result0 error +} + +// Args returns an interface slice containing the arguments of this +// invocation. +func (c KeyValuePingFuncCall) Args() []interface{} { + return []interface{}{} +} + +// Results returns an interface slice containing the results of this +// invocation. +func (c KeyValuePingFuncCall) Results() []interface{} { + return []interface{}{c.Result0} +} + // KeyValuePoolFunc describes the behavior when the Pool method of the // parent MockKeyValue instance is invoked. type KeyValuePoolFunc struct { diff --git a/internal/redispool/redispool_test.go b/internal/redispool/redispool_test.go index 9b6f7c1789c..b6b82e2bbcd 100644 --- a/internal/redispool/redispool_test.go +++ b/internal/redispool/redispool_test.go @@ -40,13 +40,9 @@ func TestDeleteAllKeysWithPrefix(t *testing.T) { kv := NewTestKeyValue() - c := kv.Pool().Get() - defer c.Close() - // If we are not on CI, skip the test if our redis connection fails. if os.Getenv("CI") == "" { - _, err := c.Do("PING") - if err != nil { + if err := kv.Ping(); err != nil { t.Skip("could not connect to redis", err) } }