From 5638b1b2ed83deb15108c9e99a3c1d3f6fecbf9b Mon Sep 17 00:00:00 2001 From: Li Haoyi Date: Wed, 9 Oct 2019 11:35:22 +0800 Subject: Statically check the set of decorators applied to each endpoint method --- cask/src/cask/decorators/compress.scala | 3 +- cask/src/cask/router/RoutesEndpointMetadata.scala | 72 +++++++++++++++++++---- cask/test/src/test/cask/FailureTests.scala | 15 +++++ example/websockets4/app/src/Websockets4.scala | 3 +- 4 files changed, 79 insertions(+), 14 deletions(-) diff --git a/cask/src/cask/decorators/compress.scala b/cask/src/cask/decorators/compress.scala index 17931fd..097abe4 100644 --- a/cask/src/cask/decorators/compress.scala +++ b/cask/src/cask/decorators/compress.scala @@ -11,8 +11,7 @@ class compress extends cask.RawDecorator{ .toSeq .flatMap(_.asScala) .flatMap(_.split(", ")) - val r = delegate(Map()) - val finalResult = r.transform{ case v: cask.Response.Raw => + val finalResult = delegate(Map()).transform{ case v: cask.Response.Raw => val (newData, newHeaders) = if (acceptEncodings.exists(_.toLowerCase == "gzip")) { new Response.Data { def write(out: OutputStream): Unit = { diff --git a/cask/src/cask/router/RoutesEndpointMetadata.scala b/cask/src/cask/router/RoutesEndpointMetadata.scala index 6919166..3a00a36 100644 --- a/cask/src/cask/router/RoutesEndpointMetadata.scala +++ b/cask/src/cask/router/RoutesEndpointMetadata.scala @@ -7,8 +7,43 @@ import scala.reflect.macros.blackbox case class EndpointMetadata[T](decorators: Seq[Decorator[_, _, _]], endpoint: Endpoint[_, _, _], entryPoint: EntryPoint[T, _]) +object EndpointMetadata{ + // `seqify` is used to statically check that the decorators applied to each + // individual endpoint method line up, and each decorator's `OuterReturned` + // correctly matches the enclosing decorator's `InnerReturned`. We don't bother + // checking decorators defined as part of cask.Main or cask.Routes, since those + // are both more dynamic (and hard to check) and also less often used and thus + // less error prone + def seqify1(d: Decorator[_, _, _]) = Seq(d) + def seqify2[T1] + (d1: Decorator[T1, _, _]) + (d2: Decorator[_, T1, _]) = Seq(d1, d2) + def seqify3[T1, T2] + (d1: Decorator[T1, _, _]) + (d2: Decorator[T2, T1, _]) + (d3: Decorator[_, T2, _]) = Seq(d1, d2, d3) + def seqify4[T1, T2, T3] + (d1: Decorator[T1, _, _]) + (d2: Decorator[T2, T1, _]) + (d3: Decorator[T3, T2, _]) + (d4: Decorator[_, T3, _]) = Seq(d1, d2, d3, d4) + def seqify5[T1, T2, T3, T4] + (d1: Decorator[T1, _, _]) + (d2: Decorator[T2, T1, _]) + (d3: Decorator[T3, T2, _]) + (d4: Decorator[T4, T3, _]) + (d5: Decorator[_, T4, _]) = Seq(d1, d2, d3, d4, d5) + def seqify6[T1, T2, T3, T4, T5] + (d1: Decorator[T1, _, _]) + (d2: Decorator[T2, T1, _]) + (d3: Decorator[T3, T2, _]) + (d4: Decorator[T4, T3, _]) + (d5: Decorator[T5, T4, _]) + (d6: Decorator[_, T5, _]) = Seq(d1, d2, d3, d4) +} case class RoutesEndpointsMetadata[T](value: EndpointMetadata[T]*) object RoutesEndpointsMetadata{ + implicit def initialize[T]: RoutesEndpointsMetadata[T] = macro initializeImpl[T] implicit def initializeImpl[T: c.WeakTypeTag](c: blackbox.Context): c.Expr[RoutesEndpointsMetadata[T]] = { import c.universe._ @@ -33,11 +68,18 @@ object RoutesEndpointsMetadata{ val annotObjects = for(annot <- annotations) - yield q"new ${annot.tree.tpe}(..${annot.tree.children.tail})" + yield q"new ${annot.tree.tpe}(..${annot.tree.children.tail})" val annotObjectSyms = for(_ <- annotations.indices) - yield c.universe.TermName(c.freshName("annotObject")) + yield c.universe.TermName(c.freshName("annotObject")) + + val annotPositions = + for(a <- annotations) + yield a.tree.find(_.pos != NoPosition) match{ + case None => m.pos + case Some(t) => t.pos + } val route = router.extractMethod( m.asInstanceOf[MethodSymbol], @@ -52,15 +94,23 @@ object RoutesEndpointsMetadata{ for((sym, obj) <- annotObjectSyms.zip(annotObjects)) yield q"val $sym = $obj" - val res = q"""{ - ..$declarations - cask.router.EndpointMetadata( - Seq(..${annotObjectSyms.dropRight(1)}), - ${annotObjectSyms.last}, - $route - ) - }""" - res + val seqify = TermName("seqify" + annotObjectSyms.length) + + val seqifyCall = annotObjectSyms + .zip(annotPositions) + .reverse + .foldLeft[Tree](q"cask.router.EndpointMetadata.$seqify"){ + case (lhs, (rhs, pos)) => q"$lhs(${c.internal.setPos(q"$rhs", pos)})" + } + + q"""{ + ..$declarations + cask.router.EndpointMetadata( + $seqifyCall.reverse.dropRight(1), + ${annotObjectSyms.last}, + $route + ) + }""" } c.Expr[RoutesEndpointsMetadata[T]](q"""cask.router.RoutesEndpointsMetadata(..$routeParts)""") diff --git a/cask/test/src/test/cask/FailureTests.scala b/cask/test/src/test/cask/FailureTests.scala index bd27971..65018ce 100644 --- a/cask/test/src/test/cask/FailureTests.scala +++ b/cask/test/src/test/cask/FailureTests.scala @@ -9,8 +9,23 @@ object FailureTests extends TestSuite { delegate(Map("extra" -> 31337)) } } + val tests = Tests{ 'mismatchedDecorators - { + utest.compileError(""" + object Decorated extends cask.MainRoutes{ + @myDecorator + @cask.websocket("/hello/:world") + def hello(world: String)(extra: Int) = ??? + initialize() + } + """).check( + """ + def hello(world: String)(extra: Int) = ??? + ^ + """, + "required: cask.router.Decorator[_, cask.endpoints.WebsocketResult, _]" + ) utest.compileError(""" object Decorated extends cask.MainRoutes{ @cask.get("/hello/:world") diff --git a/example/websockets4/app/src/Websockets4.scala b/example/websockets4/app/src/Websockets4.scala index f275746..83c4f98 100644 --- a/example/websockets4/app/src/Websockets4.scala +++ b/example/websockets4/app/src/Websockets4.scala @@ -1,7 +1,8 @@ package app case class Websockets4()(implicit val log: cask.Logger) extends cask.Routes{ - @cask.decorators.compress // make sure compress decorator passes non-requests through correctly + // make sure compress decorator passes non-requests through correctly + override def decorators = Seq(new cask.decorators.compress()) @cask.websocket("/connect/:userName") def showUserProfile(userName: String): cask.WebsocketResult = { if (userName != "haoyi") cask.Response("", statusCode = 403) -- cgit v1.2.3