summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLi Haoyi <haoyi.sg@gmail.com>2018-07-25 23:36:44 +0800
committerLi Haoyi <haoyi.sg@gmail.com>2018-07-25 23:36:44 +0800
commit712bafb0c903a14dc0bf6b07e5529007635e004a (patch)
tree9770624d1ed63346ddb120fc6aa4182330694a72
parentf88c2941efdd8a1bc8f4ea7362c2163918c44a1c (diff)
downloadcask-712bafb0c903a14dc0bf6b07e5529007635e004a.tar.gz
cask-712bafb0c903a14dc0bf6b07e5529007635e004a.tar.bz2
cask-712bafb0c903a14dc0bf6b07e5529007635e004a.zip
big refactoring in preparation for allowing endpoint filters
-rw-r--r--cask/src/cask/endpoints/FormEndpoint.scala19
-rw-r--r--cask/src/cask/endpoints/JsonEndpoint.scala17
-rw-r--r--cask/src/cask/endpoints/StaticEndpoints.scala22
-rw-r--r--cask/src/cask/endpoints/WebEndpoints.scala20
-rw-r--r--cask/src/cask/internal/Router.scala212
-rw-r--r--cask/src/cask/main/ErrorMsgs.scala6
-rw-r--r--cask/src/cask/main/Main.scala19
-rw-r--r--cask/src/cask/main/Routes.scala42
-rw-r--r--cask/src/cask/model/ParamContext.scala5
-rw-r--r--cask/test/src/test/cask/ExampleTests.scala5
-rw-r--r--cask/test/src/test/cask/FormJsonPost.scala2
-rw-r--r--cask/test/src/test/cask/StaticFiles.scala2
12 files changed, 179 insertions, 192 deletions
diff --git a/cask/src/cask/endpoints/FormEndpoint.scala b/cask/src/cask/endpoints/FormEndpoint.scala
index 10d1af3..99699e6 100644
--- a/cask/src/cask/endpoints/FormEndpoint.scala
+++ b/cask/src/cask/endpoints/FormEndpoint.scala
@@ -1,6 +1,5 @@
package cask.endpoints
-import cask.internal.Router.EntryPoint
import cask.internal.Router
import cask.main.Routes
import cask.model.{FormValue, ParamContext, Response}
@@ -38,27 +37,19 @@ object FormReader{
class postForm(val path: String, override val subpath: Boolean = false) extends Routes.Endpoint[Response]{
val methods = Seq("post")
type InputType = Seq[FormValue]
- def wrapMethodOutput(t: Response) = t
def parseMethodInput[T](implicit p: FormReader[T]) = p
- def handle(ctx: ParamContext,
- bindings: Map[String, String],
- routes: Routes,
- entryPoint: EntryPoint[Seq[FormValue], Routes, ParamContext]): Router.Result[Response] = {
+
+ def handle(ctx: ParamContext) = {
val formData = FormParserFactory.builder().build().createParser(ctx.exchange).parseBlocking()
val formDataBindings =
formData
.iterator()
.asScala
.map(k => (k, formData.get(k).asScala.map(FormValue.fromUndertow).toSeq))
-
- val pathBindings =
- bindings.map{case (k, v) => (k, Seq(new FormValue.Plain(v, new io.undertow.util.HeaderMap())))}
-
- val allBindings = pathBindings ++ formDataBindings
-
- entryPoint.invoke(routes, ctx, allBindings)
- .asInstanceOf[Router.Result[Response]]
+ .toMap
+ formDataBindings
}
+ def wrapPathSegment(s: String): InputType = Seq(FormValue.Plain(s, new io.undertow.util.HeaderMap))
}
diff --git a/cask/src/cask/endpoints/JsonEndpoint.scala b/cask/src/cask/endpoints/JsonEndpoint.scala
index f172b2b..f740148 100644
--- a/cask/src/cask/endpoints/JsonEndpoint.scala
+++ b/cask/src/cask/endpoints/JsonEndpoint.scala
@@ -27,19 +27,8 @@ object JsReader{
class postJson(val path: String, override val subpath: Boolean = false) extends Routes.Endpoint[Response]{
val methods = Seq("post")
type InputType = ujson.Js.Value
- def wrapMethodOutput(t: Response) = t
def parseMethodInput[T](implicit p: JsReader[T]) = p
- def handle(ctx: ParamContext,
- bindings: Map[String, String],
- routes: Routes,
- entryPoint: EntryPoint[ujson.Js.Value, Routes, cask.model.ParamContext]): Router.Result[Response] = {
-
- val js = ujson.read(new String(ctx.exchange.getInputStream.readAllBytes())).asInstanceOf[ujson.Js.Obj]
-
- js.obj
- val allBindings = bindings.mapValues(ujson.Js.Str(_))
-
- entryPoint.invoke(routes, ctx, js.obj.toMap ++ allBindings)
- .asInstanceOf[Router.Result[Response]]
- }
+ def handle(ctx: ParamContext) =
+ ujson.read(new String(ctx.exchange.getInputStream.readAllBytes())).obj.toMap
+ def wrapPathSegment(s: String): InputType = ujson.Js.Str(s)
}
diff --git a/cask/src/cask/endpoints/StaticEndpoints.scala b/cask/src/cask/endpoints/StaticEndpoints.scala
index 937b9c2..f644d86 100644
--- a/cask/src/cask/endpoints/StaticEndpoints.scala
+++ b/cask/src/cask/endpoints/StaticEndpoints.scala
@@ -1,9 +1,8 @@
package cask.endpoints
import cask.internal.Router
-import cask.internal.Router.EntryPoint
import cask.main.Routes
-import cask.model.{BaseResponse, ParamContext}
+import cask.model.ParamContext
class static(val path: String) extends Routes.Endpoint[String] {
val methods = Seq("get")
@@ -11,19 +10,10 @@ class static(val path: String) extends Routes.Endpoint[String] {
override def subpath = true
def wrapOutput(t: String) = t
def parseMethodInput[T](implicit p: QueryParamReader[T]) = p
- def wrapMethodOutput(t: String) = t
-
- def handle(ctx: ParamContext,
- bindings: Map[String, String],
- routes: Routes,
- entryPoint: EntryPoint[Seq[String], Routes, cask.model.ParamContext]): Router.Result[BaseResponse] = {
- entryPoint.invoke(routes, ctx, Map()).asInstanceOf[Router.Result[String]] match{
- case Router.Result.Success(s) =>
- Router.Result.Success(cask.model.Static(s + "/" + ctx.remaining.mkString("/")))
-
- case e: Router.Result.Error => e
-
- }
-
+ override def wrapMethodOutput(ctx: ParamContext, t: String) = {
+ Router.Result.Success(cask.model.Static(t + "/" + ctx.remaining.mkString("/")))
}
+
+ def handle(ctx: ParamContext) = Map()
+ def wrapPathSegment(s: String): InputType = Seq(s)
}
diff --git a/cask/src/cask/endpoints/WebEndpoints.scala b/cask/src/cask/endpoints/WebEndpoints.scala
index 14c21ce..9eca964 100644
--- a/cask/src/cask/endpoints/WebEndpoints.scala
+++ b/cask/src/cask/endpoints/WebEndpoints.scala
@@ -10,22 +10,12 @@ import collection.JavaConverters._
trait WebEndpoint extends Routes.Endpoint[BaseResponse]{
type InputType = Seq[String]
- def wrapMethodOutput(t: BaseResponse) = t
def parseMethodInput[T](implicit p: QueryParamReader[T]) = p
- def handle(ctx: ParamContext,
- bindings: Map[String, String],
- routes: Routes,
- entryPoint: EntryPoint[Seq[String], Routes, cask.model.ParamContext]): Router.Result[BaseResponse] = {
- val allBindings =
- bindings.map{case (k, v) => (k, Seq(v))} ++
- ctx.exchange.getQueryParameters
- .asScala
- .toSeq
- .map{case (k, vs) => (k, vs.asScala.toArray.toSeq)}
-
- entryPoint.invoke(routes, ctx, allBindings)
- .asInstanceOf[Router.Result[BaseResponse]]
- }
+ def handle(ctx: ParamContext) = ctx.exchange.getQueryParameters
+ .asScala
+ .map{case (k, vs) => (k, vs.asScala.toArray.toSeq)}
+ .toMap
+ def wrapPathSegment(s: String) = Seq(s)
}
class get(val path: String, override val subpath: Boolean = false) extends WebEndpoint{
val methods = Seq("get")
diff --git a/cask/src/cask/internal/Router.scala b/cask/src/cask/internal/Router.scala
index b89fef3..826b1f6 100644
--- a/cask/src/cask/internal/Router.scala
+++ b/cask/src/cask/internal/Router.scala
@@ -46,18 +46,22 @@ object Router{
* instead, which provides a nicer API to call it that mimmicks the API of
* calling a Scala method.
*/
- case class EntryPoint[I, T, C](name: String,
- argSignatures: Seq[ArgSig[I, T, _, C]],
- doc: Option[String],
- varargs: Boolean,
- invoke0: (T, C, Map[String, I]) => Result[Any]){
- def invoke(target: T, ctx: C, args: Map[String, I]): Result[Any] = {
- val unknown = args.keySet -- argSignatures.map(_.name).toSet
- val missing = argSignatures.filter(as => as.reads.arity != 0 && !args.contains(as.name) && as.default.isEmpty)
+ case class EntryPoint[T, C](name: String,
+ argSignatures: Seq[Seq[ArgSig[_, T, _, C]]],
+ doc: Option[String],
+ invoke0: (T, C, Seq[Map[String, Any]]) => Result[Any]){
+ def invoke(target: T,
+ ctx: C,
+ paramLists: Seq[Map[String, Any]]): Result[Any] = {
+
+ val unknown = paramLists.head.keySet -- argSignatures.head.map(_.name).toSet
+ val missing = argSignatures.head.filter(as =>
+ as.reads.arity != 0 && !paramLists.head.contains(as.name) && as.default.isEmpty
+ )
if (missing.nonEmpty || unknown.nonEmpty) Result.Error.MismatchedArguments(missing, unknown.toSeq)
else {
- try invoke0(target, ctx, args)
+ try invoke0(target, ctx, paramLists)
catch{case e: Throwable => Result.Error.Exception(e)}
}
}
@@ -187,128 +191,142 @@ class Router[C <: Context](val c: C) {
curCls: c.universe.Type,
wrapOutput: c.Tree => c.Tree,
ctx: c.Type,
- argReader: c.Tree,
- annotDeserializeType: c.Tree): c.universe.Tree = {
+ argReaders: Seq[c.Tree],
+ annotDeserializeTypes: Seq[c.Tree]): c.universe.Tree = {
val baseArgSym = TermName(c.freshName())
- val flattenedArgLists = meth.paramss.flatten
- def hasDefault(i: Int) = {
- val defaultName = s"${meth.name}$$default$$${i + 1}"
- if (curCls.members.exists(_.name.toString == defaultName)) Some(defaultName)
- else None
- }
- val argListSymbol = q"${c.fresh[TermName]("argsList")}"
-
- val defaults = for ((arg, i) <- flattenedArgLists.zipWithIndex) yield {
- val arg = TermName(c.freshName())
- hasDefault(i).map(defaultName => q"($arg: $curCls) => $arg.${newTermName(defaultName)}")
- }
def getDocAnnotation(annotations: List[Annotation]) = {
val (docTrees, remaining) = annotations.partition(_.tpe =:= typeOf[Router.doc])
val docValues = for {
doc <- docTrees
if doc.scalaArgs.head.isInstanceOf[Literal]
- l = doc.scalaArgs.head.asInstanceOf[Literal]
+ l = doc.scalaArgs.head.asInstanceOf[Literal]
if l.value.value.isInstanceOf[String]
} yield l.value.value.asInstanceOf[String]
(remaining, docValues.headOption)
}
+ val (_, methodDoc) = getDocAnnotation(meth.annotations)
+ val argListSymbol = q"${c.fresh[TermName]("argsList")}"
+ val argData = for(argListIndex <- 0 until meth.paramLists.length) yield{
+ val annotDeserializeType = annotDeserializeTypes(argListIndex)
+ val argReader = argReaders(argListIndex)
+ val flattenedArgLists = meth.paramss(argListIndex)
+ def hasDefault(i: Int) = {
+ val defaultName = s"${meth.name}$$default$$${i + 1}"
+ if (curCls.members.exists(_.name.toString == defaultName)) Some(defaultName)
+ else None
+ }
- def unwrapVarargType(arg: Symbol) = {
- val vararg = arg.typeSignature.typeSymbol == definitions.RepeatedParamClass
- val unwrappedType =
- if (!vararg) arg.typeSignature
- else arg.typeSignature.asInstanceOf[TypeRef].args(0)
- (vararg, unwrappedType)
- }
+ val defaults = for ((arg, i) <- flattenedArgLists.zipWithIndex) yield {
+ val arg = TermName(c.freshName())
+ hasDefault(i).map(defaultName => q"($arg: $curCls) => $arg.${newTermName(defaultName)}")
+ }
- val (_, methodDoc) = getDocAnnotation(meth.annotations)
- val readArgSigs = for(
- ((arg, defaultOpt), i) <- flattenedArgLists.zip(defaults).zipWithIndex
- ) yield {
- val (vararg, varargUnwrappedType) = unwrapVarargType(arg)
+ def unwrapVarargType(arg: Symbol) = {
+ val vararg = arg.typeSignature.typeSymbol == definitions.RepeatedParamClass
+ val unwrappedType =
+ if (!vararg) arg.typeSignature
+ else arg.typeSignature.asInstanceOf[TypeRef].args(0)
- val default =
- if (vararg) q"scala.Some(scala.Nil)"
- else defaultOpt match {
- case Some(defaultExpr) => q"scala.Some($defaultExpr($baseArgSym))"
- case None => q"scala.None"
+ (vararg, unwrappedType)
+ }
+
+
+
+ val readArgSigs = for (
+ ((arg, defaultOpt), i) <- flattenedArgLists.zip(defaults).zipWithIndex
+ ) yield {
+
+ val (vararg, varargUnwrappedType) = unwrapVarargType(arg)
+
+ val default =
+ if (vararg) q"scala.Some(scala.Nil)"
+ else defaultOpt match {
+ case Some(defaultExpr) => q"scala.Some($defaultExpr($baseArgSym))"
+ case None => q"scala.None"
+ }
+
+ val (docUnwrappedType, docOpt) = varargUnwrappedType match {
+ case t: AnnotatedType =>
+ import compat._
+ val (remaining, docValue) = getDocAnnotation(t.annotations)
+ if (remaining.isEmpty) (t.underlying, docValue)
+ else (c.universe.AnnotatedType(remaining, t.underlying), docValue)
+
+ case t => (t, None)
}
- val (docUnwrappedType, docOpt) = varargUnwrappedType match{
- case t: AnnotatedType =>
- import compat._
- val (remaining, docValue) = getDocAnnotation(t.annotations)
- if (remaining.isEmpty) (t.underlying, docValue)
- else (c.universe.AnnotatedType(remaining, t.underlying), docValue)
+ val docTree = docOpt match {
+ case None => q"scala.None"
+ case Some(s) => q"scala.Some($s)"
+ }
- case t => (t, None)
- }
+ val argSig =
+ q"""
+ cask.internal.Router.ArgSig[$annotDeserializeType, $curCls, $docUnwrappedType, $ctx](
+ ${arg.name.toString},
+ ${docUnwrappedType.toString + (if (vararg) "*" else "")},
+ $docTree,
+ $defaultOpt
+ )($argReader[$docUnwrappedType])
+ """
- val docTree = docOpt match{
- case None => q"scala.None"
- case Some(s) => q"scala.Some($s)"
+ val reader =
+ if (vararg) c.abort(meth.pos, "Varargs are not supported in cask routes")
+ else
+ q"""
+ cask.internal.Router.makeReadCall(
+ $argListSymbol($argListIndex),
+ ctx,
+ $default,
+ $argSig.asInstanceOf[cask.internal.Router.ArgSig[Any, _, _, cask.model.ParamContext]]
+ )
+ """
+ c.internal.setPos(reader, meth.pos)
+ (reader, argSig)
}
- val argSig = q"""
- cask.internal.Router.ArgSig[$annotDeserializeType, $curCls, $docUnwrappedType, $ctx](
- ${arg.name.toString},
- ${docUnwrappedType.toString + (if(vararg) "*" else "")},
- $docTree,
- $defaultOpt
- )($argReader[$docUnwrappedType])
- """
-
- val reader =
- if(vararg) c.abort(meth.pos, "Varargs are not supported in cask routes")
- else q"""
- cask.internal.Router.makeReadCall(
- $argListSymbol,
- ctx,
- $default,
- $argSig
- )
- """
- c.internal.setPos(reader, meth.pos)
- (reader, argSig, vararg)
- }
+ val (readArgs, argSigs) = readArgSigs.unzip
+ val (argNames, argNameCasts) = flattenedArgLists.map { arg =>
+ val (vararg, unwrappedType) = unwrapVarargType(arg)
+ (
+ pq"${arg.name.toTermName}",
+ if (!vararg) q"${arg.name.toTermName}.asInstanceOf[$unwrappedType]"
+ else q"${arg.name.toTermName}.asInstanceOf[Seq[$unwrappedType]]: _*"
- val (readArgs, argSigs, varargs) = readArgSigs.unzip3
- val (argNames, argNameCasts) = flattenedArgLists.map { arg =>
- val (vararg, unwrappedType) = unwrapVarargType(arg)
- (
- pq"${arg.name.toTermName}",
- if (!vararg) q"${arg.name.toTermName}.asInstanceOf[$unwrappedType]"
- else q"${arg.name.toTermName}.asInstanceOf[Seq[$unwrappedType]]: _*"
+ )
+ }.unzip
- )
- }.unzip
+ (argNameCasts, argSigs, argNames, readArgs)
+ }
+ val argNameCasts = argData.map(_._1)
+ val argSigs = argData.map(_._2)
+ val argNames = argData.map(_._3)
+ val readArgs = argData.map(_._4)
+ var methodCall: c.Tree = q"$baseArgSym.${meth.name.toTermName}"
+ for(argNameCast <- argNameCasts) methodCall = q"$baseArgSym.${meth.name.toTermName}(..$argNameCast)"
- val methCall =
- if (meth.paramLists.isEmpty) q"$baseArgSym.${meth.name.toTermName}"
- else q"$baseArgSym.${meth.name.toTermName}(..$argNameCasts)"
val res = q"""
- cask.internal.Router.EntryPoint[$annotDeserializeType, $curCls, $ctx](
+ cask.internal.Router.EntryPoint[$curCls, $ctx](
${meth.name.toString},
- scala.Seq(..$argSigs),
+ ${argSigs.toList},
${methodDoc match{
case None => q"scala.None"
case Some(s) => q"scala.Some($s)"
}},
- ${varargs.contains(true)},
- ($baseArgSym: $curCls, ctx: $ctx, $argListSymbol: Map[String, $annotDeserializeType]) =>
- cask.internal.Router.validate(Seq(..$readArgs)) match{
- case cask.internal.Router.Result.Success(Seq(..$argNames)) =>
- cask.internal.Router.Result.Success(
- ${wrapOutput(methCall)}
- )
+ ($baseArgSym: $curCls, ctx: $ctx, $argListSymbol: Seq[Map[String, Any]]) =>
+ cask.internal.Router.validate(Seq(..${readArgs.flatten.toList})) match{
+ case cask.internal.Router.Result.Success(Seq(..${argNames.flatten.toList})) =>
+
+ ${wrapOutput(methodCall)}
+
case x: cask.internal.Router.Result.Error => x
}
- ).asInstanceOf[cask.internal.Router.EntryPoint[Any, $curCls, $ctx]]
+ ).asInstanceOf[cask.internal.Router.EntryPoint[$curCls, $ctx]]
"""
c.internal.transform(res){(t, a) =>
@@ -318,4 +336,4 @@ class Router[C <: Context](val c: C) {
res
}
-} \ No newline at end of file
+}
diff --git a/cask/src/cask/main/ErrorMsgs.scala b/cask/src/cask/main/ErrorMsgs.scala
index f762f3a..c5ce978 100644
--- a/cask/src/cask/main/ErrorMsgs.scala
+++ b/cask/src/cask/main/ErrorMsgs.scala
@@ -37,11 +37,11 @@ object ErrorMsgs {
}
def formatMainMethodSignature[T](base: T,
- main: Router.EntryPoint[_, T, _],
+ main: Router.EntryPoint[T, _],
leftIndent: Int,
leftColWidth: Int) = {
// +2 for space on right of left col
- val args = main.argSignatures.map(renderArg(base, _, leftColWidth + leftIndent + 2 + 2, 80))
+ val args = main.argSignatures.last.map(as => renderArg(base, as, leftColWidth + leftIndent + 2 + 2, 80))
val leftIndentStr = " " * leftIndent
val argStrings =
@@ -60,7 +60,7 @@ object ErrorMsgs {
|${argStrings.map(_ + "\n").mkString}""".stripMargin
}
- def formatInvokeError[T](base: T, route: Router.EntryPoint[_, T, _], x: Router.Result.Error): String = {
+ def formatInvokeError[T](base: T, route: Router.EntryPoint[T, _], x: Router.Result.Error): String = {
def expectedMsg = formatMainMethodSignature(base: T, route, 0, 0)
x match{
diff --git a/cask/src/cask/main/Main.scala b/cask/src/cask/main/Main.scala
index 9acaf19..bfb058e 100644
--- a/cask/src/cask/main/Main.scala
+++ b/cask/src/cask/main/Main.scala
@@ -31,8 +31,8 @@ abstract class BaseMain{
lazy val routeTries = Seq("get", "put", "post")
.map { method =>
method -> DispatchTrie.construct[(Routes, Routes.EndpointMetadata[_])](0,
- for ((route, metadata) <- routeList if metadata.endpoint.methods.contains(method))
- yield (Util.splitPath(metadata.endpoint.path): IndexedSeq[String], (route, metadata), metadata.endpoint.subpath)
+ for ((route, metadata) <- routeList if metadata.endpoints.exists(_.methods.contains(method)))
+ yield (Util.splitPath(metadata.endpoints.last.path): IndexedSeq[String], (route, metadata), metadata.endpoints.last.subpath)
)
}.toMap
@@ -59,22 +59,21 @@ abstract class BaseMain{
routeTries(exchange.getRequestMethod.toString.toLowerCase()).lookup(Util.splitPath(exchange.getRequestPath).toList, Map()) match{
case None => writeResponse(exchange, handleError(404))
case Some(((routes, metadata), bindings, remaining)) =>
- val result = metadata.endpoint.handle(
- ParamContext(exchange, remaining), bindings, routes,
- metadata.entryPoint.asInstanceOf[
- EntryPoint[metadata.endpoint.InputType, cask.main.Routes, cask.model.ParamContext]
- ]
+ val providers = metadata.endpoints.map(e =>
+ e.handle(ParamContext(exchange, remaining)) ++ bindings.mapValues(e.wrapPathSegment)
)
-
+ val result = metadata.entryPoint
+ .asInstanceOf[EntryPoint[cask.main.Routes, cask.model.ParamContext]]
+ .invoke(routes, ParamContext(exchange, remaining), providers)
result match{
- case Router.Result.Success(response) => writeResponse(exchange, response)
+ case Router.Result.Success(response: BaseResponse) => writeResponse(exchange, response)
case e: Router.Result.Error =>
writeResponse(exchange,
Response(
ErrorMsgs.formatInvokeError(
routes,
- metadata.entryPoint.asInstanceOf[EntryPoint[_, cask.main.Routes, _]],
+ metadata.entryPoint.asInstanceOf[EntryPoint[cask.main.Routes, _]],
e
),
statusCode = 500)
diff --git a/cask/src/cask/main/Routes.scala b/cask/src/cask/main/Routes.scala
index ee2b2b9..7299276 100644
--- a/cask/src/cask/main/Routes.scala
+++ b/cask/src/cask/main/Routes.scala
@@ -14,15 +14,13 @@ object Routes{
val path: String
val methods: Seq[String]
def subpath: Boolean = false
- def wrapMethodOutput(t: R): Any
- def handle(ctx: ParamContext,
- bindings: Map[String, String],
- routes: Routes,
- entryPoint: EntryPoint[InputType, Routes, cask.model.ParamContext]): Router.Result[BaseResponse]
+ def wrapMethodOutput(ctx: ParamContext,t: R): cask.internal.Router.Result[Any] = cask.internal.Router.Result.Success(t)
+ def handle(ctx: ParamContext): Map[String, InputType]
+ def wrapPathSegment(s: String): InputType
}
- case class EndpointMetadata[T](endpoint: Endpoint[_],
- entryPoint: EntryPoint[_, T, ParamContext])
+ case class EndpointMetadata[T](endpoints: Seq[Endpoint[_]],
+ entryPoint: EntryPoint[T, ParamContext])
case class RoutesEndpointsMetadata[T](value: EndpointMetadata[T]*)
object RoutesEndpointsMetadata{
implicit def initialize[T] = macro initializeImpl[T]
@@ -32,27 +30,39 @@ object Routes{
val routeParts = for{
m <- c.weakTypeOf[T].members
- annot <- m.annotations.filter(_.tree.tpe <:< c.weakTypeOf[Endpoint[_]])
+ val annotations = m.annotations.filter(_.tree.tpe <:< c.weakTypeOf[Endpoint[_]])
+ if annotations.nonEmpty
} yield {
- val annotObject = q"new ${annot.tree.tpe}(..${annot.tree.children.tail})"
- val annotObjectSym = c.universe.TermName(c.freshName("annotObject"))
+
+ val annotObjects =
+ for(annot <- annotations)
+ yield q"new ${annot.tree.tpe}(..${annot.tree.children.tail})"
+ val annotObjectSyms =
+ for(_ <- annotations.indices)
+ yield c.universe.TermName(c.freshName("annotObject"))
val route = router.extractMethod(
m.asInstanceOf[MethodSymbol],
weakTypeOf[T],
- (t: router.c.universe.Tree) => q"$annotObjectSym.wrapMethodOutput($t)",
+ (t: router.c.universe.Tree) => q"${annotObjectSyms.last}.wrapMethodOutput(ctx, $t)",
c.weakTypeOf[ParamContext],
- q"$annotObjectSym.parseMethodInput",
- tq"$annotObjectSym.InputType"
+ annotObjectSyms.map(annotObjectSym => q"$annotObjectSym.parseMethodInput"),
+ annotObjectSyms.map(annotObjectSym => tq"$annotObjectSym.InputType")
+
)
+ val declarations =
+ for((sym, obj) <- annotObjectSyms.zip(annotObjects))
+ yield q"val $sym = $obj"
- q"""{
- val $annotObjectSym = $annotObject
+ val res = q"""{
+ ..$declarations
cask.main.Routes.EndpointMetadata(
- $annotObjectSym,
+ Seq(..$annotObjectSyms),
$route
)
}"""
+// println(res)
+ res
}
c.Expr[RoutesEndpointsMetadata[T]](q"""cask.main.Routes.RoutesEndpointsMetadata(..$routeParts)""")
diff --git a/cask/src/cask/model/ParamContext.scala b/cask/src/cask/model/ParamContext.scala
index d9f8cfc..43da260 100644
--- a/cask/src/cask/model/ParamContext.scala
+++ b/cask/src/cask/model/ParamContext.scala
@@ -2,7 +2,4 @@ package cask.model
import io.undertow.server.HttpServerExchange
-case class ParamContext(exchange: HttpServerExchange, remaining: Seq[String]) {
-
-
-}
+case class ParamContext(exchange: HttpServerExchange, remaining: Seq[String])
diff --git a/cask/test/src/test/cask/ExampleTests.scala b/cask/test/src/test/cask/ExampleTests.scala
index c1c007e..b60f30c 100644
--- a/cask/test/src/test/cask/ExampleTests.scala
+++ b/cask/test/src/test/cask/ExampleTests.scala
@@ -10,8 +10,9 @@ object ExampleTests extends TestSuite{
.setHandler(new BlockingHandler(example.defaultHandler))
.build
server.start()
- val res = f("http://localhost:8080")
- server.stop()
+ val res =
+ try f("http://localhost:8080")
+ finally server.stop()
res
}
diff --git a/cask/test/src/test/cask/FormJsonPost.scala b/cask/test/src/test/cask/FormJsonPost.scala
index 2874a52..3679286 100644
--- a/cask/test/src/test/cask/FormJsonPost.scala
+++ b/cask/test/src/test/cask/FormJsonPost.scala
@@ -5,7 +5,9 @@ import io.undertow.server.HttpServerExchange
object FormJsonPost extends cask.MainRoutes{
@cask.postJson("/json")
+// @db.validateUser()
def jsonEndpoint(x: HttpServerExchange, value1: ujson.Js.Value, value2: Seq[Int]) = {
+// (user: db.User) = {
"OK " + value1 + " " + value2
}
diff --git a/cask/test/src/test/cask/StaticFiles.scala b/cask/test/src/test/cask/StaticFiles.scala
index d51b35a..8f4a8ef 100644
--- a/cask/test/src/test/cask/StaticFiles.scala
+++ b/cask/test/src/test/cask/StaticFiles.scala
@@ -7,7 +7,7 @@ object StaticFiles extends cask.MainRoutes{
}
@cask.static("/static")
- def staticRoutes = "cask/resources/cask"
+ def staticRoutes() = "cask/resources/cask"
initialize()
}