summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLi Haoyi <haoyi.sg@gmail.com>2019-10-09 11:35:22 +0800
committerLi Haoyi <haoyi.sg@gmail.com>2019-10-09 11:42:03 +0800
commit5638b1b2ed83deb15108c9e99a3c1d3f6fecbf9b (patch)
treeedf49c618c824b58df9227654becdd0d8d7ab5b6
parent8571a1eae3a3798dde92022fac042ab3732e0d6f (diff)
downloadcask-5638b1b2ed83deb15108c9e99a3c1d3f6fecbf9b.tar.gz
cask-5638b1b2ed83deb15108c9e99a3c1d3f6fecbf9b.tar.bz2
cask-5638b1b2ed83deb15108c9e99a3c1d3f6fecbf9b.zip
Statically check the set of decorators applied to each endpoint method
-rw-r--r--cask/src/cask/decorators/compress.scala3
-rw-r--r--cask/src/cask/router/RoutesEndpointMetadata.scala72
-rw-r--r--cask/test/src/test/cask/FailureTests.scala15
-rw-r--r--example/websockets4/app/src/Websockets4.scala3
4 files changed, 79 insertions, 14 deletions
diff --git a/cask/src/cask/decorators/compress.scala b/cask/src/cask/decorators/compress.scala
index 17931fd..097abe4 100644
--- a/cask/src/cask/decorators/compress.scala
+++ b/cask/src/cask/decorators/compress.scala
@@ -11,8 +11,7 @@ class compress extends cask.RawDecorator{
.toSeq
.flatMap(_.asScala)
.flatMap(_.split(", "))
- val r = delegate(Map())
- val finalResult = r.transform{ case v: cask.Response.Raw =>
+ val finalResult = delegate(Map()).transform{ case v: cask.Response.Raw =>
val (newData, newHeaders) = if (acceptEncodings.exists(_.toLowerCase == "gzip")) {
new Response.Data {
def write(out: OutputStream): Unit = {
diff --git a/cask/src/cask/router/RoutesEndpointMetadata.scala b/cask/src/cask/router/RoutesEndpointMetadata.scala
index 6919166..3a00a36 100644
--- a/cask/src/cask/router/RoutesEndpointMetadata.scala
+++ b/cask/src/cask/router/RoutesEndpointMetadata.scala
@@ -7,8 +7,43 @@ import scala.reflect.macros.blackbox
case class EndpointMetadata[T](decorators: Seq[Decorator[_, _, _]],
endpoint: Endpoint[_, _, _],
entryPoint: EntryPoint[T, _])
+object EndpointMetadata{
+ // `seqify` is used to statically check that the decorators applied to each
+ // individual endpoint method line up, and each decorator's `OuterReturned`
+ // correctly matches the enclosing decorator's `InnerReturned`. We don't bother
+ // checking decorators defined as part of cask.Main or cask.Routes, since those
+ // are both more dynamic (and hard to check) and also less often used and thus
+ // less error prone
+ def seqify1(d: Decorator[_, _, _]) = Seq(d)
+ def seqify2[T1]
+ (d1: Decorator[T1, _, _])
+ (d2: Decorator[_, T1, _]) = Seq(d1, d2)
+ def seqify3[T1, T2]
+ (d1: Decorator[T1, _, _])
+ (d2: Decorator[T2, T1, _])
+ (d3: Decorator[_, T2, _]) = Seq(d1, d2, d3)
+ def seqify4[T1, T2, T3]
+ (d1: Decorator[T1, _, _])
+ (d2: Decorator[T2, T1, _])
+ (d3: Decorator[T3, T2, _])
+ (d4: Decorator[_, T3, _]) = Seq(d1, d2, d3, d4)
+ def seqify5[T1, T2, T3, T4]
+ (d1: Decorator[T1, _, _])
+ (d2: Decorator[T2, T1, _])
+ (d3: Decorator[T3, T2, _])
+ (d4: Decorator[T4, T3, _])
+ (d5: Decorator[_, T4, _]) = Seq(d1, d2, d3, d4, d5)
+ def seqify6[T1, T2, T3, T4, T5]
+ (d1: Decorator[T1, _, _])
+ (d2: Decorator[T2, T1, _])
+ (d3: Decorator[T3, T2, _])
+ (d4: Decorator[T4, T3, _])
+ (d5: Decorator[T5, T4, _])
+ (d6: Decorator[_, T5, _]) = Seq(d1, d2, d3, d4)
+}
case class RoutesEndpointsMetadata[T](value: EndpointMetadata[T]*)
object RoutesEndpointsMetadata{
+
implicit def initialize[T]: RoutesEndpointsMetadata[T] = macro initializeImpl[T]
implicit def initializeImpl[T: c.WeakTypeTag](c: blackbox.Context): c.Expr[RoutesEndpointsMetadata[T]] = {
import c.universe._
@@ -33,11 +68,18 @@ object RoutesEndpointsMetadata{
val annotObjects =
for(annot <- annotations)
- yield q"new ${annot.tree.tpe}(..${annot.tree.children.tail})"
+ yield q"new ${annot.tree.tpe}(..${annot.tree.children.tail})"
val annotObjectSyms =
for(_ <- annotations.indices)
- yield c.universe.TermName(c.freshName("annotObject"))
+ yield c.universe.TermName(c.freshName("annotObject"))
+
+ val annotPositions =
+ for(a <- annotations)
+ yield a.tree.find(_.pos != NoPosition) match{
+ case None => m.pos
+ case Some(t) => t.pos
+ }
val route = router.extractMethod(
m.asInstanceOf[MethodSymbol],
@@ -52,15 +94,23 @@ object RoutesEndpointsMetadata{
for((sym, obj) <- annotObjectSyms.zip(annotObjects))
yield q"val $sym = $obj"
- val res = q"""{
- ..$declarations
- cask.router.EndpointMetadata(
- Seq(..${annotObjectSyms.dropRight(1)}),
- ${annotObjectSyms.last},
- $route
- )
- }"""
- res
+ val seqify = TermName("seqify" + annotObjectSyms.length)
+
+ val seqifyCall = annotObjectSyms
+ .zip(annotPositions)
+ .reverse
+ .foldLeft[Tree](q"cask.router.EndpointMetadata.$seqify"){
+ case (lhs, (rhs, pos)) => q"$lhs(${c.internal.setPos(q"$rhs", pos)})"
+ }
+
+ q"""{
+ ..$declarations
+ cask.router.EndpointMetadata(
+ $seqifyCall.reverse.dropRight(1),
+ ${annotObjectSyms.last},
+ $route
+ )
+ }"""
}
c.Expr[RoutesEndpointsMetadata[T]](q"""cask.router.RoutesEndpointsMetadata(..$routeParts)""")
diff --git a/cask/test/src/test/cask/FailureTests.scala b/cask/test/src/test/cask/FailureTests.scala
index bd27971..65018ce 100644
--- a/cask/test/src/test/cask/FailureTests.scala
+++ b/cask/test/src/test/cask/FailureTests.scala
@@ -9,10 +9,25 @@ object FailureTests extends TestSuite {
delegate(Map("extra" -> 31337))
}
}
+
val tests = Tests{
'mismatchedDecorators - {
utest.compileError("""
object Decorated extends cask.MainRoutes{
+ @myDecorator
+ @cask.websocket("/hello/:world")
+ def hello(world: String)(extra: Int) = ???
+ initialize()
+ }
+ """).check(
+ """
+ def hello(world: String)(extra: Int) = ???
+ ^
+ """,
+ "required: cask.router.Decorator[_, cask.endpoints.WebsocketResult, _]"
+ )
+ utest.compileError("""
+ object Decorated extends cask.MainRoutes{
@cask.get("/hello/:world")
@myDecorator()
def hello(world: String)(extra: Int)= world
diff --git a/example/websockets4/app/src/Websockets4.scala b/example/websockets4/app/src/Websockets4.scala
index f275746..83c4f98 100644
--- a/example/websockets4/app/src/Websockets4.scala
+++ b/example/websockets4/app/src/Websockets4.scala
@@ -1,7 +1,8 @@
package app
case class Websockets4()(implicit val log: cask.Logger) extends cask.Routes{
- @cask.decorators.compress // make sure compress decorator passes non-requests through correctly
+ // make sure compress decorator passes non-requests through correctly
+ override def decorators = Seq(new cask.decorators.compress())
@cask.websocket("/connect/:userName")
def showUserProfile(userName: String): cask.WebsocketResult = {
if (userName != "haoyi") cask.Response("", statusCode = 403)