diff --git a/src/main/scala/code/api/OBPRestHelper.scala b/src/main/scala/code/api/OBPRestHelper.scala index 93327887d..170c81c7a 100644 --- a/src/main/scala/code/api/OBPRestHelper.scala +++ b/src/main/scala/code/api/OBPRestHelper.scala @@ -130,7 +130,7 @@ trait OBPRestHelper extends RestHelper with Loggable { } } - def failIfBadOauth(fn: (Box[User]) => Box[JsonResponse]) : JsonResponse = { + def failIfBadAuthorizationHeader(fn: (Box[User]) => Box[JsonResponse]) : JsonResponse = { if (isThereAnOAuthHeader) { getUser match { case Full(u) => fn(Full(u)) @@ -141,7 +141,10 @@ trait OBPRestHelper extends RestHelper with Loggable { } else if (Props.getBool("allow_direct_login", true) && isThereDirectLoginHeader) { DirectLogin.getUser match { case Full(u) => fn(Full(u)) - case _ => errorJsonResponse("directlogin error") + case _ => { + var (httpCode, message, directLoginParameters) = DirectLogin.validator("protectedResource", DirectLogin.getHttpMethod) + errorJsonResponse(message, httpCode) + } } } else { fn(Empty) @@ -189,7 +192,7 @@ trait OBPRestHelper extends RestHelper with Loggable { //if request is correct json //if request matches PartialFunction cases for each defined url //if request has correct oauth headers - failIfBadOauth { + failIfBadAuthorizationHeader { failIfBadJSON(r, handler) } } diff --git a/src/main/scala/code/api/directlogin.scala b/src/main/scala/code/api/directlogin.scala index 2d4d45688..2b5f7f416 100644 --- a/src/main/scala/code/api/directlogin.scala +++ b/src/main/scala/code/api/directlogin.scala @@ -37,11 +37,11 @@ import net.liftweb.http._ import net.liftweb.http.rest.RestHelper import net.liftweb.json.Extraction import net.liftweb.mapper.By -import net.liftweb.util.{Props, Helpers} +import net.liftweb.util.{Helpers, Props} import net.liftweb.util.Helpers._ import scala.compat.Platform -import code.api.util.ErrorMessages +import code.api.util.{APIUtil, ErrorMessages} /** * This object provides the API calls necessary to @@ -94,7 +94,7 @@ object DirectLogin extends RestHelper with Loggable { var (httpCode, message, directLoginParameters) = validator("authorizationToken", getHttpMethod) if (httpCode == 200) { - val userId = getUser(directLoginParameters) + val userId:Long = (for {id <- getUserId(directLoginParameters)} yield id).getOrElse(0) if (userId == 0) { message = ErrorMessages.InvalidLoginCredentials @@ -129,14 +129,6 @@ object DirectLogin extends RestHelper with Loggable { } - // TODO remove duplication with OAuth1.0a version of this? - def registeredApplication(consumerKey: String): Boolean = { - Consumer.find(By(Consumer.key, consumerKey)) match { - case Full(application) => application.isActive - case _ => false - } - } - //Check if the request (access token or request token) is valid and return a tuple def validator(requestType : String, httpMethod : String) : (Int, String, Map[String,String]) = { //return a Map containing the directLogin parameters : prameter -> value @@ -182,11 +174,8 @@ object DirectLogin extends RestHelper with Loggable { } case _ => Map("error" -> "request incorrect") } - } - - def validAccessToken(tokenKey: String) = { Token.find(By(Token.key, tokenKey), By(Token.tokenType, TokenType.Access)) match { case Full(token) => token.isValid @@ -219,18 +208,18 @@ object DirectLogin extends RestHelper with Loggable { } else if ( requestType == "protectedResource" && - !validAccessToken(parameters.get("token").getOrElse("")) + ! validAccessToken(parameters.getOrElse("token", "")) ) { - message = ErrorMessages.DirectLoginInvalidToken + parameters.get("token").getOrElse("") + message = ErrorMessages.DirectLoginInvalidToken + parameters.getOrElse("token", "") httpCode = 401 } //check if the application is registered and active else if ( requestType == "authorizationToken" && Props.getBool("direct_login_consumer_key_mandatory", true) && - !registeredApplication(parameters.get("consumer_key").getOrElse(""))) { + ! APIUtil.registeredApplication(parameters.getOrElse("consumer_key", ""))) { - logger.error("application: " + parameters.get("consumer_key").getOrElse("") + " not found") + logger.error("application: " + parameters.getOrElse("consumer_key", "") + " not found") message = ErrorMessages.InvalidConsumerKey httpCode = 401 } @@ -257,7 +246,7 @@ object DirectLogin extends RestHelper with Loggable { import code.model.{Token, TokenType} val token = Token.create token.tokenType(TokenType.Access) - Consumer.find(By(Consumer.key, directLoginParameters.get("consumer_key").getOrElse(""))) match { + Consumer.find(By(Consumer.key, directLoginParameters.getOrElse("consumer_key", ""))) match { case Full(consumer) => token.consumerId(consumer.id) case _ => None } @@ -279,56 +268,52 @@ object DirectLogin extends RestHelper with Loggable { case _ => "GET" } val (httpCode, message, directLoginParameters) = validator("protectedResource", httpMethod) - val user = getUser(200, if (directLoginParameters.isDefinedAt("token")) - directLoginParameters.get("token") - else - Empty - ) - if (user != Empty ) { - val res = Full(user.get) - res - } else { + val user = for { + u <- getUserFromToken(if (directLoginParameters.isDefinedAt("token")) directLoginParameters.get("token") else Empty) + } yield u + + if (user.isEmpty ) ParamFailure(message, Empty, Empty, APIFailure(message, httpCode)) - } + else + user } - private def getUser(directLoginParameters: Map[String, String]): Long = { - val username = directLoginParameters.get("username").getOrElse("").toString - val password = directLoginParameters.get("password").getOrElse("") match { - case p: String => p - case _ => "error" - } - var userId:Long = OBPUser.getUserId(username, password).getOrElse(0) - if (userId == 0) { - OBPUser.externalUserHelper(directLoginParameters.getOrElse("username", ""), directLoginParameters.getOrElse("password", "")) - userId = OBPUser.getUserId(username, password).getOrElse(0) + + private def getUserId(directLoginParameters: Map[String, String]): Box[Long] = { + val username = directLoginParameters.getOrElse("username", "") + val password = directLoginParameters.getOrElse("password", "") + + var userId = for {id <- OBPUser.getUserId(username, password)} yield id + + if (userId.isEmpty) { + OBPUser.externalUserHelper(username, password) + userId = for {id <- OBPUser.getUserId(username, password)} yield id } + userId } - def getUser(httpCode : Int, tokenID : Box[String]) : Box[User] = - if(httpCode==200) - { - import code.model.Token - logger.info("DirectLogin header correct ") - Token.find(By(Token.key, tokenID.getOrElse(""))) match { - case Full(token) => { - logger.info("access token: "+ token + " found") - val user = token.user - //just a log - user match { - case Full(u) => logger.info("user " + u.emailAddress + " was found from the DirectLogin token") - case _ => logger.info("no user was found for the DirectLogin token") - } - user - } - case _ =>{ - logger.warn("no token " + tokenID.getOrElse("") + " found") - Empty + + def getUserFromToken(tokenID : Box[String]) : Box[User] = { + logger.info("DirectLogin header correct ") + Token.find(By(Token.key, tokenID.getOrElse(""))) match { + case Full(token) => { + logger.info("access token: " + token + " found") + val user = token.user + //just a log + user match { + case Full(u) => logger.info("user " + u.emailAddress + " was found from the DirectLogin token") + case _ => logger.info("no user was found for the DirectLogin token") } + user + } + case _ => { + logger.warn("no token " + tokenID.getOrElse("") + " found") + Empty } } - else - Empty + } + + } diff --git a/src/main/scala/code/api/oauth1.0.scala b/src/main/scala/code/api/oauth1.0.scala index 4791c725f..88d12ec3f 100644 --- a/src/main/scala/code/api/oauth1.0.scala +++ b/src/main/scala/code/api/oauth1.0.scala @@ -30,17 +30,20 @@ import net.liftweb.http.Req import net.liftweb.http.PostRequest import net.liftweb.common.Box import net.liftweb.http.InMemoryResponse -import net.liftweb.common.{Full,Empty,Loggable} +import net.liftweb.common.{Empty, Full, Loggable} import net.liftweb.http.S -import code.model.{Nonce, Consumer, Token} +import code.model.{Consumer, Nonce, Token} import net.liftweb.mapper.By import java.util.Date -import java.net.{URLEncoder, URLDecoder} +import java.net.{URLDecoder, URLEncoder} import javax.crypto.spec.SecretKeySpec import javax.crypto.Mac + import net.liftweb.util.Helpers + import scala.compat.Platform import Helpers._ +import code.api.util.APIUtil import net.liftweb.util.Props import code.model.TokenType import code.model.User @@ -224,13 +227,6 @@ object OAuthHandshake extends RestHelper with Loggable { ) !=0 } - def registeredApplication(consumerKey : String ) : Boolean = { - Consumer.find(By(Consumer.key,consumerKey)) match { - case Full(application) => application.isActive - case _ => false - } - } - def correctSignature(OAuthparameters : Map[String, String], httpMethod : String) = { //Normalize an encode the request parameters as explained in Section 3.4.1.3.2 //of OAuth 1.0 specification (http://tools.ietf.org/html/rfc5849) @@ -362,7 +358,7 @@ object OAuthHandshake extends RestHelper with Loggable { httpCode = 400 } //check if the application is registered and active - else if(! registeredApplication(parameters.get("oauth_consumer_key").get)) + else if(! APIUtil.registeredApplication(parameters.get("oauth_consumer_key").get)) { logger.error("application: " + parameters.get("oauth_consumer_key").get + " not found") message = "Invalid consumer credentials" diff --git a/src/main/scala/code/api/util/APIUtil.scala b/src/main/scala/code/api/util/APIUtil.scala index d8bf013c7..857ea5835 100644 --- a/src/main/scala/code/api/util/APIUtil.scala +++ b/src/main/scala/code/api/util/APIUtil.scala @@ -47,6 +47,7 @@ import net.liftweb.http.js.JsExp import net.liftweb.http.{CurrentReq, JsonResponse, Req, S} import net.liftweb.json.JsonAST.JValue import net.liftweb.json.{Extraction, parse} +import net.liftweb.mapper.By import net.liftweb.util.Helpers._ import net.liftweb.util.{Helpers, Props, SecurityHelpers} @@ -166,6 +167,14 @@ object APIUtil extends Loggable { } } + def registeredApplication(consumerKey: String): Boolean = { + println(Consumer.findAll()) + Consumer.find(By(Consumer.key, consumerKey)) match { + case Full(application) => application.isActive + case _ => false + } + } + def logAPICall = { if(Props.getBool("write_metrics", false)) { val user =