diff --git a/obp-api/src/main/scala/code/api/util/RateLimitingUtil.scala b/obp-api/src/main/scala/code/api/util/RateLimitingUtil.scala index 37a167258..370a145c5 100644 --- a/obp-api/src/main/scala/code/api/util/RateLimitingUtil.scala +++ b/obp-api/src/main/scala/code/api/util/RateLimitingUtil.scala @@ -76,6 +76,13 @@ object RateLimitingJson { object RateLimitingUtil extends MdcLoggable { import code.api.util.RateLimitingPeriod._ + /** State of a rate limiting counter from Redis */ + case class RateLimitCounterState( + calls: Option[Long], // Current counter value + ttl: Option[Long], // Time to live in seconds + status: String // ACTIVE, NO_COUNTER, EXPIRED, REDIS_UNAVAILABLE + ) + def useConsumerLimits = APIUtil.getPropsAsBoolValue("use_consumer_limits", false) /** Get system default rate limits from properties. Used when no RateLimiting records exist for a consumer. @@ -143,38 +150,75 @@ object RateLimitingUtil extends MdcLoggable { } } - private def createUniqueKey(consumerKey: String, period: LimitCallPeriod) = consumerKey + "_" + RateLimitingPeriod.toString(period) + /** + * Single source of truth for reading rate limit counter state from Redis. + * All rate limiting functions should call this instead of accessing Redis directly. + * + * @param consumerKey The consumer ID + * @param period The time period (PER_SECOND, PER_MINUTE, etc.) + * @return RateLimitCounterState with calls, ttl, and status + */ + private def getCounterState(consumerKey: String, period: LimitCallPeriod): RateLimitCounterState = { + val key = createUniqueKey(consumerKey, period) + + // Read TTL and value from Redis (2 operations) + val ttlOpt: Option[Long] = Redis.use(JedisMethod.TTL, key).map(_.toLong) + val valueOpt: Option[Long] = Redis.use(JedisMethod.GET, key).map(_.toLong) + + // Determine status based on Redis TTL response + val status = ttlOpt match { + case Some(ttl) if ttl > 0 => "ACTIVE" // Counter running with time remaining + case Some(-2) => "NO_COUNTER" // Key does not exist, never been set + case Some(ttl) if ttl <= 0 => "EXPIRED" // Key expired (TTL=0) or no expiry (TTL=-1) + case None => "REDIS_UNAVAILABLE" // Redis connection failed + } + + // Normalize calls value + val calls = ttlOpt match { + case Some(-2) => Some(0L) // Key doesn't exist -> 0 calls + case Some(ttl) if ttl <= 0 => Some(0L) // Expired or invalid -> 0 calls + case Some(_) => valueOpt.orElse(Some(0L)) // Active key -> return value or 0 + case None => Some(0L) // Redis unavailable -> 0 calls + } + + // Normalize TTL value + val normalizedTtl = ttlOpt match { + case Some(-2) => Some(0L) // Key doesn't exist -> 0 TTL + case Some(ttl) if ttl <= 0 => Some(0L) // Expired -> 0 TTL + case Some(ttl) => Some(ttl) // Active -> actual TTL + case None => Some(0L) // Redis unavailable -> 0 TTL + } + + RateLimitCounterState(calls, normalizedTtl, status) + } + private def createUniqueKey(consumerKey: String, period: LimitCallPeriod) = consumerKey + "_" + RateLimitingPeriod.toString(period) private def underConsumerLimits(consumerKey: String, period: LimitCallPeriod, limit: Long): Boolean = { + if (useConsumerLimits) { - try { - (limit) match { - case l if l > 0 => // Redis is available and limit is set - val key = createUniqueKey(consumerKey, period) - // TODO: Check if we can remove redundant EXISTS check. GET returns None when key does not exist. - // Check This would reduce Redis operations from 2 to 1 (25% reduction per request). - // Simplified code: - // val currentValue = Redis.use(JedisMethod.GET, key) - // currentValue match { - // case Some(value) => value.toLong + 1 <= limit - // case None => true // Key does not exist, first call - // } - val exists = Redis.use(JedisMethod.EXISTS,key).map(_.toBoolean).get - exists match { - case true => - val underLimit = Redis.use(JedisMethod.GET,key).get.toLong + 1 <= limit // +1 means we count the current call as well. We increment later i.e after successful call. - underLimit - case false => // In case that key does not exist we return successful result - true - } - case _ => - // Rate Limiting for a Consumer <= 0 implies successful result - // Or any other unhandled case implies successful result - true - } - } catch { - case e : Throwable => - logger.error(s"Redis issue: $e") + (limit) match { + case l if l > 0 => // Limit is set, check against Redis counter + val state = getCounterState(consumerKey, period) + state.status match { + case "ACTIVE" => + // Counter is active, check if we're under limit + // +1 means we count the current call as well. We increment later i.e after successful call. + state.calls.getOrElse(0L) + 1 <= limit + case "NO_COUNTER" | "EXPIRED" => + // No counter or expired -> allow (first call or period expired) + true + case "REDIS_UNAVAILABLE" => + // Redis unavailable -> fail open (allow request) + logger.warn(s"Redis unavailable when checking rate limit for consumer $consumerKey, period $period - allowing request") + true + case _ => + // Unknown status -> fail open (allow request) + logger.warn(s"Unknown status '${state.status}' when checking rate limit for consumer $consumerKey, period $period - allowing request") + true + } + case _ => + // Rate Limiting for a Consumer <= 0 implies successful result + // Or any other unhandled case implies successful result true } } else { @@ -227,45 +271,9 @@ object RateLimitingUtil extends MdcLoggable { def consumerRateLimitState(consumerKey: String): immutable.Seq[((Option[Long], Option[Long], String), LimitCallPeriod)] = { - def getCallCounterForPeriod(consumerKey: String, period: LimitCallPeriod): ((Option[Long], Option[Long], String), LimitCallPeriod) = { - val key = createUniqueKey(consumerKey, period) - - // get TTL - val ttlOpt: Option[Long] = Redis.use(JedisMethod.TTL, key).map(_.toLong) - - // get value (assuming string storage) - // TODO: Why do we assume string for a counter that we INCR? - val valueOpt: Option[Long] = Redis.use(JedisMethod.GET, key).map(_.toLong) - - // TTL meanings: - // -2: Key does not exist - // -1: Key exists with no expiry (shouldn't happen in our rate limiting) - // >0: Seconds until key expires - val calls = ttlOpt match { - case Some(-2) => Some(0L) // Key doesn't exist -> 0 calls - case Some(ttl) if ttl <= 0 => Some(0L) // Expired or invalid -> 0 calls - case Some(_) => valueOpt.orElse(Some(0L)) // Active key -> return value or 0 - case None => Some(0L) // Redis unavailable -> 0 calls - } - - val normalizedTtl = ttlOpt match { - case Some(-2) => Some(0L) // Key doesn't exist -> 0 TTL - case Some(ttl) if ttl <= 0 => Some(0L) // Expired -> 0 TTL - case Some(ttl) => Some(ttl) // Active -> actual TTL - case None => Some(0L) // Redis unavailable -> 0 TTL - } - - - // Calculate status based on Redis TTL response - val status = ttlOpt match { - case Some(ttl) if ttl > 0 => "ACTIVE" // Counter running with time remaining - case Some(-2) => "NO_COUNTER" // Key does not exist, never been set - case Some(ttl) if ttl <= 0 => "EXPIRED" // Key expired (TTL=0) or no expiry (TTL=-1) - case None => "REDIS_UNAVAILABLE" // Redis connection failed - } - - ((calls, normalizedTtl, status), period) + val state = getCounterState(consumerKey, period) + ((state.calls, state.ttl, state.status), period) } getCallCounterForPeriod(consumerKey, RateLimitingPeriod.PER_SECOND) ::