rate-limits refactor for single point of truth 3

This commit is contained in:
simonredfern 2025-12-23 23:21:27 +01:00
parent 1eaaa50d8f
commit 7b44672a35
3 changed files with 48 additions and 62 deletions

View File

@ -93,14 +93,14 @@ object AfterApiAuth extends MdcLoggable{
/**
* This block of code needs to update Call Context with Rate Limiting
* Uses RateLimitingUtil.getActiveRateLimits as the SINGLE SOURCE OF TRUTH
* Uses RateLimitingUtil.getActiveRateLimitsWithIds as the SINGLE SOURCE OF TRUTH
*/
def checkRateLimiting(userIsLockedOrDeleted: Future[(Box[User], Option[CallContext])]): Future[(Box[User], Option[CallContext])] = {
for {
(user, cc) <- userIsLockedOrDeleted
consumer = cc.flatMap(_.consumer)
consumerId = consumer.map(_.consumerId.get).getOrElse("")
rateLimit <- RateLimitingUtil.getActiveRateLimits(consumerId, new Date())
(rateLimit, _) <- RateLimitingUtil.getActiveRateLimitsWithIds(consumerId, new Date())
} yield {
(user, cc.map(_.copy(rateLimiting = Some(rateLimit))))
}

View File

@ -82,62 +82,15 @@ object RateLimitingUtil extends MdcLoggable {
* @param consumerId The consumer ID
* @return RateLimit with system property defaults (default to -1 if not set)
*/
def getSystemDefaultRateLimits(consumerId: String): CallLimit = {
RateLimitingJson.CallLimit(
consumerId,
None,
None,
None,
APIUtil.getPropsAsLongValue("rate_limiting_per_second", -1),
APIUtil.getPropsAsLongValue("rate_limiting_per_minute", -1),
APIUtil.getPropsAsLongValue("rate_limiting_per_hour", -1),
APIUtil.getPropsAsLongValue("rate_limiting_per_day", -1),
APIUtil.getPropsAsLongValue("rate_limiting_per_week", -1),
APIUtil.getPropsAsLongValue("rate_limiting_per_month", -1)
)
}
/** Aggregate multiple rate limiting records into a single CallLimit. This is the SINGLE SOURCE OF TRUTH for aggregation logic.
* Rules:
* - Only positive values (> 0) are summed
* - If no positive values exist for a period, return -1 (unlimited)
* - Multiple overlapping records have their limits added together
* @param rateLimitRecords List of RateLimiting records to aggregate
* @param consumerId The consumer ID
* @return Aggregated CallLimit
*/
def aggregateRateLimits(rateLimitRecords: List[RateLimiting], consumerId: String): CallLimit = {
def sumLimits(values: List[Long]): Long = {
val positiveValues = values.filter(_ > 0)
if (positiveValues.isEmpty) -1 else positiveValues.sum
}
if (rateLimitRecords.nonEmpty) {
RateLimitingJson.CallLimit(
consumerId,
rateLimitRecords.find(_.apiName.isDefined).flatMap(_.apiName),
rateLimitRecords.find(_.apiVersion.isDefined).flatMap(_.apiVersion),
rateLimitRecords.find(_.bankId.isDefined).flatMap(_.bankId),
sumLimits(rateLimitRecords.map(_.perSecondCallLimit)),
sumLimits(rateLimitRecords.map(_.perMinuteCallLimit)),
sumLimits(rateLimitRecords.map(_.perHourCallLimit)),
sumLimits(rateLimitRecords.map(_.perDayCallLimit)),
sumLimits(rateLimitRecords.map(_.perWeekCallLimit)),
sumLimits(rateLimitRecords.map(_.perMonthCallLimit))
)
} else {
RateLimitingJson.CallLimit(consumerId, None, None, None, -1, -1, -1, -1, -1, -1)
}
}
/** Get the active rate limits for a consumer at a specific date. This is the SINGLE SOURCE OF TRUTH for rate limit calculation used by both:
* - The enforcement system (AfterApiAuth.checkRateLimiting)
* - The API endpoint (GET /consumer/rate-limits/active-at-date/{DATE})
/** THE SINGLE SOURCE OF TRUTH for active rate limits.
* This is the ONLY function that should be called to get active rate limits.
* Used by BOTH enforcement (AfterApiAuth) and API reporting (APIMethods600).
*
* @param consumerId The consumer ID
* @param date The date to check active limits for
* @return Future containing the aggregated CallLimit that will be enforced
* @return Future containing (aggregated CallLimit, List of rate_limiting_ids that contributed)
*/
def getActiveRateLimits(consumerId: String, date: Date): Future[CallLimit] = {
def getActiveRateLimitsWithIds(consumerId: String, date: Date): Future[(CallLimit, List[String])] = {
def getActiveRateLimitings(consumerId: String): Future[List[RateLimiting]] = {
useConsumerLimits match {
case true => RateLimitingDI.rateLimiting.vend.getActiveCallLimitsByConsumerIdAtDate(consumerId, date)
@ -145,13 +98,48 @@ object RateLimitingUtil extends MdcLoggable {
}
}
def aggregateRateLimits(rateLimitRecords: List[RateLimiting]): CallLimit = {
def sumLimits(values: List[Long]): Long = {
val positiveValues = values.filter(_ > 0)
if (positiveValues.isEmpty) -1 else positiveValues.sum
}
if (rateLimitRecords.nonEmpty) {
RateLimitingJson.CallLimit(
consumerId,
rateLimitRecords.find(_.apiName.isDefined).flatMap(_.apiName),
rateLimitRecords.find(_.apiVersion.isDefined).flatMap(_.apiVersion),
rateLimitRecords.find(_.bankId.isDefined).flatMap(_.bankId),
sumLimits(rateLimitRecords.map(_.perSecondCallLimit)),
sumLimits(rateLimitRecords.map(_.perMinuteCallLimit)),
sumLimits(rateLimitRecords.map(_.perHourCallLimit)),
sumLimits(rateLimitRecords.map(_.perDayCallLimit)),
sumLimits(rateLimitRecords.map(_.perWeekCallLimit)),
sumLimits(rateLimitRecords.map(_.perMonthCallLimit))
)
} else {
// No records found - return system defaults
RateLimitingJson.CallLimit(
consumerId,
None,
None,
None,
APIUtil.getPropsAsLongValue("rate_limiting_per_second", -1),
APIUtil.getPropsAsLongValue("rate_limiting_per_minute", -1),
APIUtil.getPropsAsLongValue("rate_limiting_per_hour", -1),
APIUtil.getPropsAsLongValue("rate_limiting_per_day", -1),
APIUtil.getPropsAsLongValue("rate_limiting_per_week", -1),
APIUtil.getPropsAsLongValue("rate_limiting_per_month", -1)
)
}
}
for {
rateLimitRecords <- getActiveRateLimitings(consumerId)
} yield {
rateLimitRecords match {
case Nil => getSystemDefaultRateLimits(consumerId)
case records => aggregateRateLimits(records, consumerId)
}
val callLimit = aggregateRateLimits(rateLimitRecords)
val ids = rateLimitRecords.map(_.rateLimitingId)
(callLimit, ids)
}
}

View File

@ -493,9 +493,7 @@ trait APIMethods600 {
val format = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss'Z'")
format.parse(dateString)
}
rateLimit <- RateLimitingUtil.getActiveRateLimits(consumerId, date)
rateLimitRecords <- RateLimitingDI.rateLimiting.vend.getActiveCallLimitsByConsumerIdAtDate(consumerId, date)
rateLimitIds = rateLimitRecords.map(_.rateLimitingId)
(rateLimit, rateLimitIds) <- RateLimitingUtil.getActiveRateLimitsWithIds(consumerId, date)
} yield {
(JSONFactory600.createActiveCallLimitsJsonV600FromCallLimit(rateLimit, rateLimitIds, date), HttpCode.`200`(callContext))
}