mirror of
https://github.com/OpenBankProject/OBP-API.git
synced 2026-02-06 16:16:47 +00:00
rate-limits refactor for single point of truth
This commit is contained in:
parent
47d6f97d89
commit
a9a7384088
@ -93,67 +93,18 @@ object AfterApiAuth extends MdcLoggable{
|
||||
|
||||
/**
|
||||
* This block of code needs to update Call Context with Rate Limiting
|
||||
* Please note that first source is the table RateLimiting and second is the table Consumer
|
||||
* Uses RateLimitingUtil.getActiveRateLimits as the SINGLE SOURCE OF TRUTH
|
||||
*/
|
||||
def checkRateLimiting(userIsLockedOrDeleted: Future[(Box[User], Option[CallContext])]): Future[(Box[User], Option[CallContext])] = {
|
||||
def getActiveRateLimitings(consumerId: String): Future[List[RateLimiting]] = {
|
||||
RateLimitingUtil.useConsumerLimits match {
|
||||
case true => RateLimitingDI.rateLimiting.vend.getActiveCallLimitsByConsumerIdAtDate(consumerId, new Date())
|
||||
case false => Future(List.empty)
|
||||
}
|
||||
}
|
||||
|
||||
def aggregateLimits(limits: List[RateLimiting], consumerId: String): CallLimit = {
|
||||
def sumLimits(values: List[Long]): Long = {
|
||||
val positiveValues = values.filter(_ > 0)
|
||||
if (positiveValues.isEmpty) -1 else positiveValues.sum
|
||||
}
|
||||
|
||||
if (limits.nonEmpty) {
|
||||
CallLimit(
|
||||
consumerId,
|
||||
limits.find(_.apiName.isDefined).flatMap(_.apiName),
|
||||
limits.find(_.apiVersion.isDefined).flatMap(_.apiVersion),
|
||||
limits.find(_.bankId.isDefined).flatMap(_.bankId),
|
||||
sumLimits(limits.map(_.perSecondCallLimit)),
|
||||
sumLimits(limits.map(_.perMinuteCallLimit)),
|
||||
sumLimits(limits.map(_.perHourCallLimit)),
|
||||
sumLimits(limits.map(_.perDayCallLimit)),
|
||||
sumLimits(limits.map(_.perWeekCallLimit)),
|
||||
sumLimits(limits.map(_.perMonthCallLimit))
|
||||
)
|
||||
} else {
|
||||
CallLimit(consumerId, None, None, None, -1, -1, -1, -1, -1, -1)
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
(user, cc) <- userIsLockedOrDeleted
|
||||
consumer = cc.flatMap(_.consumer)
|
||||
consumerId = consumer.map(_.consumerId.get).getOrElse("")
|
||||
rateLimitings <- getActiveRateLimitings(consumerId)
|
||||
rateLimit <- RateLimitingUtil.getActiveRateLimits(consumerId, new Date())
|
||||
} yield {
|
||||
val limit: Option[CallLimit] = rateLimitings match {
|
||||
case Nil => // No rate limiting records found, use consumer defaults
|
||||
Some(CallLimit(
|
||||
consumerId,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
consumer.map(_.perSecondCallLimit.get).getOrElse(-1),
|
||||
consumer.map(_.perMinuteCallLimit.get).getOrElse(-1),
|
||||
consumer.map(_.perHourCallLimit.get).getOrElse(-1),
|
||||
consumer.map(_.perDayCallLimit.get).getOrElse(-1),
|
||||
consumer.map(_.perWeekCallLimit.get).getOrElse(-1),
|
||||
consumer.map(_.perMonthCallLimit.get).getOrElse(-1)
|
||||
))
|
||||
case activeLimits => // Aggregate multiple rate limiting records
|
||||
Some(aggregateLimits(activeLimits, consumerId))
|
||||
}
|
||||
(user, cc.map(_.copy(rateLimiting = limit)))
|
||||
(user, cc.map(_.copy(rateLimiting = Some(rateLimit))))
|
||||
}
|
||||
}
|
||||
|
||||
private def sofitInitAction(user: AuthUser): Boolean = applyAction("sofit.logon_init_action.enabled") {
|
||||
def getOrCreateBankAccount(bank: Bank, accountId: String, label: String, accountType: String = ""): Box[BankAccount] = {
|
||||
MappedBankAccount.find(
|
||||
|
||||
@ -1,5 +1,9 @@
|
||||
package code.api.util
|
||||
|
||||
import java.util.Date
|
||||
import code.ratelimiting.{RateLimiting, RateLimitingDI}
|
||||
import scala.concurrent.Future
|
||||
import scala.concurrent.ExecutionContext.Implicits.global
|
||||
import code.api.{APIFailureNewStyle, JedisMethod}
|
||||
import code.api.cache.Redis
|
||||
import code.api.util.APIUtil.fullBoxOrException
|
||||
@ -74,6 +78,83 @@ object RateLimitingUtil extends MdcLoggable {
|
||||
|
||||
def useConsumerLimits = APIUtil.getPropsAsBoolValue("use_consumer_limits", false)
|
||||
|
||||
/** Get system default rate limits from properties. Used when no RateLimiting records exist for a consumer.
|
||||
* @param consumerId The consumer ID
|
||||
* @return RateLimit with system property defaults (default to -1 if not set)
|
||||
*/
|
||||
def getSystemDefaultRateLimits(consumerId: String): CallLimit = {
|
||||
RateLimitingJson.CallLimit(
|
||||
consumerId,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
APIUtil.getPropsAsLongValue("rate_limiting_per_second", -1),
|
||||
APIUtil.getPropsAsLongValue("rate_limiting_per_minute", -1),
|
||||
APIUtil.getPropsAsLongValue("rate_limiting_per_hour", -1),
|
||||
APIUtil.getPropsAsLongValue("rate_limiting_per_day", -1),
|
||||
APIUtil.getPropsAsLongValue("rate_limiting_per_week", -1),
|
||||
APIUtil.getPropsAsLongValue("rate_limiting_per_month", -1)
|
||||
)
|
||||
}
|
||||
|
||||
/** Aggregate multiple rate limiting records into a single CallLimit. This is the SINGLE SOURCE OF TRUTH for aggregation logic.
|
||||
* Rules:
|
||||
* - Only positive values (> 0) are summed
|
||||
* - If no positive values exist for a period, return -1 (unlimited)
|
||||
* - Multiple overlapping records have their limits added together
|
||||
* @param rateLimitRecords List of RateLimiting records to aggregate
|
||||
* @param consumerId The consumer ID
|
||||
* @return Aggregated CallLimit
|
||||
*/
|
||||
def aggregateRateLimits(rateLimitRecords: List[RateLimiting], consumerId: String): CallLimit = {
|
||||
def sumLimits(values: List[Long]): Long = {
|
||||
val positiveValues = values.filter(_ > 0)
|
||||
if (positiveValues.isEmpty) -1 else positiveValues.sum
|
||||
}
|
||||
|
||||
if (rateLimitRecords.nonEmpty) {
|
||||
RateLimitingJson.CallLimit(
|
||||
consumerId,
|
||||
rateLimitRecords.find(_.apiName.isDefined).flatMap(_.apiName),
|
||||
rateLimitRecords.find(_.apiVersion.isDefined).flatMap(_.apiVersion),
|
||||
rateLimitRecords.find(_.bankId.isDefined).flatMap(_.bankId),
|
||||
sumLimits(rateLimitRecords.map(_.perSecondCallLimit)),
|
||||
sumLimits(rateLimitRecords.map(_.perMinuteCallLimit)),
|
||||
sumLimits(rateLimitRecords.map(_.perHourCallLimit)),
|
||||
sumLimits(rateLimitRecords.map(_.perDayCallLimit)),
|
||||
sumLimits(rateLimitRecords.map(_.perWeekCallLimit)),
|
||||
sumLimits(rateLimitRecords.map(_.perMonthCallLimit))
|
||||
)
|
||||
} else {
|
||||
RateLimitingJson.CallLimit(consumerId, None, None, None, -1, -1, -1, -1, -1, -1)
|
||||
}
|
||||
}
|
||||
|
||||
/** Get the active rate limits for a consumer at a specific date. This is the SINGLE SOURCE OF TRUTH for rate limit calculation used by both:
|
||||
* - The enforcement system (AfterApiAuth.checkRateLimiting)
|
||||
* - The API endpoint (GET /consumer/rate-limits/active-at-date/{DATE})
|
||||
* @param consumerId The consumer ID
|
||||
* @param date The date to check active limits for
|
||||
* @return Future containing the aggregated CallLimit that will be enforced
|
||||
*/
|
||||
def getActiveRateLimits(consumerId: String, date: Date): Future[CallLimit] = {
|
||||
def getActiveRateLimitings(consumerId: String): Future[List[RateLimiting]] = {
|
||||
useConsumerLimits match {
|
||||
case true => RateLimitingDI.rateLimiting.vend.getActiveCallLimitsByConsumerIdAtDate(consumerId, date)
|
||||
case false => Future(List.empty)
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
rateLimitRecords <- getActiveRateLimitings(consumerId)
|
||||
} yield {
|
||||
rateLimitRecords match {
|
||||
case Nil => getSystemDefaultRateLimits(consumerId)
|
||||
case records => aggregateRateLimits(records, consumerId)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
private def createUniqueKey(consumerKey: String, period: LimitCallPeriod) = consumerKey + "_" + RateLimitingPeriod.toString(period)
|
||||
|
||||
private def underConsumerLimits(consumerKey: String, period: LimitCallPeriod, limit: Long): Boolean = {
|
||||
|
||||
@ -493,9 +493,9 @@ trait APIMethods600 {
|
||||
val format = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss'Z'")
|
||||
format.parse(dateString)
|
||||
}
|
||||
activeCallLimits <- RateLimitingDI.rateLimiting.vend.getActiveCallLimitsByConsumerIdAtDate(consumerId, date)
|
||||
rateLimit <- RateLimitingUtil.getActiveRateLimits(consumerId, date)
|
||||
} yield {
|
||||
(createActiveCallLimitsJsonV600(activeCallLimits, date), HttpCode.`200`(callContext))
|
||||
(JSONFactory600.createActiveCallLimitsJsonV600FromCallLimit(rateLimit, date), HttpCode.`200`(callContext))
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@ -587,6 +587,22 @@ object JSONFactory600 extends CustomJsonFormats with MdcLoggable {
|
||||
)
|
||||
}
|
||||
|
||||
def createActiveCallLimitsJsonV600FromCallLimit(
|
||||
rateLimit: code.api.util.RateLimitingJson.CallLimit,
|
||||
activeDate: java.util.Date
|
||||
): ActiveCallLimitsJsonV600 = {
|
||||
ActiveCallLimitsJsonV600(
|
||||
call_limits = List.empty,
|
||||
active_at_date = activeDate,
|
||||
total_per_second_call_limit = rateLimit.per_second,
|
||||
total_per_minute_call_limit = rateLimit.per_minute,
|
||||
total_per_hour_call_limit = rateLimit.per_hour,
|
||||
total_per_day_call_limit = rateLimit.per_day,
|
||||
total_per_week_call_limit = rateLimit.per_week,
|
||||
total_per_month_call_limit = rateLimit.per_month
|
||||
)
|
||||
}
|
||||
|
||||
def createTokenJSON(token: String): TokenJSON = {
|
||||
TokenJSON(token)
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user