rate-limits refactor for single point of truth

This commit is contained in:
simonredfern 2025-12-23 22:05:14 +01:00
parent 47d6f97d89
commit a9a7384088
4 changed files with 102 additions and 54 deletions

View File

@ -93,67 +93,18 @@ object AfterApiAuth extends MdcLoggable{
/**
* This block of code needs to update Call Context with Rate Limiting
* Please note that first source is the table RateLimiting and second is the table Consumer
* Uses RateLimitingUtil.getActiveRateLimits as the SINGLE SOURCE OF TRUTH
*/
def checkRateLimiting(userIsLockedOrDeleted: Future[(Box[User], Option[CallContext])]): Future[(Box[User], Option[CallContext])] = {
def getActiveRateLimitings(consumerId: String): Future[List[RateLimiting]] = {
RateLimitingUtil.useConsumerLimits match {
case true => RateLimitingDI.rateLimiting.vend.getActiveCallLimitsByConsumerIdAtDate(consumerId, new Date())
case false => Future(List.empty)
}
}
def aggregateLimits(limits: List[RateLimiting], consumerId: String): CallLimit = {
def sumLimits(values: List[Long]): Long = {
val positiveValues = values.filter(_ > 0)
if (positiveValues.isEmpty) -1 else positiveValues.sum
}
if (limits.nonEmpty) {
CallLimit(
consumerId,
limits.find(_.apiName.isDefined).flatMap(_.apiName),
limits.find(_.apiVersion.isDefined).flatMap(_.apiVersion),
limits.find(_.bankId.isDefined).flatMap(_.bankId),
sumLimits(limits.map(_.perSecondCallLimit)),
sumLimits(limits.map(_.perMinuteCallLimit)),
sumLimits(limits.map(_.perHourCallLimit)),
sumLimits(limits.map(_.perDayCallLimit)),
sumLimits(limits.map(_.perWeekCallLimit)),
sumLimits(limits.map(_.perMonthCallLimit))
)
} else {
CallLimit(consumerId, None, None, None, -1, -1, -1, -1, -1, -1)
}
}
for {
(user, cc) <- userIsLockedOrDeleted
consumer = cc.flatMap(_.consumer)
consumerId = consumer.map(_.consumerId.get).getOrElse("")
rateLimitings <- getActiveRateLimitings(consumerId)
rateLimit <- RateLimitingUtil.getActiveRateLimits(consumerId, new Date())
} yield {
val limit: Option[CallLimit] = rateLimitings match {
case Nil => // No rate limiting records found, use consumer defaults
Some(CallLimit(
consumerId,
None,
None,
None,
consumer.map(_.perSecondCallLimit.get).getOrElse(-1),
consumer.map(_.perMinuteCallLimit.get).getOrElse(-1),
consumer.map(_.perHourCallLimit.get).getOrElse(-1),
consumer.map(_.perDayCallLimit.get).getOrElse(-1),
consumer.map(_.perWeekCallLimit.get).getOrElse(-1),
consumer.map(_.perMonthCallLimit.get).getOrElse(-1)
))
case activeLimits => // Aggregate multiple rate limiting records
Some(aggregateLimits(activeLimits, consumerId))
}
(user, cc.map(_.copy(rateLimiting = limit)))
(user, cc.map(_.copy(rateLimiting = Some(rateLimit))))
}
}
private def sofitInitAction(user: AuthUser): Boolean = applyAction("sofit.logon_init_action.enabled") {
def getOrCreateBankAccount(bank: Bank, accountId: String, label: String, accountType: String = ""): Box[BankAccount] = {
MappedBankAccount.find(

View File

@ -1,5 +1,9 @@
package code.api.util
import java.util.Date
import code.ratelimiting.{RateLimiting, RateLimitingDI}
import scala.concurrent.Future
import scala.concurrent.ExecutionContext.Implicits.global
import code.api.{APIFailureNewStyle, JedisMethod}
import code.api.cache.Redis
import code.api.util.APIUtil.fullBoxOrException
@ -74,6 +78,83 @@ object RateLimitingUtil extends MdcLoggable {
def useConsumerLimits = APIUtil.getPropsAsBoolValue("use_consumer_limits", false)
/** Get system default rate limits from properties. Used when no RateLimiting records exist for a consumer.
* @param consumerId The consumer ID
* @return RateLimit with system property defaults (default to -1 if not set)
*/
def getSystemDefaultRateLimits(consumerId: String): CallLimit = {
RateLimitingJson.CallLimit(
consumerId,
None,
None,
None,
APIUtil.getPropsAsLongValue("rate_limiting_per_second", -1),
APIUtil.getPropsAsLongValue("rate_limiting_per_minute", -1),
APIUtil.getPropsAsLongValue("rate_limiting_per_hour", -1),
APIUtil.getPropsAsLongValue("rate_limiting_per_day", -1),
APIUtil.getPropsAsLongValue("rate_limiting_per_week", -1),
APIUtil.getPropsAsLongValue("rate_limiting_per_month", -1)
)
}
/** Aggregate multiple rate limiting records into a single CallLimit. This is the SINGLE SOURCE OF TRUTH for aggregation logic.
* Rules:
* - Only positive values (> 0) are summed
* - If no positive values exist for a period, return -1 (unlimited)
* - Multiple overlapping records have their limits added together
* @param rateLimitRecords List of RateLimiting records to aggregate
* @param consumerId The consumer ID
* @return Aggregated CallLimit
*/
def aggregateRateLimits(rateLimitRecords: List[RateLimiting], consumerId: String): CallLimit = {
def sumLimits(values: List[Long]): Long = {
val positiveValues = values.filter(_ > 0)
if (positiveValues.isEmpty) -1 else positiveValues.sum
}
if (rateLimitRecords.nonEmpty) {
RateLimitingJson.CallLimit(
consumerId,
rateLimitRecords.find(_.apiName.isDefined).flatMap(_.apiName),
rateLimitRecords.find(_.apiVersion.isDefined).flatMap(_.apiVersion),
rateLimitRecords.find(_.bankId.isDefined).flatMap(_.bankId),
sumLimits(rateLimitRecords.map(_.perSecondCallLimit)),
sumLimits(rateLimitRecords.map(_.perMinuteCallLimit)),
sumLimits(rateLimitRecords.map(_.perHourCallLimit)),
sumLimits(rateLimitRecords.map(_.perDayCallLimit)),
sumLimits(rateLimitRecords.map(_.perWeekCallLimit)),
sumLimits(rateLimitRecords.map(_.perMonthCallLimit))
)
} else {
RateLimitingJson.CallLimit(consumerId, None, None, None, -1, -1, -1, -1, -1, -1)
}
}
/** Get the active rate limits for a consumer at a specific date. This is the SINGLE SOURCE OF TRUTH for rate limit calculation used by both:
* - The enforcement system (AfterApiAuth.checkRateLimiting)
* - The API endpoint (GET /consumer/rate-limits/active-at-date/{DATE})
* @param consumerId The consumer ID
* @param date The date to check active limits for
* @return Future containing the aggregated CallLimit that will be enforced
*/
def getActiveRateLimits(consumerId: String, date: Date): Future[CallLimit] = {
def getActiveRateLimitings(consumerId: String): Future[List[RateLimiting]] = {
useConsumerLimits match {
case true => RateLimitingDI.rateLimiting.vend.getActiveCallLimitsByConsumerIdAtDate(consumerId, date)
case false => Future(List.empty)
}
}
for {
rateLimitRecords <- getActiveRateLimitings(consumerId)
} yield {
rateLimitRecords match {
case Nil => getSystemDefaultRateLimits(consumerId)
case records => aggregateRateLimits(records, consumerId)
}
}
}
private def createUniqueKey(consumerKey: String, period: LimitCallPeriod) = consumerKey + "_" + RateLimitingPeriod.toString(period)
private def underConsumerLimits(consumerKey: String, period: LimitCallPeriod, limit: Long): Boolean = {

View File

@ -493,9 +493,9 @@ trait APIMethods600 {
val format = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss'Z'")
format.parse(dateString)
}
activeCallLimits <- RateLimitingDI.rateLimiting.vend.getActiveCallLimitsByConsumerIdAtDate(consumerId, date)
rateLimit <- RateLimitingUtil.getActiveRateLimits(consumerId, date)
} yield {
(createActiveCallLimitsJsonV600(activeCallLimits, date), HttpCode.`200`(callContext))
(JSONFactory600.createActiveCallLimitsJsonV600FromCallLimit(rateLimit, date), HttpCode.`200`(callContext))
}
}

View File

@ -587,6 +587,22 @@ object JSONFactory600 extends CustomJsonFormats with MdcLoggable {
)
}
def createActiveCallLimitsJsonV600FromCallLimit(
rateLimit: code.api.util.RateLimitingJson.CallLimit,
activeDate: java.util.Date
): ActiveCallLimitsJsonV600 = {
ActiveCallLimitsJsonV600(
call_limits = List.empty,
active_at_date = activeDate,
total_per_second_call_limit = rateLimit.per_second,
total_per_minute_call_limit = rateLimit.per_minute,
total_per_hour_call_limit = rateLimit.per_hour,
total_per_day_call_limit = rateLimit.per_day,
total_per_week_call_limit = rateLimit.per_week,
total_per_month_call_limit = rateLimit.per_month
)
}
def createTokenJSON(token: String): TokenJSON = {
TokenJSON(token)
}