diff --git a/obp-api/src/main/scala/bootstrap/http4s/Http4sBoot.scala b/obp-api/src/main/scala/bootstrap/http4s/Http4sBoot.scala deleted file mode 100644 index 0a867ec4a..000000000 --- a/obp-api/src/main/scala/bootstrap/http4s/Http4sBoot.scala +++ /dev/null @@ -1,346 +0,0 @@ -/** -Open Bank Project - API -Copyright (C) 2011-2019, TESOBE GmbH. - -This program is free software: you can redistribute it and/or modify -it under the terms of the GNU Affero General Public License as published by -the Free Software Foundation, either version 3 of the License, or -(at your option) any later version. - -This program is distributed in the hope that it will be useful, -but WITHOUT ANY WARRANTY; without even the implied warranty of -MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -GNU Affero General Public License for more details. - -You should have received a copy of the GNU Affero General Public License -along with this program. If not, see . - -Email: contact@tesobe.com -TESOBE GmbH. -Osloer Strasse 16/17 -Berlin 13359, Germany - -This product includes software developed at -TESOBE (http://www.tesobe.com/) - - */ -package bootstrap.http4s - -import bootstrap.liftweb.ToSchemify -import code.api.Constant._ -import code.api.util.ApiRole.CanCreateEntitlementAtAnyBank -import code.api.util.ErrorMessages.MandatoryPropertyIsNotSet -import code.api.util._ -import code.api.util.migration.Migration -import code.api.util.migration.Migration.DbFunction -import code.entitlement.Entitlement -import code.model.dataAccess._ -import code.scheduler._ -import code.users._ -import code.util.Helper.MdcLoggable -import code.views.Views -import com.openbankproject.commons.util.Functions.Implicits._ -import net.liftweb.common.Box.tryo -import net.liftweb.common._ -import net.liftweb.db.{DB, DBLogEntry} -import net.liftweb.mapper.{DefaultConnectionIdentifier => _, _} -import net.liftweb.util._ - -import java.io.{File, FileInputStream} -import java.util.TimeZone - - - - -/** - * Http4s Boot class for initializing OBP-API core components - * This class handles database initialization, migrations, and system setup - * without Lift Web framework dependencies - */ -class Http4sBoot extends MdcLoggable { - - /** - * For the project scope, most early initiate logic should in this method. - */ - override protected def initiate(): Unit = { - val resourceDir = System.getProperty("props.resource.dir") ?: System.getenv("props.resource.dir") - val propsPath = tryo{Box.legacyNullTest(resourceDir)}.toList.flatten - - val propsDir = for { - propsPath <- propsPath - } yield { - Props.toTry.map { - f => { - val name = propsPath + f() + "props" - name -> { () => tryo{new FileInputStream(new File(name))} } - } - } - } - - Props.whereToLook = () => { - propsDir.flatten - } - - if (Props.mode == Props.RunModes.Development) logger.info("OBP-API Props all fields : \n" + Props.props.mkString("\n")) - logger.info("external props folder: " + propsPath) - TimeZone.setDefault(TimeZone.getTimeZone("UTC")) - logger.info("Current Project TimeZone: " + TimeZone.getDefault) - - - // set dynamic_code_sandbox_enable to System.properties, so com.openbankproject.commons.ExecutionContext can read this value - APIUtil.getPropsValue("dynamic_code_sandbox_enable") - .foreach(it => System.setProperty("dynamic_code_sandbox_enable", it)) - } - - - - def boot: Unit = { - implicit val formats = CustomJsonFormats.formats - - logger.info("Http4sBoot says: Hello from the Open Bank Project API. This is Http4sBoot.scala for Http4s runner. The gitCommit is : " + APIUtil.gitCommit) - - logger.debug("Boot says:Using database driver: " + APIUtil.driver) - - DB.defineConnectionManager(net.liftweb.util.DefaultConnectionIdentifier, APIUtil.vendor) - - /** - * Function that determines if foreign key constraints are - * created by Schemifier for the specified connection. - * - * Note: The chosen driver must also support foreign keys for - * creation to happen - * - * In case of PostgreSQL it works - */ - MapperRules.createForeignKeys_? = (_) => APIUtil.getPropsAsBoolValue("mapper_rules.create_foreign_keys", false) - - schemifyAll() - - logger.info("Mapper database info: " + Migration.DbFunction.mapperDatabaseInfo) - - DbFunction.tableExists(ResourceUser) match { - case true => // DB already exist - // Migration Scripts are used to update the model of OBP-API DB to a latest version. - // Please note that migration scripts are executed before Lift Mapper Schemifier - Migration.database.executeScripts(startedBeforeSchemifier = true) - logger.info("The Mapper database already exits. The scripts are executed BEFORE Lift Mapper Schemifier.") - case false => // DB is still not created. The scripts will be executed after Lift Mapper Schemifier - logger.info("The Mapper database is still not created. The scripts are going to be executed AFTER Lift Mapper Schemifier.") - } - - // Migration Scripts are used to update the model of OBP-API DB to a latest version. - - // Please note that migration scripts are executed after Lift Mapper Schemifier - Migration.database.executeScripts(startedBeforeSchemifier = false) - - if (APIUtil.getPropsAsBoolValue("create_system_views_at_boot", true)) { - // Create system views - val owner = Views.views.vend.getOrCreateSystemView(SYSTEM_OWNER_VIEW_ID).isDefined - val auditor = Views.views.vend.getOrCreateSystemView(SYSTEM_AUDITOR_VIEW_ID).isDefined - val accountant = Views.views.vend.getOrCreateSystemView(SYSTEM_ACCOUNTANT_VIEW_ID).isDefined - val standard = Views.views.vend.getOrCreateSystemView(SYSTEM_STANDARD_VIEW_ID).isDefined - val stageOne = Views.views.vend.getOrCreateSystemView(SYSTEM_STAGE_ONE_VIEW_ID).isDefined - val manageCustomViews = Views.views.vend.getOrCreateSystemView(SYSTEM_MANAGE_CUSTOM_VIEWS_VIEW_ID).isDefined - // Only create Firehose view if they are enabled at instance. - val accountFirehose = if (ApiPropsWithAlias.allowAccountFirehose) - Views.views.vend.getOrCreateSystemView(SYSTEM_FIREHOSE_VIEW_ID).isDefined - else Empty.isDefined - - APIUtil.getPropsValue("additional_system_views") match { - case Full(value) => - val additionalSystemViewsFromProps = value.split(",").map(_.trim).toList - val additionalSystemViews = List( - SYSTEM_READ_ACCOUNTS_BASIC_VIEW_ID, - SYSTEM_READ_ACCOUNTS_DETAIL_VIEW_ID, - SYSTEM_READ_BALANCES_VIEW_ID, - SYSTEM_READ_TRANSACTIONS_BASIC_VIEW_ID, - SYSTEM_READ_TRANSACTIONS_DEBITS_VIEW_ID, - SYSTEM_READ_TRANSACTIONS_DETAIL_VIEW_ID, - SYSTEM_READ_ACCOUNTS_BERLIN_GROUP_VIEW_ID, - SYSTEM_READ_BALANCES_BERLIN_GROUP_VIEW_ID, - SYSTEM_READ_TRANSACTIONS_BERLIN_GROUP_VIEW_ID, - SYSTEM_INITIATE_PAYMENTS_BERLIN_GROUP_VIEW_ID - ) - for { - systemView <- additionalSystemViewsFromProps - if additionalSystemViews.exists(_ == systemView) - } { - Views.views.vend.getOrCreateSystemView(systemView) - } - case _ => // Do nothing - } - - } - - ApiWarnings.logWarningsRegardingProperties() - ApiWarnings.customViewNamesCheck() - ApiWarnings.systemViewNamesCheck() - - //see the notes for this method: - createDefaultBankAndDefaultAccountsIfNotExisting() - - createBootstrapSuperUser() - - if (APIUtil.getPropsAsBoolValue("logging.database.queries.enable", false)) { - DB.addLogFunc - { - case (log, duration) => - { - logger.debug("Total query time : %d ms".format(duration)) - log.allEntries.foreach - { - case DBLogEntry(stmt, duration) => - logger.debug("The query : %s in %d ms".format(stmt, duration)) - } - } - } - } - - // start RabbitMq Adapter(using mapped connector as mockded CBS) - if (APIUtil.getPropsAsBoolValue("rabbitmq.adapter.enabled", false)) { - code.bankconnectors.rabbitmq.Adapter.startRabbitMqAdapter.main(Array("")) - } - - // ensure our relational database's tables are created/fit the schema - val connector = code.api.Constant.CONNECTOR.openOrThrowException(s"$MandatoryPropertyIsNotSet. The missing prop is `connector` ") - - logger.info(s"ApiPathZero (the bit before version) is $ApiPathZero") - logger.debug(s"If you can read this, logging level is debug") - - // API Metrics (logs of API calls) - // If set to true we will write each URL with params to a datastore / log file - if (APIUtil.getPropsAsBoolValue("write_metrics", false)) { - logger.info("writeMetrics is true. We will write API metrics") - } else { - logger.info("writeMetrics is false. We will NOT write API metrics") - } - - // API Metrics (logs of Connector calls) - // If set to true we will write each URL with params to a datastore / log file - if (APIUtil.getPropsAsBoolValue("write_connector_metrics", false)) { - logger.info("writeConnectorMetrics is true. We will write connector metrics") - } else { - logger.info("writeConnectorMetrics is false. We will NOT write connector metrics") - } - - - logger.info (s"props_identifier is : ${APIUtil.getPropsValue("props_identifier", "NONE-SET")}") - - val locale = I18NUtil.getDefaultLocale() - logger.info("Default Project Locale is :" + locale) - - } - - def schemifyAll() = { - Schemifier.schemify(true, Schemifier.infoF _, ToSchemify.models: _*) - } - - - /** - * there will be a default bank and two default accounts in obp mapped mode. - * These bank and accounts will be used for the payments. - * when we create transaction request over counterparty and if the counterparty do not link to an existing obp account - * then we will use the default accounts (incoming and outgoing) to keep the money. - */ - private def createDefaultBankAndDefaultAccountsIfNotExisting() ={ - val defaultBankId= APIUtil.defaultBankId - val incomingAccountId= INCOMING_SETTLEMENT_ACCOUNT_ID - val outgoingAccountId= OUTGOING_SETTLEMENT_ACCOUNT_ID - - MappedBank.find(By(MappedBank.permalink, defaultBankId)) match { - case Full(b) => - logger.debug(s"Bank(${defaultBankId}) is found.") - case _ => - MappedBank.create - .permalink(defaultBankId) - .fullBankName("OBP_DEFAULT_BANK") - .shortBankName("OBP") - .national_identifier("OBP") - .mBankRoutingScheme("OBP") - .mBankRoutingAddress("obp1") - .logoURL("") - .websiteURL("") - .saveMe() - logger.debug(s"creating Bank(${defaultBankId})") - } - - MappedBankAccount.find(By(MappedBankAccount.bank, defaultBankId), By(MappedBankAccount.theAccountId, incomingAccountId)) match { - case Full(b) => - logger.debug(s"BankAccount(${defaultBankId}, $incomingAccountId) is found.") - case _ => - MappedBankAccount.create - .bank(defaultBankId) - .theAccountId(incomingAccountId) - .accountCurrency("EUR") - .saveMe() - logger.debug(s"creating BankAccount(${defaultBankId}, $incomingAccountId).") - } - - MappedBankAccount.find(By(MappedBankAccount.bank, defaultBankId), By(MappedBankAccount.theAccountId, outgoingAccountId)) match { - case Full(b) => - logger.debug(s"BankAccount(${defaultBankId}, $outgoingAccountId) is found.") - case _ => - MappedBankAccount.create - .bank(defaultBankId) - .theAccountId(outgoingAccountId) - .accountCurrency("EUR") - .saveMe() - logger.debug(s"creating BankAccount(${defaultBankId}, $outgoingAccountId).") - } - } - - - /** - * Bootstrap Super User - * Given the following credentials, OBP will create a user *if it does not exist already*. - * This user's password will be valid for a limited amount of time. - * This user will be granted ONLY CanCreateEntitlementAtAnyBank - * This feature can also be used in a "Break Glass scenario" - */ - private def createBootstrapSuperUser() ={ - - val superAdminUsername = APIUtil.getPropsValue("super_admin_username","") - val superAdminInitalPassword = APIUtil.getPropsValue("super_admin_inital_password","") - val superAdminEmail = APIUtil.getPropsValue("super_admin_email","") - - val isPropsNotSetProperly = superAdminUsername==""||superAdminInitalPassword ==""||superAdminEmail=="" - - //This is the logic to check if an AuthUser exists for the `create sandbox` endpoint, AfterApiAuth, OpenIdConnect ,,, - val existingAuthUser = AuthUser.find(By(AuthUser.username, superAdminUsername)) - - if(isPropsNotSetProperly) { - //Nothing happens, props is not set - }else if(existingAuthUser.isDefined) { - logger.error(s"createBootstrapSuperUser- Errors: Existing AuthUser with username ${superAdminUsername} detected in data import where no ResourceUser was found") - } else { - val authUser = AuthUser.create - .email(superAdminEmail) - .firstName(superAdminUsername) - .lastName(superAdminUsername) - .username(superAdminUsername) - .password(superAdminInitalPassword) - .passwordShouldBeChanged(true) - .validated(true) - - val validationErrors = authUser.validate - - if(!validationErrors.isEmpty) - logger.error(s"createBootstrapSuperUser- Errors: ${validationErrors.map(_.msg)}") - else { - Full(authUser.save()) //this will create/update the resourceUser. - - val userBox = Users.users.vend.getUserByProviderAndUsername(authUser.getProvider(), authUser.username.get) - - val resultBox = userBox.map(user => Entitlement.entitlement.vend.addEntitlement("", user.userId, CanCreateEntitlementAtAnyBank.toString)) - - if(resultBox.isEmpty){ - logger.error(s"createBootstrapSuperUser- Errors: ${resultBox}") - } - } - - } - - } - - -} diff --git a/obp-api/src/main/scala/bootstrap/http4s/Http4sServer.scala b/obp-api/src/main/scala/bootstrap/http4s/Http4sServer.scala index 7f6584d00..6f1dc1529 100644 --- a/obp-api/src/main/scala/bootstrap/http4s/Http4sServer.scala +++ b/obp-api/src/main/scala/bootstrap/http4s/Http4sServer.scala @@ -3,6 +3,7 @@ package bootstrap.http4s import cats.data.{Kleisli, OptionT} import cats.effect._ import code.api.util.APIUtil +import code.api.util.http4s.Http4sLiftWebBridge import com.comcast.ip4s._ import org.http4s._ import org.http4s.ember.server._ @@ -11,19 +12,23 @@ import org.http4s.implicits._ import scala.language.higherKinds object Http4sServer extends IOApp { - //Start OBP relevant objects and settings; this step MUST be executed first - new bootstrap.http4s.Http4sBoot().boot + //Start OBP relevant objects and settings; this step MUST be executed first + // new bootstrap.http4s.Http4sBoot().boot + new bootstrap.liftweb.Boot().boot val port = APIUtil.getPropsAsIntValue("http4s.port",8086) val host = APIUtil.getPropsValue("http4s.host","127.0.0.1") type HttpF[A] = OptionT[IO, A] - val services: HttpRoutes[IO] = Kleisli[HttpF, Request[IO], Response[IO]] { req: Request[IO] => + private val baseServices: HttpRoutes[IO] = Kleisli[HttpF, Request[IO], Response[IO]] { req: Request[IO] => code.api.v5_0_0.Http4s500.wrappedRoutesV500Services.run(req) .orElse(code.api.v7_0_0.Http4s700.wrappedRoutesV700Services.run(req)) + .orElse(Http4sLiftWebBridge.routes.run(req)) } + val services: HttpRoutes[IO] = Http4sLiftWebBridge.withStandardHeaders(baseServices) + val httpApp: Kleisli[IO, Request[IO], Response[IO]] = (services).orNotFound override def run(args: List[String]): IO[ExitCode] = EmberServerBuilder diff --git a/obp-api/src/main/scala/code/api/util/http4s/Http4sLiftWebBridge.scala b/obp-api/src/main/scala/code/api/util/http4s/Http4sLiftWebBridge.scala new file mode 100644 index 000000000..571b88e09 --- /dev/null +++ b/obp-api/src/main/scala/code/api/util/http4s/Http4sLiftWebBridge.scala @@ -0,0 +1,309 @@ +package code.api.util.http4s + +import cats.data.{Kleisli, OptionT} +import cats.effect.IO +import code.api.{APIFailure, JsonResponseException, ResponseHeader} +import code.api.util.APIUtil +import code.util.Helper.MdcLoggable +import com.openbankproject.commons.util.ReflectUtils +import net.liftweb.actor.LAFuture +import net.liftweb.common.{Box, Empty, Failure, Full, ParamFailure} +import net.liftweb.http._ +import net.liftweb.http.provider.{HTTPContext, HTTPParam, HTTPProvider, HTTPRequest, HTTPSession, HTTPCookie, RetryState} +import org.http4s._ +import org.typelevel.ci.CIString + +import java.io.{ByteArrayInputStream, ByteArrayOutputStream, InputStream} +import java.time.format.DateTimeFormatter +import java.time.{ZoneOffset, ZonedDateTime} +import java.util.{Locale, UUID} +import java.util.concurrent.ConcurrentHashMap +import scala.collection.JavaConverters._ + +object Http4sLiftWebBridge extends MdcLoggable { + type HttpF[A] = OptionT[IO, A] + + // Configurable timeout for continuation resolution (default: 60 seconds) + private lazy val continuationTimeoutMs: Long = + APIUtil.getPropsAsLongValue("http4s.continuation.timeout.ms", 60000L) + + def routes: HttpRoutes[IO] = HttpRoutes.of[IO] { + case req => dispatch(req) + } + + def withStandardHeaders(routes: HttpRoutes[IO]): HttpRoutes[IO] = { + Kleisli[HttpF, Request[IO], Response[IO]] { req: Request[IO] => + routes.run(req).map(resp => ensureStandardHeaders(req, resp)) + } + } + + def dispatch(req: Request[IO]): IO[Response[IO]] = { + val uri = req.uri.renderString + val method = req.method.name + logger.debug(s"Http4sLiftBridge dispatching: $method $uri") + for { + bodyBytes <- req.body.compile.to(Array) + liftReq = buildLiftReq(req, bodyBytes) + liftResp <- IO { + val session = LiftRules.statelessSession.vend.apply(liftReq) + S.init(Full(liftReq), session) { + try { + runLiftDispatch(liftReq) + } catch { + case JsonResponseException(jsonResponse) => jsonResponse + case e if e.getClass.getName == "net.liftweb.http.rest.ContinuationException" => + resolveContinuation(e) + } + } + } + http4sResponse <- liftResponseToHttp4s(liftResp) + } yield { + logger.debug(s"Http4sLiftBridge completed: $method $uri -> ${http4sResponse.status.code}") + ensureStandardHeaders(req, http4sResponse) + } + } + + private def runLiftDispatch(req: Req): LiftResponse = { + val handlers = LiftRules.statelessDispatch.toList ++ LiftRules.dispatch.toList + val handler = handlers.collectFirst { case pf if pf.isDefinedAt(req) => pf(req) } + handler match { + case Some(run) => + try { + run() match { + case Full(resp) => resp + case ParamFailure(_, _, _, apiFailure: APIFailure) => + APIUtil.errorJsonResponse(apiFailure.msg, apiFailure.responseCode) + case Failure(msg, _, _) => + APIUtil.errorJsonResponse(msg) + case Empty => + NotFoundResponse() + } + } catch { + case JsonResponseException(jsonResponse) => jsonResponse + case e if e.getClass.getName == "net.liftweb.http.rest.ContinuationException" => + resolveContinuation(e) + } + case None => NotFoundResponse() + } + } + + private def resolveContinuation(exception: Throwable): LiftResponse = { + logger.debug(s"Resolving ContinuationException for async Lift handler") + val func = + ReflectUtils + .getCallByNameValue(exception, "f") + .asInstanceOf[((=> LiftResponse) => Unit) => Unit] + val future = new LAFuture[LiftResponse] + val satisfy: (=> LiftResponse) => Unit = response => future.satisfy(response) + func(satisfy) + future.get(continuationTimeoutMs).openOr { + logger.warn(s"Continuation timeout after ${continuationTimeoutMs}ms, returning InternalServerError") + InternalServerErrorResponse() + } + } + + private def buildLiftReq(req: Request[IO], body: Array[Byte]): Req = { + val headers = http4sHeadersToParams(req.headers.headers) + val params = http4sParamsToParams(req.uri.query.multiParams.toList) + val httpRequest = new Http4sLiftRequest( + req = req, + body = body, + headerParams = headers, + queryParams = params + ) + Req( + httpRequest, + LiftRules.statelessRewrite.toList, + Nil, + LiftRules.statelessReqTest.toList, + System.nanoTime() + ) + } + + private def http4sHeadersToParams(headers: List[Header.Raw]): List[HTTPParam] = { + headers + .groupBy(_.name.toString) + .toList + .map { case (name, values) => + HTTPParam(name, values.map(_.value)) + } + } + + private def http4sParamsToParams(params: List[(String, collection.Seq[String])]): List[HTTPParam] = { + params.map { case (name, values) => + HTTPParam(name, values.toList) + } + } + + private def liftResponseToHttp4s(response: LiftResponse): IO[Response[IO]] = { + response.toResponse match { + case InMemoryResponse(data, headers, _, code) => + IO.pure(buildHttp4sResponse(code, data, headers)) + case StreamingResponse(data, onEnd, _, headers, _, code) => + IO { + try { + val bytes = readAllBytes(data.asInstanceOf[InputStream]) + buildHttp4sResponse(code, bytes, headers) + } finally { + onEnd() + } + } + case OutputStreamResponse(out, _, headers, _, code) => + IO { + val baos = new ByteArrayOutputStream() + out(baos) + buildHttp4sResponse(code, baos.toByteArray, headers) + } + case basic: BasicResponse => + IO.pure(buildHttp4sResponse(basic.code, Array.emptyByteArray, basic.headers)) + } + } + + private def buildHttp4sResponse(code: Int, body: Array[Byte], headers: List[(String, String)]): Response[IO] = { + val hasContentType = headers.exists { case (name, _) => name.equalsIgnoreCase("Content-Type") } + val normalizedHeaders = if (hasContentType) { + headers + } else { + ("Content-Type", "application/json; charset=utf-8") :: headers + } + val http4sHeaders = Headers( + normalizedHeaders.map { case (name, value) => Header.Raw(CIString(name), value) } + ) + Response[IO]( + status = org.http4s.Status.fromInt(code).getOrElse(org.http4s.Status.InternalServerError) + ).withEntity(body).withHeaders(http4sHeaders) + } + + private def readAllBytes(input: InputStream): Array[Byte] = { + val buffer = new ByteArrayOutputStream() + val chunk = new Array[Byte](4096) + var read = input.read(chunk) + while (read != -1) { + buffer.write(chunk, 0, read) + read = input.read(chunk) + } + buffer.toByteArray + } + + private def ensureStandardHeaders(req: Request[IO], resp: Response[IO]): Response[IO] = { + val now = ZonedDateTime.now(ZoneOffset.UTC).format(DateTimeFormatter.RFC_1123_DATE_TIME) + val existing = resp.headers.headers + def hasHeader(name: String): Boolean = + existing.exists(_.name.toString.equalsIgnoreCase(name)) + val existingCorrelationId = existing + .find(_.name.toString.equalsIgnoreCase(ResponseHeader.`Correlation-Id`)) + .map(_.value) + .getOrElse("") + val correlationId = + Option(existingCorrelationId).map(_.trim).filter(_.nonEmpty) + .orElse(req.headers.headers.find(_.name.toString.equalsIgnoreCase("X-Request-ID")).map(_.value)) + .getOrElse(UUID.randomUUID().toString) + val extraHeaders = List.newBuilder[Header.Raw] + if (existingCorrelationId.trim.isEmpty) { + extraHeaders += Header.Raw(CIString(ResponseHeader.`Correlation-Id`), correlationId) + } + if (!hasHeader("Cache-Control")) { + extraHeaders += Header.Raw(CIString("Cache-Control"), "no-cache, private, no-store") + } + if (!hasHeader("Pragma")) { + extraHeaders += Header.Raw(CIString("Pragma"), "no-cache") + } + if (!hasHeader("Expires")) { + extraHeaders += Header.Raw(CIString("Expires"), now) + } + if (!hasHeader("X-Frame-Options")) { + extraHeaders += Header.Raw(CIString("X-Frame-Options"), "DENY") + } + val headersToAdd = extraHeaders.result() + if (headersToAdd.isEmpty) resp + else { + val filtered = resp.headers.headers.filterNot(h => + h.name.toString.equalsIgnoreCase(ResponseHeader.`Correlation-Id`) && + h.value.trim.isEmpty + ) + resp.copy(headers = Headers(filtered) ++ Headers(headersToAdd)) + } + } + + private object Http4sLiftContext extends HTTPContext { + // Thread-safe attribute store using ConcurrentHashMap + private val attributesStore = new ConcurrentHashMap[String, Any]() + def path: String = "" + def resource(path: String): java.net.URL = null + def resourceAsStream(path: String): InputStream = null + def mimeType(path: String): net.liftweb.common.Box[String] = Empty + def initParam(name: String): net.liftweb.common.Box[String] = Empty + def initParams: List[(String, String)] = Nil + def attribute(name: String): net.liftweb.common.Box[Any] = Box(Option(attributesStore.get(name))) + def attributes: List[(String, Any)] = attributesStore.asScala.toList + def setAttribute(name: String, value: Any): Unit = attributesStore.put(name, value) + def removeAttribute(name: String): Unit = attributesStore.remove(name) + } + + private object Http4sLiftProvider extends HTTPProvider { + override protected def context: HTTPContext = Http4sLiftContext + } + + private final class Http4sLiftSession(val sessionId: String) extends HTTPSession { + private val attributesStore = scala.collection.mutable.Map.empty[String, Any] + private var maxInactive: Long = 0L + private val createdAt: Long = System.currentTimeMillis() + def link(liftSession: LiftSession): Unit = () + def unlink(liftSession: LiftSession): Unit = () + def maxInactiveInterval: Long = maxInactive + def setMaxInactiveInterval(interval: Long): Unit = { maxInactive = interval } + def lastAccessedTime: Long = createdAt + def setAttribute(name: String, value: Any): Unit = attributesStore.update(name, value) + def attribute(name: String): Any = attributesStore.getOrElse(name, null) + def removeAttribute(name: String): Unit = attributesStore.remove(name) + def terminate: Unit = () + } + + private final class Http4sLiftRequest( + req: Request[IO], + body: Array[Byte], + headerParams: List[HTTPParam], + queryParams: List[HTTPParam] + ) extends HTTPRequest { + private val sessionValue = new Http4sLiftSession(UUID.randomUUID().toString) + private val uriPath = req.uri.path.renderString + private val uriQuery = req.uri.query.renderString + private val remoteAddr = req.remoteAddr + def cookies: List[HTTPCookie] = Nil + def provider: HTTPProvider = Http4sLiftProvider + def authType: net.liftweb.common.Box[String] = Empty + def headers(name: String): List[String] = + headerParams.find(_.name.equalsIgnoreCase(name)).map(_.values).getOrElse(Nil) + def headers: List[HTTPParam] = headerParams + def contextPath: String = "" + def context: HTTPContext = Http4sLiftContext + def contentType: net.liftweb.common.Box[String] = req.contentType.map(_.mediaType.toString) + def uri: String = uriPath + def url: String = req.uri.renderString + def queryString: net.liftweb.common.Box[String] = if (uriQuery.nonEmpty) Full(uriQuery) else Empty + def param(name: String): List[String] = req.uri.query.multiParams.getOrElse(name, Nil).toList + def params: List[HTTPParam] = queryParams + def paramNames: List[String] = queryParams.map(_.name).distinct + def session: HTTPSession = sessionValue + def destroyServletSession(): Unit = () + def sessionId: net.liftweb.common.Box[String] = Full(sessionValue.sessionId) + def remoteAddress: String = remoteAddr.map(_.toUriString).getOrElse("") + def remotePort: Int = req.uri.port.getOrElse(0) + def remoteHost: String = remoteAddr.map(_.toUriString).getOrElse("") + def serverName: String = req.uri.host.map(_.value).getOrElse("localhost") + def scheme: String = req.uri.scheme.map(_.value).getOrElse("http") + def serverPort: Int = req.uri.port.getOrElse(0) + def method: String = req.method.name + def suspendResumeSupport_? : Boolean = false + def resumeInfo: Option[(Req, LiftResponse)] = None + def suspend(timeout: Long): RetryState.Value = RetryState.TIMED_OUT + def resume(what: (Req, LiftResponse)): Boolean = false + def inputStream: InputStream = new ByteArrayInputStream(body) + def multipartContent_? : Boolean = contentType.exists(_.toLowerCase.contains("multipart/")) + def extractFiles: List[net.liftweb.http.ParamHolder] = Nil + def locale: net.liftweb.common.Box[Locale] = Empty + def setCharacterEncoding(encoding: String): Unit = () + def snapshot: HTTPRequest = this + def userAgent: net.liftweb.common.Box[String] = header("User-Agent") + } +} diff --git a/obp-api/src/main/scala/code/api/v5_0_0/Http4s500.scala b/obp-api/src/main/scala/code/api/v5_0_0/Http4s500.scala index 3352204cc..8b293aead 100644 --- a/obp-api/src/main/scala/code/api/v5_0_0/Http4s500.scala +++ b/obp-api/src/main/scala/code/api/v5_0_0/Http4s500.scala @@ -4,33 +4,22 @@ import cats.data.{Kleisli, OptionT} import cats.effect._ import code.api.Constant._ import code.api.ResourceDocs1_4_0.SwaggerDefinitionsJSON._ -import code.api.util.APIUtil.{EmptyBody, OBPEndpoint, ResourceDoc, getProductsIsPublic} +import code.api.util.APIUtil.{EmptyBody, ResourceDoc, getProductsIsPublic} import code.api.util.ApiTag._ import code.api.util.ErrorMessages._ import code.api.util.http4s.Http4sRequestAttributes.{EndpointHelpers, RequestOps} import code.api.util.http4s.{ErrorResponseConverter, ResourceDocMiddleware} import code.api.util.{CustomJsonFormats, NewStyle} import code.api.v4_0_0.JSONFactory400 -import code.api.{JsonResponseException, ResponseHeader} import com.github.dwickern.macros.NameOf.nameOf import com.openbankproject.commons.ExecutionContext.Implicits.global import com.openbankproject.commons.dto.GetProductsParam import com.openbankproject.commons.model.{BankId, ProductCode} -import com.openbankproject.commons.util.{ApiVersion, ApiVersionStatus, ReflectUtils, ScannedApiVersion} -import net.liftweb.actor.LAFuture -import net.liftweb.common.{Box, Empty, Full} -import net.liftweb.http.provider._ -import net.liftweb.http.{BasicResponse, InMemoryResponse, InternalServerErrorResponse, LiftResponse, LiftRules, LiftSession, NotFoundResponse, OutputStreamResponse, Req, S, StreamingResponse} +import com.openbankproject.commons.util.{ApiVersion, ApiVersionStatus, ScannedApiVersion} import net.liftweb.json.JsonAST.prettyRender import net.liftweb.json.{Extraction, Formats} import org.http4s._ import org.http4s.dsl.io._ -import org.typelevel.ci.CIString - -import java.io.{ByteArrayInputStream, ByteArrayOutputStream, InputStream} -import java.time.format.DateTimeFormatter -import java.time.{ZoneOffset, ZonedDateTime} -import java.util.{Locale, UUID} import scala.collection.mutable.ArrayBuffer import scala.language.{higherKinds, implicitConversions} @@ -61,7 +50,6 @@ object Http4s500 { object Implementations5_0_0 { val prefixPath = Root / ApiPathZero.toString / implementedInApiVersion.toString - private val prefixPathString = s"/${ApiPathZero.toString}/${implementedInApiVersion.toString}" resourceDocs += ResourceDoc( null, @@ -227,238 +215,6 @@ object Http4s500 { } } - private lazy val liftHandlers: List[(OBPEndpoint, Option[ResourceDoc])] = { - val docs = OBPAPI5_0_0.allResourceDocs - OBPAPI5_0_0.routes.flatMap { route => - val routeDocs = docs.filter(_.partialFunction == route) - if (routeDocs.isEmpty) { - List(OBPAPI5_0_0.apiPrefix(route) -> None) - } else { - val (autoValidateDocs, otherDocs) = routeDocs.partition(OBPAPI5_0_0.isAutoValidate(_, autoValidateAll = true)) - val autoValidateHandlers = autoValidateDocs.toList.map { doc => - OBPAPI5_0_0.apiPrefix(doc.wrappedWithAuthCheck(route)) -> Some(doc) - } - val otherHandlers = otherDocs.headOption.toList.map { doc => - OBPAPI5_0_0.apiPrefix(route) -> Some(doc) - } - autoValidateHandlers ++ otherHandlers - } - } - } - - private def dispatchToLift(req: Request[IO]): IO[Response[IO]] = { - for { - bodyBytes <- req.body.compile.to(Array) - liftReq = buildLiftReq(req, bodyBytes) - liftResp <- IO { - val session = LiftRules.statelessSession.vend.apply(liftReq) - S.init(Full(liftReq), session) { - try { - val matchingHandler = liftHandlers.find { case (handler, _) => - if (liftReq.json_?) { - liftReq.json match { - case net.liftweb.common.Failure(_, _, _) => true - case _ => handler.isDefinedAt(liftReq) - } - } else { - handler.isDefinedAt(liftReq) - } - } - matchingHandler match { - case Some((handler, doc)) => - OBPAPI5_0_0.failIfBadAuthorizationHeader(doc) { - OBPAPI5_0_0.failIfBadJSON(liftReq, handler) - } - case None => - NotFoundResponse() - } - } catch { - case JsonResponseException(jsonResponse) => jsonResponse - case e if e.getClass.getName == "net.liftweb.http.rest.ContinuationException" => - resolveContinuation(e) - } - } - } - http4sResponse <- liftResponseToHttp4s(liftResp) - } yield http4sResponse - } - - private def resolveContinuation(exception: Throwable): LiftResponse = { - val func = - ReflectUtils - .getCallByNameValue(exception, "f") - .asInstanceOf[((=> LiftResponse) => Unit) => Unit] - val future = new LAFuture[LiftResponse] - val satisfy: (=> LiftResponse) => Unit = response => future.satisfy(response) - func(satisfy) - future.get(60 * 1000L).openOr(InternalServerErrorResponse()) - } - - private def buildLiftReq(req: Request[IO], body: Array[Byte]): Req = { - val headers = http4sHeadersToParams(req.headers.headers) - val params = http4sParamsToParams(req.uri.query.multiParams.toList) - val httpRequest = new Http4sLiftRequest( - req = req, - body = body, - headerParams = headers, - queryParams = params - ) - Req( - httpRequest, - LiftRules.statelessRewrite.toList, - Nil, - LiftRules.statelessReqTest.toList, - System.nanoTime() - ) - } - - private def http4sHeadersToParams(headers: List[Header.Raw]): List[HTTPParam] = { - headers - .groupBy(_.name.toString) - .toList - .map { case (name, values) => - HTTPParam(name, values.map(_.value)) - } - } - - private def http4sParamsToParams(params: List[(String, collection.Seq[String])]): List[HTTPParam] = { - params.map { case (name, values) => - HTTPParam(name, values.toList) - } - } - - private def liftResponseToHttp4s(response: LiftResponse): IO[Response[IO]] = { - response.toResponse match { - case InMemoryResponse(data, headers, _, code) => - IO.pure(buildHttp4sResponse(code, data, headers)) - case StreamingResponse(data, onEnd, _, headers, _, code) => - IO { - val bytes = readAllBytes(data.asInstanceOf[InputStream]) - onEnd() - buildHttp4sResponse(code, bytes, headers) - } - case OutputStreamResponse(out, _, headers, _, code) => - IO { - val baos = new ByteArrayOutputStream() - out(baos) - buildHttp4sResponse(code, baos.toByteArray, headers) - } - case basic: BasicResponse => - IO.pure(buildHttp4sResponse(basic.code, Array.emptyByteArray, basic.headers)) - } - } - - private def buildHttp4sResponse(code: Int, body: Array[Byte], headers: List[(String, String)]): Response[IO] = { - val contentTypeHeader = headers.find { case (name, _) => name.equalsIgnoreCase("Content-Type") } - val normalizedHeaders = contentTypeHeader match { - case Some(_) => headers - case None => ("Content-Type", "application/json; charset=utf-8") :: headers - } - val http4sHeaders = Headers( - normalizedHeaders.map { case (name, value) => Header.Raw(CIString(name), value) } - ) - Response[IO]( - status = org.http4s.Status.fromInt(code).getOrElse(org.http4s.Status.InternalServerError) - ).withEntity(body).withHeaders(http4sHeaders) - } - - private def readAllBytes(input: InputStream): Array[Byte] = { - val buffer = new ByteArrayOutputStream() - val chunk = new Array[Byte](4096) - var read = input.read(chunk) - while (read != -1) { - buffer.write(chunk, 0, read) - read = input.read(chunk) - } - buffer.toByteArray - } - - private object Http4sLiftContext extends HTTPContext { - private val attributesStore = scala.collection.mutable.Map.empty[String, Any] - def path: String = "" - def resource(path: String): java.net.URL = null - def resourceAsStream(path: String): InputStream = null - def mimeType(path: String): Box[String] = Empty - def initParam(name: String): Box[String] = Empty - def initParams: List[(String, String)] = Nil - def attribute(name: String): Box[Any] = Box(attributesStore.get(name)) - def attributes: List[(String, Any)] = attributesStore.toList - def setAttribute(name: String, value: Any): Unit = attributesStore.update(name, value) - def removeAttribute(name: String): Unit = attributesStore.remove(name) - } - - private object Http4sLiftProvider extends HTTPProvider { - override protected def context: HTTPContext = Http4sLiftContext - } - - private final class Http4sLiftSession(val sessionId: String) extends HTTPSession { - private val attributesStore = scala.collection.mutable.Map.empty[String, Any] - private var maxInactive: Long = 0L - private val createdAt: Long = System.currentTimeMillis() - def link(liftSession: LiftSession): Unit = () - def unlink(liftSession: LiftSession): Unit = () - def maxInactiveInterval: Long = maxInactive - def setMaxInactiveInterval(interval: Long): Unit = { maxInactive = interval } - def lastAccessedTime: Long = createdAt - def setAttribute(name: String, value: Any): Unit = attributesStore.update(name, value) - def attribute(name: String): Any = attributesStore.getOrElse(name, null) - def removeAttribute(name: String): Unit = attributesStore.remove(name) - def terminate: Unit = () - } - - private final class Http4sLiftRequest( - req: Request[IO], - body: Array[Byte], - headerParams: List[HTTPParam], - queryParams: List[HTTPParam] - ) extends HTTPRequest { - private val sessionValue = new Http4sLiftSession(UUID.randomUUID().toString) - private val uriPath = req.uri.path.renderString - private val uriQuery = req.uri.query.renderString - private val remoteAddr = req.remoteAddr - def cookies: List[HTTPCookie] = Nil - def provider: HTTPProvider = Http4sLiftProvider - def authType: Box[String] = Empty - def headers(name: String): List[String] = - headerParams.find(_.name.equalsIgnoreCase(name)).map(_.values).getOrElse(Nil) - def headers: List[HTTPParam] = headerParams - def contextPath: String = "" - def context: HTTPContext = Http4sLiftContext - def contentType: Box[String] = req.contentType.map(_.mediaType.toString) - def uri: String = uriPath - def url: String = req.uri.renderString - def queryString: Box[String] = if (uriQuery.nonEmpty) Full(uriQuery) else Empty - def param(name: String): List[String] = req.uri.query.multiParams.getOrElse(name, Nil).toList - def params: List[HTTPParam] = queryParams - def paramNames: List[String] = queryParams.map(_.name).distinct - def session: HTTPSession = sessionValue - def destroyServletSession(): Unit = () - def sessionId: Box[String] = Full(sessionValue.sessionId) - def remoteAddress: String = remoteAddr.map(_.toUriString).getOrElse("") - def remotePort: Int = req.uri.port.getOrElse(0) - def remoteHost: String = remoteAddr.map(_.toUriString).getOrElse("") - def serverName: String = req.uri.host.map(_.value).getOrElse("localhost") - def scheme: String = req.uri.scheme.map(_.value).getOrElse("http") - def serverPort: Int = req.uri.port.getOrElse(0) - def method: String = req.method.name - def suspendResumeSupport_? : Boolean = false - def resumeInfo: Option[(Req, LiftResponse)] = None - def suspend(timeout: Long): RetryState.Value = RetryState.TIMED_OUT - def resume(what: (Req, LiftResponse)): Boolean = false - def inputStream: InputStream = new ByteArrayInputStream(body) - def multipartContent_? : Boolean = contentType.exists(_.toLowerCase.contains("multipart/")) - def extractFiles: List[net.liftweb.http.ParamHolder] = Nil - def locale: Box[Locale] = Empty - def setCharacterEncoding(encoding: String): Unit = () - def snapshot: HTTPRequest = this - def userAgent: Box[String] = header("User-Agent") - } - - val proxy: HttpRoutes[IO] = HttpRoutes.of[IO] { - case req if req.uri.path.renderString.startsWith(prefixPathString) => - dispatchToLift(req) - } - val allRoutes: HttpRoutes[IO] = Kleisli[HttpF, Request[IO], Response[IO]] { req: Request[IO] => root(req) @@ -466,55 +222,10 @@ object Http4s500 { .orElse(getBank(req)) .orElse(getProducts(req)) .orElse(getProduct(req)) - .orElse(proxy(req)) } - private def ensureStandardHeaders(routes: HttpRoutes[IO]): HttpRoutes[IO] = { - Kleisli[HttpF, Request[IO], Response[IO]] { req: Request[IO] => - routes.run(req).map { resp => - val now = ZonedDateTime.now(ZoneOffset.UTC).format(DateTimeFormatter.RFC_1123_DATE_TIME) - val existing = resp.headers.headers - def hasHeader(name: String): Boolean = - existing.exists(_.name.toString.equalsIgnoreCase(name)) - val existingCorrelationId = existing - .find(_.name.toString.equalsIgnoreCase(ResponseHeader.`Correlation-Id`)) - .map(_.value) - .getOrElse("") - val correlationId = - Option(existingCorrelationId).map(_.trim).filter(_.nonEmpty) - .orElse(req.headers.get(CIString("X-Request-ID")).map(_.head.value)) - .getOrElse(UUID.randomUUID().toString) - val extraHeaders = List.newBuilder[Header.Raw] - if (existingCorrelationId.trim.isEmpty) { - extraHeaders += Header.Raw(CIString(ResponseHeader.`Correlation-Id`), correlationId) - } - if (!hasHeader("Cache-Control")) { - extraHeaders += Header.Raw(CIString("Cache-Control"), "no-cache, private, no-store") - } - if (!hasHeader("Pragma")) { - extraHeaders += Header.Raw(CIString("Pragma"), "no-cache") - } - if (!hasHeader("Expires")) { - extraHeaders += Header.Raw(CIString("Expires"), now) - } - if (!hasHeader("X-Frame-Options")) { - extraHeaders += Header.Raw(CIString("X-Frame-Options"), "DENY") - } - val headersToAdd = extraHeaders.result() - if (headersToAdd.isEmpty) resp - else { - val filtered = resp.headers.headers.filterNot(h => - h.name.toString.equalsIgnoreCase(ResponseHeader.`Correlation-Id`) && - h.value.trim.isEmpty - ) - resp.copy(headers = Headers(filtered) ++ Headers(headersToAdd)) - } - } - } - } - val allRoutesWithMiddleware: HttpRoutes[IO] = - ensureStandardHeaders(ResourceDocMiddleware.apply(resourceDocs)(allRoutes)) + ResourceDocMiddleware.apply(resourceDocs)(allRoutes) } val wrappedRoutesV500Services: HttpRoutes[IO] = Implementations5_0_0.allRoutesWithMiddleware diff --git a/obp-api/src/test/scala/code/api/v5_0_0/Http4sLiftBridgeParityTest.scala b/obp-api/src/test/scala/code/api/v5_0_0/Http4sLiftBridgeParityTest.scala new file mode 100644 index 000000000..5b07cf266 --- /dev/null +++ b/obp-api/src/test/scala/code/api/v5_0_0/Http4sLiftBridgeParityTest.scala @@ -0,0 +1,108 @@ +package code.api.v5_0_0 + +import cats.effect.IO +import cats.effect.unsafe.implicits.global +import code.api.ResponseHeader +import code.api.berlin.group.ConstantsBG +import code.api.util.APIUtil.OAuth._ +import code.api.util.http4s.Http4sLiftWebBridge +import net.liftweb.json.JValue +import net.liftweb.json.JsonAST.JObject +import net.liftweb.json.JsonParser.parse +import org.http4s.{Header, Headers, Method, Request, Status, Uri} +import org.scalatest.Tag +import org.typelevel.ci.CIString + +class Http4sLiftBridgeParityTest extends V500ServerSetup { + + object Http4sLiftBridgeParityTag extends Tag("Http4sLiftBridgeParity") + + private val http4sRoutes = Http4sLiftWebBridge.withStandardHeaders(Http4sLiftWebBridge.routes).orNotFound + + private def toHttp4sRequest(reqData: ReqData): Request[IO] = { + val method = Method.fromString(reqData.method).getOrElse(Method.GET) + val base = Request[IO](method = method, uri = Uri.unsafeFromString(reqData.url)) + val withHeaders = reqData.headers.foldLeft(base) { case (req, (key, value)) => + req.putHeaders(Header.Raw(CIString(key), value)) + } + if (reqData.body.trim.nonEmpty) withHeaders.withEntity(reqData.body) else withHeaders + } + + private def runHttp4s(reqData: ReqData): (Status, JValue, Headers) = { + val response = http4sRoutes.run(toHttp4sRequest(reqData)).unsafeRunSync() + val body = response.as[String].unsafeRunSync() + val json = if (body.trim.isEmpty) JObject(Nil) else parse(body) + (response.status, json, response.headers) + } + + private def hasField(json: JValue, key: String): Boolean = { + json match { + case JObject(fields) => fields.exists(_.name == key) + case _ => false + } + } + + private def jsonKeys(json: JValue): Set[String] = { + json match { + case JObject(fields) => fields.map(_.name).toSet + case _ => Set.empty + } + } + + private def jsonKeysLower(json: JValue): Set[String] = { + jsonKeys(json).map(_.toLowerCase) + } + + private def assertCorrelationId(headers: Headers): Unit = { + val header = headers.headers.find(_.name.toString.equalsIgnoreCase(ResponseHeader.`Correlation-Id`)) + header.isDefined shouldBe true + header.map(_.value.trim.nonEmpty).getOrElse(false) shouldBe true + } + + feature("Http4s Lift bridge parity across versions and auth") { + + scenario("legacy v2.0.0 banks parity", Http4sLiftBridgeParityTag) { + val liftResponse = makeGetRequest((baseRequest / "obp" / "v2.0.0" / "banks").GET) + val reqData = extractParamsAndHeaders((baseRequest / "obp" / "v2.0.0" / "banks").GET, "", "") + val (http4sStatus, http4sJson, http4sHeaders) = runHttp4s(reqData) + + liftResponse.code should equal(http4sStatus.code) + hasField(http4sJson, "banks") shouldBe true + assertCorrelationId(http4sHeaders) + } + + scenario("UK Open Banking accounts parity", Http4sLiftBridgeParityTag) { + val liftReq = (baseRequest / "open-banking" / "v2.0" / "accounts").GET <@(user1) + val liftResponse = makeGetRequest(liftReq) + val reqData = extractParamsAndHeaders(liftReq, "", "") + val (http4sStatus, http4sJson, http4sHeaders) = runHttp4s(reqData) + + liftResponse.code should equal(http4sStatus.code) + assertCorrelationId(http4sHeaders) + } + + scenario("Berlin Group accounts parity", Http4sLiftBridgeParityTag) { + val berlinPath = ConstantsBG.berlinGroupVersion1.apiShortVersion.split("/").toList + val base = berlinPath.foldLeft(baseRequest) { case (req, part) => req / part } + val liftReq = (base / "accounts").GET <@(user1) + val liftResponse = makeGetRequest(liftReq) + val reqData = extractParamsAndHeaders(liftReq, "", "") + val (http4sStatus, http4sJson, http4sHeaders) = runHttp4s(reqData) + + liftResponse.code should equal(http4sStatus.code) + // Berlin Group responses can differ in top-level keys while still being valid. + assertCorrelationId(http4sHeaders) + } + + scenario("DirectLogin parity", Http4sLiftBridgeParityTag) { + val liftReq = (baseRequest / "my" / "logins" / "direct").POST + val liftResponse = makePostRequest(liftReq, "") + val reqData = extractParamsAndHeaders(liftReq, "", "") + val (http4sStatus, http4sJson, http4sHeaders) = runHttp4s(reqData) + + liftResponse.code should equal(http4sStatus.code) + (hasField(http4sJson, "error") || hasField(http4sJson, "message")) shouldBe true + assertCorrelationId(http4sHeaders) + } + } +}