diff --git a/obp-api/src/main/scala/code/api/util/AfterApiAuth.scala b/obp-api/src/main/scala/code/api/util/AfterApiAuth.scala index 13eae4fc4..1652a9da5 100644 --- a/obp-api/src/main/scala/code/api/util/AfterApiAuth.scala +++ b/obp-api/src/main/scala/code/api/util/AfterApiAuth.scala @@ -93,67 +93,18 @@ object AfterApiAuth extends MdcLoggable{ /** * This block of code needs to update Call Context with Rate Limiting - * Please note that first source is the table RateLimiting and second is the table Consumer + * Uses RateLimitingUtil.getActiveRateLimits as the SINGLE SOURCE OF TRUTH */ def checkRateLimiting(userIsLockedOrDeleted: Future[(Box[User], Option[CallContext])]): Future[(Box[User], Option[CallContext])] = { - def getActiveRateLimitings(consumerId: String): Future[List[RateLimiting]] = { - RateLimitingUtil.useConsumerLimits match { - case true => RateLimitingDI.rateLimiting.vend.getActiveCallLimitsByConsumerIdAtDate(consumerId, new Date()) - case false => Future(List.empty) - } - } - - def aggregateLimits(limits: List[RateLimiting], consumerId: String): CallLimit = { - def sumLimits(values: List[Long]): Long = { - val positiveValues = values.filter(_ > 0) - if (positiveValues.isEmpty) -1 else positiveValues.sum - } - - if (limits.nonEmpty) { - CallLimit( - consumerId, - limits.find(_.apiName.isDefined).flatMap(_.apiName), - limits.find(_.apiVersion.isDefined).flatMap(_.apiVersion), - limits.find(_.bankId.isDefined).flatMap(_.bankId), - sumLimits(limits.map(_.perSecondCallLimit)), - sumLimits(limits.map(_.perMinuteCallLimit)), - sumLimits(limits.map(_.perHourCallLimit)), - sumLimits(limits.map(_.perDayCallLimit)), - sumLimits(limits.map(_.perWeekCallLimit)), - sumLimits(limits.map(_.perMonthCallLimit)) - ) - } else { - CallLimit(consumerId, None, None, None, -1, -1, -1, -1, -1, -1) - } - } - for { (user, cc) <- userIsLockedOrDeleted consumer = cc.flatMap(_.consumer) consumerId = consumer.map(_.consumerId.get).getOrElse("") - rateLimitings <- getActiveRateLimitings(consumerId) + rateLimit <- RateLimitingUtil.getActiveRateLimits(consumerId, new Date()) } yield { - val limit: Option[CallLimit] = rateLimitings match { - case Nil => // No rate limiting records found, use consumer defaults - Some(CallLimit( - consumerId, - None, - None, - None, - consumer.map(_.perSecondCallLimit.get).getOrElse(-1), - consumer.map(_.perMinuteCallLimit.get).getOrElse(-1), - consumer.map(_.perHourCallLimit.get).getOrElse(-1), - consumer.map(_.perDayCallLimit.get).getOrElse(-1), - consumer.map(_.perWeekCallLimit.get).getOrElse(-1), - consumer.map(_.perMonthCallLimit.get).getOrElse(-1) - )) - case activeLimits => // Aggregate multiple rate limiting records - Some(aggregateLimits(activeLimits, consumerId)) - } - (user, cc.map(_.copy(rateLimiting = limit))) + (user, cc.map(_.copy(rateLimiting = Some(rateLimit)))) } } - private def sofitInitAction(user: AuthUser): Boolean = applyAction("sofit.logon_init_action.enabled") { def getOrCreateBankAccount(bank: Bank, accountId: String, label: String, accountType: String = ""): Box[BankAccount] = { MappedBankAccount.find( diff --git a/obp-api/src/main/scala/code/api/util/RateLimitingUtil.scala b/obp-api/src/main/scala/code/api/util/RateLimitingUtil.scala index 90f77f9a4..0ebe38dbe 100644 --- a/obp-api/src/main/scala/code/api/util/RateLimitingUtil.scala +++ b/obp-api/src/main/scala/code/api/util/RateLimitingUtil.scala @@ -1,5 +1,9 @@ package code.api.util +import java.util.Date +import code.ratelimiting.{RateLimiting, RateLimitingDI} +import scala.concurrent.Future +import scala.concurrent.ExecutionContext.Implicits.global import code.api.{APIFailureNewStyle, JedisMethod} import code.api.cache.Redis import code.api.util.APIUtil.fullBoxOrException @@ -74,6 +78,83 @@ object RateLimitingUtil extends MdcLoggable { def useConsumerLimits = APIUtil.getPropsAsBoolValue("use_consumer_limits", false) + /** Get system default rate limits from properties. Used when no RateLimiting records exist for a consumer. + * @param consumerId The consumer ID + * @return RateLimit with system property defaults (default to -1 if not set) + */ + def getSystemDefaultRateLimits(consumerId: String): CallLimit = { + RateLimitingJson.CallLimit( + consumerId, + None, + None, + None, + APIUtil.getPropsAsLongValue("rate_limiting_per_second", -1), + APIUtil.getPropsAsLongValue("rate_limiting_per_minute", -1), + APIUtil.getPropsAsLongValue("rate_limiting_per_hour", -1), + APIUtil.getPropsAsLongValue("rate_limiting_per_day", -1), + APIUtil.getPropsAsLongValue("rate_limiting_per_week", -1), + APIUtil.getPropsAsLongValue("rate_limiting_per_month", -1) + ) + } + + /** Aggregate multiple rate limiting records into a single CallLimit. This is the SINGLE SOURCE OF TRUTH for aggregation logic. + * Rules: + * - Only positive values (> 0) are summed + * - If no positive values exist for a period, return -1 (unlimited) + * - Multiple overlapping records have their limits added together + * @param rateLimitRecords List of RateLimiting records to aggregate + * @param consumerId The consumer ID + * @return Aggregated CallLimit + */ + def aggregateRateLimits(rateLimitRecords: List[RateLimiting], consumerId: String): CallLimit = { + def sumLimits(values: List[Long]): Long = { + val positiveValues = values.filter(_ > 0) + if (positiveValues.isEmpty) -1 else positiveValues.sum + } + + if (rateLimitRecords.nonEmpty) { + RateLimitingJson.CallLimit( + consumerId, + rateLimitRecords.find(_.apiName.isDefined).flatMap(_.apiName), + rateLimitRecords.find(_.apiVersion.isDefined).flatMap(_.apiVersion), + rateLimitRecords.find(_.bankId.isDefined).flatMap(_.bankId), + sumLimits(rateLimitRecords.map(_.perSecondCallLimit)), + sumLimits(rateLimitRecords.map(_.perMinuteCallLimit)), + sumLimits(rateLimitRecords.map(_.perHourCallLimit)), + sumLimits(rateLimitRecords.map(_.perDayCallLimit)), + sumLimits(rateLimitRecords.map(_.perWeekCallLimit)), + sumLimits(rateLimitRecords.map(_.perMonthCallLimit)) + ) + } else { + RateLimitingJson.CallLimit(consumerId, None, None, None, -1, -1, -1, -1, -1, -1) + } + } + + /** Get the active rate limits for a consumer at a specific date. This is the SINGLE SOURCE OF TRUTH for rate limit calculation used by both: + * - The enforcement system (AfterApiAuth.checkRateLimiting) + * - The API endpoint (GET /consumer/rate-limits/active-at-date/{DATE}) + * @param consumerId The consumer ID + * @param date The date to check active limits for + * @return Future containing the aggregated CallLimit that will be enforced + */ + def getActiveRateLimits(consumerId: String, date: Date): Future[CallLimit] = { + def getActiveRateLimitings(consumerId: String): Future[List[RateLimiting]] = { + useConsumerLimits match { + case true => RateLimitingDI.rateLimiting.vend.getActiveCallLimitsByConsumerIdAtDate(consumerId, date) + case false => Future(List.empty) + } + } + + for { + rateLimitRecords <- getActiveRateLimitings(consumerId) + } yield { + rateLimitRecords match { + case Nil => getSystemDefaultRateLimits(consumerId) + case records => aggregateRateLimits(records, consumerId) + } + } + } + private def createUniqueKey(consumerKey: String, period: LimitCallPeriod) = consumerKey + "_" + RateLimitingPeriod.toString(period) private def underConsumerLimits(consumerKey: String, period: LimitCallPeriod, limit: Long): Boolean = { diff --git a/obp-api/src/main/scala/code/api/v6_0_0/APIMethods600.scala b/obp-api/src/main/scala/code/api/v6_0_0/APIMethods600.scala index 68d7bc88b..52b7eb5e8 100644 --- a/obp-api/src/main/scala/code/api/v6_0_0/APIMethods600.scala +++ b/obp-api/src/main/scala/code/api/v6_0_0/APIMethods600.scala @@ -493,9 +493,9 @@ trait APIMethods600 { val format = new SimpleDateFormat("yyyy-MM-dd'T'HH:mm:ss'Z'") format.parse(dateString) } - activeCallLimits <- RateLimitingDI.rateLimiting.vend.getActiveCallLimitsByConsumerIdAtDate(consumerId, date) + rateLimit <- RateLimitingUtil.getActiveRateLimits(consumerId, date) } yield { - (createActiveCallLimitsJsonV600(activeCallLimits, date), HttpCode.`200`(callContext)) + (JSONFactory600.createActiveCallLimitsJsonV600FromCallLimit(rateLimit, date), HttpCode.`200`(callContext)) } } diff --git a/obp-api/src/main/scala/code/api/v6_0_0/JSONFactory6.0.0.scala b/obp-api/src/main/scala/code/api/v6_0_0/JSONFactory6.0.0.scala index 52c13d187..ba882414c 100644 --- a/obp-api/src/main/scala/code/api/v6_0_0/JSONFactory6.0.0.scala +++ b/obp-api/src/main/scala/code/api/v6_0_0/JSONFactory6.0.0.scala @@ -587,6 +587,22 @@ object JSONFactory600 extends CustomJsonFormats with MdcLoggable { ) } + def createActiveCallLimitsJsonV600FromCallLimit( + rateLimit: code.api.util.RateLimitingJson.CallLimit, + activeDate: java.util.Date + ): ActiveCallLimitsJsonV600 = { + ActiveCallLimitsJsonV600( + call_limits = List.empty, + active_at_date = activeDate, + total_per_second_call_limit = rateLimit.per_second, + total_per_minute_call_limit = rateLimit.per_minute, + total_per_hour_call_limit = rateLimit.per_hour, + total_per_day_call_limit = rateLimit.per_day, + total_per_week_call_limit = rateLimit.per_week, + total_per_month_call_limit = rateLimit.per_month + ) + } + def createTokenJSON(token: String): TokenJSON = { TokenJSON(token) }