feature/add_required_field_annotation : refactor OBPRequired annotation extraction related code.

This commit is contained in:
shuang 2020-01-23 18:03:03 +08:00
parent d70939d19a
commit f29894206c
4 changed files with 69 additions and 56 deletions

View File

@ -108,9 +108,6 @@ object ApiVersion {
* and affect the follow OBP Standard versions.
* @param apiPathZero
*/
def setUrlPrefix(apiPathZero: String) = {
val urlPrefixField = classOf[ScannedApiVersion].getDeclaredField("urlPrefix")
urlPrefixField.setAccessible(true)
standardVersions.foreach(urlPrefixField.set(_, apiPathZero))
}
def setUrlPrefix(apiPathZero: String): Unit =
standardVersions.foreach(ReflectUtils.setField(_, "urlPrefix", apiPathZero))
}

View File

@ -19,6 +19,9 @@ object Functions {
case _ if false => ???
}
def doNothingFn[T](t: T): Unit = ()
def doNothingFn[T, D](t: T, d: D): Unit = ()
def truePredicate[T]: T => Boolean = _ => true
def falsePredicate[T]: T => Boolean = _ => false

View File

@ -30,20 +30,31 @@ object ReflectUtils {
* @param fn a callback to operate field, default value is do nothing
* @return the given value given field original value
*/
private def operateField[T](obj: AnyRef, fieldName: String, fn: ru.FieldMirror => Unit = _=>()): T = {
private def operateField[T](obj: AnyRef, fieldName: String)(fn: (InstanceMirror, TermSymbol) => Unit): T = {
val instanceMirror: ru.InstanceMirror = mirror.reflect(obj)
val fieldTerm: ru.TermName = ru.TermName(fieldName)
val field: ru.Symbol = getType(obj).member(fieldTerm)
if(field.isMethod) {// the field is a lazy val
val tp = getType(obj)
def isFieldOrCallByPath(term: ru.TermSymbol) = {
term.name.decodedName.toString.trim == fieldName &&
(term.isVal || term.isVal || term.isLazy || (term.isMethod && term.asMethod.paramLists.isEmpty))
}
val fields: Iterable[ru.TermSymbol] = tp.members.collect({
case term: TermSymbol if isFieldOrCallByPath(term) => term
})
assert(fields.nonEmpty, s"${tp.typeSymbol.fullName} have not field kind member '$fieldName'")
val field = fields.find(it => it.isVal || it.isVar).getOrElse(fields.head)
val result: T = if(field.isVal || field.isVar) {
val fieldMirror: ru.FieldMirror = instanceMirror.reflectField(field)
val originValue = fieldMirror.get
originValue.asInstanceOf[T]
} else {// the field is a lazy val or call by name or empty param list method
val method = field.asMethod
instanceMirror.reflectMethod(method).apply().asInstanceOf[T]
} else {
val fieldSymbol: ru.TermSymbol = field.asTerm.accessed.asTerm
val fieldMirror: ru.FieldMirror = instanceMirror.reflectField(fieldSymbol)
val originValue = fieldMirror.get
fn(fieldMirror)
originValue.asInstanceOf[T]
}
fn(instanceMirror, field)
result
}
def getFieldValues(obj: AnyRef)(predicate: TermSymbol => Boolean = _=>true): Map[String, Any] = {
@ -89,7 +100,7 @@ object ReflectUtils {
* @param fieldName field name
* @return the field value of obj
*/
def getField(obj: AnyRef, fieldName: String): Any = operateField[Any](obj, fieldName)
def getField(obj: AnyRef, fieldName: String): Any = operateField[Any](obj, fieldName)(Functions.doNothingFn)
/**
* according object name get corresponding field value
@ -126,7 +137,10 @@ object ReflectUtils {
* @tparam T field type
* @return the original field value
*/
def setField[T](obj: AnyRef, fieldName: String, fieldValue: T): T = operateField[T](obj, fieldName, _.set(fieldValue))
def setField[T](obj: AnyRef, fieldName: String, fieldValue: T): T = operateField[T](obj, fieldName) { (instanceMirror, term) =>
assert(term.isVal || term.isVar, s"${obj.getClass.getName} have no field name is '$fieldName'")
instanceMirror.reflectField(term).set(fieldValue)
}
/**
* modify given instance nested fields value

View File

@ -23,7 +23,7 @@ import net.liftweb.json.JsonDSL._
* > @OBPRequired(Array(ApiVersion.v3_0_0, ApiVersion.v4_1_0))
*
* required for all versions except some versions: [-v3_0_0, -v4_1_0]
* > @OBPRequired(include=Array(ApiVersion.allVersion), exclude=Array(ApiVersion.v3_0_0, ApiVersion.v4_1_0))
* > @OBPRequired(value=Array(ApiVersion.allVersion), exclude=Array(ApiVersion.v3_0_0, ApiVersion.v4_1_0))
*
* Note: The include and exclude parameter should not change order, because this is not a real class, it is annotation, scala's
* annotation not allowed switch parameter's order as these:
@ -113,7 +113,7 @@ case class RequiredArgs(fieldPath:String, include: Array[ApiVersion],
}
val apiVersions: List[String] = (include, exclude) match {
case (_, Array()) => include.toList.map(_.toString)
case _ => include.toList.map("-" + _.toString)
case _ => exclude.toList.map("-" + _.toString)
}
}
@ -153,52 +153,51 @@ object RequiredFieldValidation {
case _ => throw new IllegalArgumentException(s"$OBP_REQUIRED_NAME's parameter not correct.")
}
/**
* get all field name to OBPRequired annotation info
* @param tp to process type
* @return map of field name to RequiredArgs
*/
def getAnnotations(tp: Type): Iterable[RequiredArgs] = {
val members = tp.members
val constructors = members.filter(_.isConstructor).map(_.asMethod)
def isField(symbol: TermSymbol): Boolean =
symbol.isVal || symbol.isVal || symbol.isLazy || (symbol.isMethod && symbol.asMethod.paramLists.isEmpty)
def getFieldNameAndAnnotation(symbol: Symbol): Option[RequiredArgs] = {
val fieldName = symbol.name.decodedName.toString.trim
getAnnotation(fieldName, symbol) match{
case some: Some[RequiredArgs] => some
case _ => None
}
}
// constructor param name to RequiredArgs
val constructorParamToRequiredArgs: Iterable[RequiredArgs] = constructors
.flatMap(_.paramLists.head) // all the constructor's parameters
.map(getFieldNameAndAnnotation)
.collect {
case Some(requiredArgs) => requiredArgs
}
val constructorParamNames = constructorParamToRequiredArgs.map(_.fieldPath).toSet
// those annotated field name to RequiredArgs
val annotatedFieldNameToRequiredArgs: Iterable[RequiredArgs] =
members
.filter(it => {
!it.isConstructor && !constructorParamNames.contains(it.name.decodedName.toString.trim)
// constructor's parameters and fields
val members: Iterable[Symbol] =
tp.decls.filter(_.isConstructor).flatMap(_.asMethod.paramLists.head) ++
tp.members
.collect({
case t: TermSymbol if isField(t) => t
})
.map(getFieldNameAndAnnotation)
.collect {
case Some(requiredArgs) => requiredArgs
}
.distinctBy(_.fieldPath)
constructorParamToRequiredArgs ++ annotatedFieldNameToRequiredArgs
val directAnnotated = members.map(member => getAnnotation(member.name.decodedName.toString.trim, member, false))
.collect({case Some(requiredArgs) => requiredArgs})
.distinctBy(_.fieldPath)
val directAnnotatedNames = directAnnotated.map(_.fieldPath).toSet
val inDirectAnnotated = members.collect({
case member if !directAnnotatedNames.contains(member.name.decodedName.toString.trim) =>
getAnnotation(member.name.decodedName.toString.trim, member, true)
})
.collect({case Some(requiredArgs) => requiredArgs})
.distinctBy(_.fieldPath)
directAnnotated ++ inDirectAnnotated
}
def getAnnotation(fieldName: String, symbol: Symbol): Option[RequiredArgs] = {
private def getAnnotation(fieldName: String, symbol: Symbol, findOverrides: Boolean): Option[RequiredArgs] = {
val annotation: Option[Annotation] =
(symbol :: symbol.overrides)
.flatMap(_.annotations)
.find(_.tree.tpe <:< typeOf[OBPRequired])
if(findOverrides) {
symbol.overrides.
flatMap(_.annotations)
.find(_.tree.tpe <:< typeOf[OBPRequired])
} else {
symbol.annotations
.find(_.tree.tpe <:< typeOf[OBPRequired])
}
annotation.map { it: Annotation =>
it.tree.children.tail match {
@ -226,13 +225,13 @@ object RequiredFieldValidation {
// find all sub fields RequiredInfo
val subPathToRequiredInfo: Iterable[RequiredArgs] = tp.members.collect {
case m: MethodSymbolApi if m.isGetter => {
(m.name.decodedName.toString.trim, ReflectUtils.getNestFirstTypeArg(m.returnType))
case m: TermSymbol if m.isLazy=> {
(m.name.decodedName.toString.trim, ReflectUtils.getNestFirstTypeArg(m.asMethod.returnType))
}
case m: TermSymbolApi if m.isCaseAccessor || m.isVal => {
case m: TermSymbol if m.isVal || m.isVar => {
(m.name.decodedName.toString.trim, ReflectUtils.getNestFirstTypeArg(m.info))
}
} .filter(tuple => predicate(tuple._2))
} .collect({case tuple @(_, fieldType) if predicate(fieldType) => tuple})
.distinctBy(_._1)
.flatMap(pair => {
val (memberName, membersType) = pair