RateLimitingUtil single point of entry to Redis part 1

This commit is contained in:
simonredfern 2025-12-27 07:26:40 +01:00
parent cd52665f35
commit c647eb145f

View File

@ -76,6 +76,13 @@ object RateLimitingJson {
object RateLimitingUtil extends MdcLoggable {
import code.api.util.RateLimitingPeriod._
/** State of a rate limiting counter from Redis */
case class RateLimitCounterState(
calls: Option[Long], // Current counter value
ttl: Option[Long], // Time to live in seconds
status: String // ACTIVE, NO_COUNTER, EXPIRED, REDIS_UNAVAILABLE
)
def useConsumerLimits = APIUtil.getPropsAsBoolValue("use_consumer_limits", false)
/** Get system default rate limits from properties. Used when no RateLimiting records exist for a consumer.
@ -143,38 +150,75 @@ object RateLimitingUtil extends MdcLoggable {
}
}
private def createUniqueKey(consumerKey: String, period: LimitCallPeriod) = consumerKey + "_" + RateLimitingPeriod.toString(period)
/**
* Single source of truth for reading rate limit counter state from Redis.
* All rate limiting functions should call this instead of accessing Redis directly.
*
* @param consumerKey The consumer ID
* @param period The time period (PER_SECOND, PER_MINUTE, etc.)
* @return RateLimitCounterState with calls, ttl, and status
*/
private def getCounterState(consumerKey: String, period: LimitCallPeriod): RateLimitCounterState = {
val key = createUniqueKey(consumerKey, period)
// Read TTL and value from Redis (2 operations)
val ttlOpt: Option[Long] = Redis.use(JedisMethod.TTL, key).map(_.toLong)
val valueOpt: Option[Long] = Redis.use(JedisMethod.GET, key).map(_.toLong)
// Determine status based on Redis TTL response
val status = ttlOpt match {
case Some(ttl) if ttl > 0 => "ACTIVE" // Counter running with time remaining
case Some(-2) => "NO_COUNTER" // Key does not exist, never been set
case Some(ttl) if ttl <= 0 => "EXPIRED" // Key expired (TTL=0) or no expiry (TTL=-1)
case None => "REDIS_UNAVAILABLE" // Redis connection failed
}
// Normalize calls value
val calls = ttlOpt match {
case Some(-2) => Some(0L) // Key doesn't exist -> 0 calls
case Some(ttl) if ttl <= 0 => Some(0L) // Expired or invalid -> 0 calls
case Some(_) => valueOpt.orElse(Some(0L)) // Active key -> return value or 0
case None => Some(0L) // Redis unavailable -> 0 calls
}
// Normalize TTL value
val normalizedTtl = ttlOpt match {
case Some(-2) => Some(0L) // Key doesn't exist -> 0 TTL
case Some(ttl) if ttl <= 0 => Some(0L) // Expired -> 0 TTL
case Some(ttl) => Some(ttl) // Active -> actual TTL
case None => Some(0L) // Redis unavailable -> 0 TTL
}
RateLimitCounterState(calls, normalizedTtl, status)
}
private def createUniqueKey(consumerKey: String, period: LimitCallPeriod) = consumerKey + "_" + RateLimitingPeriod.toString(period)
private def underConsumerLimits(consumerKey: String, period: LimitCallPeriod, limit: Long): Boolean = {
if (useConsumerLimits) {
try {
(limit) match {
case l if l > 0 => // Redis is available and limit is set
val key = createUniqueKey(consumerKey, period)
// TODO: Check if we can remove redundant EXISTS check. GET returns None when key does not exist.
// Check This would reduce Redis operations from 2 to 1 (25% reduction per request).
// Simplified code:
// val currentValue = Redis.use(JedisMethod.GET, key)
// currentValue match {
// case Some(value) => value.toLong + 1 <= limit
// case None => true // Key does not exist, first call
// }
val exists = Redis.use(JedisMethod.EXISTS,key).map(_.toBoolean).get
exists match {
case true =>
val underLimit = Redis.use(JedisMethod.GET,key).get.toLong + 1 <= limit // +1 means we count the current call as well. We increment later i.e after successful call.
underLimit
case false => // In case that key does not exist we return successful result
true
}
case _ =>
// Rate Limiting for a Consumer <= 0 implies successful result
// Or any other unhandled case implies successful result
true
}
} catch {
case e : Throwable =>
logger.error(s"Redis issue: $e")
(limit) match {
case l if l > 0 => // Limit is set, check against Redis counter
val state = getCounterState(consumerKey, period)
state.status match {
case "ACTIVE" =>
// Counter is active, check if we're under limit
// +1 means we count the current call as well. We increment later i.e after successful call.
state.calls.getOrElse(0L) + 1 <= limit
case "NO_COUNTER" | "EXPIRED" =>
// No counter or expired -> allow (first call or period expired)
true
case "REDIS_UNAVAILABLE" =>
// Redis unavailable -> fail open (allow request)
logger.warn(s"Redis unavailable when checking rate limit for consumer $consumerKey, period $period - allowing request")
true
case _ =>
// Unknown status -> fail open (allow request)
logger.warn(s"Unknown status '${state.status}' when checking rate limit for consumer $consumerKey, period $period - allowing request")
true
}
case _ =>
// Rate Limiting for a Consumer <= 0 implies successful result
// Or any other unhandled case implies successful result
true
}
} else {
@ -227,45 +271,9 @@ object RateLimitingUtil extends MdcLoggable {
def consumerRateLimitState(consumerKey: String): immutable.Seq[((Option[Long], Option[Long], String), LimitCallPeriod)] = {
def getCallCounterForPeriod(consumerKey: String, period: LimitCallPeriod): ((Option[Long], Option[Long], String), LimitCallPeriod) = {
val key = createUniqueKey(consumerKey, period)
// get TTL
val ttlOpt: Option[Long] = Redis.use(JedisMethod.TTL, key).map(_.toLong)
// get value (assuming string storage)
// TODO: Why do we assume string for a counter that we INCR?
val valueOpt: Option[Long] = Redis.use(JedisMethod.GET, key).map(_.toLong)
// TTL meanings:
// -2: Key does not exist
// -1: Key exists with no expiry (shouldn't happen in our rate limiting)
// >0: Seconds until key expires
val calls = ttlOpt match {
case Some(-2) => Some(0L) // Key doesn't exist -> 0 calls
case Some(ttl) if ttl <= 0 => Some(0L) // Expired or invalid -> 0 calls
case Some(_) => valueOpt.orElse(Some(0L)) // Active key -> return value or 0
case None => Some(0L) // Redis unavailable -> 0 calls
}
val normalizedTtl = ttlOpt match {
case Some(-2) => Some(0L) // Key doesn't exist -> 0 TTL
case Some(ttl) if ttl <= 0 => Some(0L) // Expired -> 0 TTL
case Some(ttl) => Some(ttl) // Active -> actual TTL
case None => Some(0L) // Redis unavailable -> 0 TTL
}
// Calculate status based on Redis TTL response
val status = ttlOpt match {
case Some(ttl) if ttl > 0 => "ACTIVE" // Counter running with time remaining
case Some(-2) => "NO_COUNTER" // Key does not exist, never been set
case Some(ttl) if ttl <= 0 => "EXPIRED" // Key expired (TTL=0) or no expiry (TTL=-1)
case None => "REDIS_UNAVAILABLE" // Redis connection failed
}
((calls, normalizedTtl, status), period)
val state = getCounterState(consumerKey, period)
((state.calls, state.ttl, state.status), period)
}
getCallCounterForPeriod(consumerKey, RateLimitingPeriod.PER_SECOND) ::