summaryrefslogtreecommitdiff
path: root/cask/src/cask/router/Macros.scala
diff options
context:
space:
mode:
Diffstat (limited to 'cask/src/cask/router/Macros.scala')
-rw-r--r--cask/src/cask/router/Macros.scala178
1 files changed, 178 insertions, 0 deletions
diff --git a/cask/src/cask/router/Macros.scala b/cask/src/cask/router/Macros.scala
new file mode 100644
index 0000000..de27e5c
--- /dev/null
+++ b/cask/src/cask/router/Macros.scala
@@ -0,0 +1,178 @@
+package cask.router
+
+import scala.reflect.macros.blackbox
+
+
+class Macros[C <: blackbox.Context](val c: C) {
+ import c.universe._
+ def getValsOrMeths(curCls: Type): Iterable[MethodSymbol] = {
+ def isAMemberOfAnyRef(member: Symbol) = {
+ // AnyRef is an alias symbol, we go to the real "owner" of these methods
+ val anyRefSym = c.mirror.universe.definitions.ObjectClass
+ member.owner == anyRefSym
+ }
+ val extractableMembers = for {
+ member <- curCls.members.toList.reverse
+ if !isAMemberOfAnyRef(member)
+ if !member.isSynthetic
+ if member.isPublic
+ if member.isTerm
+ memTerm = member.asTerm
+ if memTerm.isMethod
+ if !memTerm.isModule
+ } yield memTerm.asMethod
+
+ extractableMembers flatMap { case memTerm =>
+ if (memTerm.isSetter || memTerm.isConstructor || memTerm.isGetter) Nil
+ else Seq(memTerm)
+
+ }
+ }
+
+
+
+ def unwrapVarargType(arg: Symbol) = {
+ val vararg = arg.typeSignature.typeSymbol == definitions.RepeatedParamClass
+ val unwrappedType =
+ if (!vararg) arg.typeSignature
+ else arg.typeSignature.asInstanceOf[TypeRef].args(0)
+
+ (vararg, unwrappedType)
+ }
+
+ def extractMethod(method: MethodSymbol,
+ curCls: c.universe.Type,
+ convertToResultType: c.Tree,
+ ctx: c.Tree,
+ argReaders: Seq[c.Tree],
+ annotDeserializeTypes: Seq[c.Tree]): c.universe.Tree = {
+ val baseArgSym = TermName(c.freshName())
+
+ def getDocAnnotation(annotations: List[Annotation]) = {
+ val (docTrees, remaining) = annotations.partition(_.tpe =:= typeOf[doc])
+ val docValues = for {
+ doc <- docTrees
+ if doc.scalaArgs.head.isInstanceOf[Literal]
+ l = doc.scalaArgs.head.asInstanceOf[Literal]
+ if l.value.value.isInstanceOf[String]
+ } yield l.value.value.asInstanceOf[String]
+ (remaining, docValues.headOption)
+ }
+ val (_, methodDoc) = getDocAnnotation(method.annotations)
+ val argValuesSymbol = q"${c.fresh[TermName]("argValues")}"
+ val argSigsSymbol = q"${c.fresh[TermName]("argSigs")}"
+ val ctxSymbol = q"${c.fresh[TermName]("ctx")}"
+ val argData = for(argListIndex <- method.paramLists.indices) yield{
+ val annotDeserializeType = annotDeserializeTypes.lift(argListIndex).getOrElse(tq"scala.Any")
+ val argReader = argReaders.lift(argListIndex).getOrElse(q"cask.router.NoOpParser.instanceAny")
+ val flattenedArgLists = method.paramss(argListIndex)
+ def hasDefault(i: Int) = {
+ val defaultName = s"${method.name}$$default$$${i + 1}"
+ if (curCls.members.exists(_.name.toString == defaultName)) Some(defaultName)
+ else None
+ }
+
+ val defaults = for (i <- flattenedArgLists.indices) yield {
+ val arg = TermName(c.freshName())
+ hasDefault(i).map(defaultName => q"($arg: $curCls) => $arg.${newTermName(defaultName)}")
+ }
+
+ val readArgSigs = for (
+ ((arg, defaultOpt), i) <- flattenedArgLists.zip(defaults).zipWithIndex
+ ) yield {
+
+ if (arg.typeSignature.typeSymbol == definitions.RepeatedParamClass) c.abort(method.pos, "Varargs are not supported in cask routes")
+
+ val default = defaultOpt match {
+ case Some(defaultExpr) => q"scala.Some($defaultExpr($baseArgSym))"
+ case None => q"scala.None"
+ }
+
+ val (docUnwrappedType, docOpt) = arg.typeSignature match {
+ case t: AnnotatedType =>
+ import compat._
+ val (remaining, docValue) = getDocAnnotation(t.annotations)
+ if (remaining.isEmpty) (t.underlying, docValue)
+ else (c.universe.AnnotatedType(remaining, t.underlying), docValue)
+
+ case t => (t, None)
+ }
+
+ val docTree = docOpt match {
+ case None => q"scala.None"
+ case Some(s) => q"scala.Some($s)"
+ }
+
+ val argSig =
+ q"""
+ cask.router.ArgSig[$annotDeserializeType, $curCls, $docUnwrappedType, $ctx](
+ ${arg.name.toString},
+ ${docUnwrappedType.toString},
+ $docTree,
+ $defaultOpt
+ )($argReader[$docUnwrappedType])
+ """
+
+ val reader = q"""
+ cask.router.Runtime.makeReadCall(
+ $argValuesSymbol($argListIndex),
+ $ctxSymbol,
+ $default,
+ $argSigsSymbol($argListIndex)($i)
+ )
+ """
+
+ c.internal.setPos(reader, method.pos)
+ (reader, argSig)
+ }
+
+ val (readArgs, argSigs) = readArgSigs.unzip
+ val (argNames, argNameCasts) = flattenedArgLists.map { arg =>
+ val (vararg, unwrappedType) = unwrapVarargType(arg)
+ (
+ pq"${arg.name.toTermName}",
+ if (!vararg) q"${arg.name.toTermName}.asInstanceOf[$unwrappedType]"
+ else q"${arg.name.toTermName}.asInstanceOf[Seq[$unwrappedType]]: _*"
+
+ )
+ }.unzip
+
+ (argNameCasts, argSigs, argNames, readArgs)
+ }
+
+ val argNameCasts = argData.map(_._1)
+ val argSigs = argData.map(_._2)
+ val argNames = argData.map(_._3)
+ val readArgs = argData.map(_._4)
+ var methodCall: c.Tree = q"$baseArgSym.${method.name.toTermName}"
+ for(argNameCast <- argNameCasts) methodCall = q"$methodCall(..$argNameCast)"
+
+ val res = q"""
+ cask.router.EntryPoint[$curCls, $ctx](
+ ${method.name.toString},
+ ${argSigs.toList},
+ ${methodDoc match{
+ case None => q"scala.None"
+ case Some(s) => q"scala.Some($s)"
+ }},
+ (
+ $baseArgSym: $curCls,
+ $ctxSymbol: $ctx,
+ $argValuesSymbol: Seq[Map[String, Any]],
+ $argSigsSymbol: scala.Seq[scala.Seq[cask.router.ArgSig[Any, _, _, $ctx]]]
+ ) =>
+ cask.router.Runtime.validate(Seq(..${readArgs.flatten.toList})).map{
+ case Seq(..${argNames.flatten.toList}) => $convertToResultType($methodCall)
+ }
+ )
+ """
+
+ c.internal.transform(res){(t, a) =>
+ c.internal.setPos(t, method.pos)
+ a.default(t)
+ }
+
+ res
+ }
+
+}