diff options
-rw-r--r-- | cask/src/cask/Main.scala (renamed from cask/src/cask/Cask.scala) | 84 | ||||
-rw-r--r-- | cask/src/cask/Router.scala | 67 | ||||
-rw-r--r-- | cask/src/cask/Routes.scala | 96 | ||||
-rw-r--r-- | cask/test/src/test/cask/CaskTest.scala | 25 |
4 files changed, 150 insertions, 122 deletions
diff --git a/cask/src/cask/Cask.scala b/cask/src/cask/Main.scala index 97c145a..705ed35 100644 --- a/cask/src/cask/Cask.scala +++ b/cask/src/cask/Main.scala @@ -1,21 +1,12 @@ package cask import cask.Router.EntryPoint -import language.experimental.macros -import scala.annotation.StaticAnnotation -import scala.reflect.macros.blackbox.Context import java.io.OutputStream import java.nio.ByteBuffer import io.undertow.Undertow import io.undertow.server.{HttpHandler, HttpServerExchange} import io.undertow.util.{Headers, HttpString} -trait RouteBase{ - val path: String -} -class get(val path: String) extends StaticAnnotation with RouteBase -class post(val path: String) extends StaticAnnotation with RouteBase -class put(val path: String) extends StaticAnnotation with RouteBase class Main(servers: Routes*){ val port: Int = 8080 @@ -43,20 +34,23 @@ class Main(servers: Routes*){ routeOpt match{ case None => + exchange.setStatusCode(404) exchange.getResponseHeaders.put(Headers.CONTENT_TYPE, "text/plain") exchange.getResponseSender.send("404 Not Found") case Some((server, route, bindings)) => import collection.JavaConverters._ val allBindings = - bindings ++ - exchange.getQueryParameters - .asScala - .toSeq - .flatMap{case (k, vs) => vs.asScala.map((k, _))} + bindings.toSeq ++ + exchange.getQueryParameters + .asScala + .toSeq + .flatMap{case (k, vs) => vs.asScala.map((k, _))} + val result = route.entryPoint .asInstanceOf[EntryPoint[server.type]] - .invoke(server, allBindings.mapValues(Some(_)).toSeq) + .invoke(server, exchange, allBindings.map{case (k, v) => (k, Some(v))}) + result match{ case Router.Result.Success(response: Response) => response.headers.foreach{case (k, v) => @@ -94,64 +88,4 @@ class Main(servers: Routes*){ } } -case class Response(data: Response.Data, - statusCode: Int = 200, - headers: Seq[(String, String)] = Nil) -object Response{ - implicit def dataResponse[T](t: T)(implicit c: T => Data) = Response(t) - trait Data{ - def write(out: OutputStream): Unit - } - object Data{ - implicit class StringData(s: String) extends Data{ - def write(out: OutputStream) = out.write(s.getBytes) - } - implicit class BytesData(b: Array[Byte]) extends Data{ - def write(out: OutputStream) = out.write(b) - } - } -} -object Routes{ - case class RouteMetadata[T](metadata: RouteBase, entryPoint: EntryPoint[T]) - case class Metadata[T](value: RouteMetadata[T]*) - object Metadata{ - implicit def initialize[T] = macro initializeImpl[T] - implicit def initializeImpl[T: c.WeakTypeTag](c: Context): c.Expr[Metadata[T]] = { - import c.universe._ - val router = new cask.Router(c) - val routes = c.weakTypeOf[T].members - .map(m => (m, m.annotations.filter(_.tree.tpe <:< c.weakTypeOf[RouteBase]))) - .collect{case (m, Seq(a)) => - ( - m, - a, - router.extractMethod( - m.asInstanceOf[router.c.universe.MethodSymbol], - weakTypeOf[T].asInstanceOf[router.c.universe.Type] - ).asInstanceOf[c.universe.Tree] - ) - } - - val routeParts = for((m, a, routeTree) <- routes) yield { - val annotation = q"new ${a.tree.tpe}(..${a.tree.children.tail})" - q"cask.Routes.RouteMetadata($annotation, $routeTree)" - } - - - c.Expr[Metadata[T]](q"""cask.Routes.Metadata(..$routeParts)""") - } - } -} - -class Routes{ - private[this] var metadata0: Routes.Metadata[this.type] = null - def caskMetadata = - if (metadata0 != null) metadata0 - else throw new Exception("Routes not yet initialize") - - protected[this] def initialize()(implicit routes: Routes.Metadata[this.type]): Unit = { - metadata0 = routes - } -} - diff --git a/cask/src/cask/Router.scala b/cask/src/cask/Router.scala index ffa9f0b..798e977 100644 --- a/cask/src/cask/Router.scala +++ b/cask/src/cask/Router.scala @@ -1,8 +1,9 @@ package cask -import language.experimental.macros +import io.undertow.server.HttpServerExchange +import language.experimental.macros import scala.annotation.StaticAnnotation import scala.collection.mutable import scala.reflect.macros.blackbox.Context @@ -51,7 +52,7 @@ object Router{ typeString: String, doc: Option[String], default: Option[T => V]) - (implicit val reads: scopt.Read[V]) + (implicit val reads: ParamType[V]) def stripDashes(s: String) = { if (s.startsWith("--")) s.drop(2) @@ -71,9 +72,9 @@ object Router{ argSignatures: Seq[ArgSig[T, _]], doc: Option[String], varargs: Boolean, - invoke0: (T, Map[String, String], Seq[String]) => Result[Any], + invoke0: (T, HttpServerExchange, Map[String, Seq[String]], Seq[String]) => Result[Any], overrides: Int){ - def invoke(target: T, groupedArgs: Seq[(String, Option[String])]): Result[Any] = { + def invoke(target: T, exchange: HttpServerExchange, groupedArgs: Seq[(String, Option[String])]): Result[Any] = { var remainingArgSignatures = argSignatures.toList.filter(_.reads.arity > 0) val accumulatedKeywords = mutable.Map.empty[ArgSig[T, _], mutable.Buffer[String]] @@ -128,31 +129,26 @@ object Router{ } else { missing0.filter(x => incomplete != Some(x)) } - val duplicates = accumulatedKeywords.toSeq.filter(_._2.length > 1) if ( incomplete.nonEmpty || missing.nonEmpty || - duplicates.nonEmpty || (leftoverArgs.nonEmpty && !varargs) ){ Result.Error.MismatchedArguments( missing = missing, unknown = leftoverArgs, - duplicate = duplicates, + duplicate = Nil, incomplete = incomplete ) } else { val mapping = accumulatedKeywords - .iterator - .collect{case (k, Seq(single)) => (k.name, single)} + .map{case (k, single) => (k.name, single)} .toMap - try invoke0(target, mapping, leftoverArgs) - catch{case e: Throwable => - Result.Error.Exception(e) - } + try invoke0(target, exchange, mapping, leftoverArgs) + catch{case e: Throwable => Result.Error.Exception(e)} } } } @@ -161,22 +157,11 @@ object Router{ try Right(t) catch{ case e: Throwable => Left(error(e))} } - def readVarargs(arg: ArgSig[_, _], - values: Seq[String], - thunk: String => Any) = { - val attempts = - for(item <- values) - yield tryEither(thunk(item), Result.ParamError.Invalid(arg, item, _)) - - val bad = attempts.collect{ case Left(x) => x} - if (bad.nonEmpty) Left(bad) - else Right(attempts.collect{case Right(x) => x}) - } - def read(dict: Map[String, String], + def read(dict: Map[String, Seq[String]], default: => Option[Any], arg: ArgSig[_, _], - thunk: String => Any): FailMaybe = { + thunk: Seq[String] => Any): FailMaybe = { arg.reads.arity match{ case 0 => tryEither(thunk(null), Result.ParamError.DefaultFailed(arg, _)).left.map(Seq(_)) @@ -238,7 +223,7 @@ object Router{ * Something went wrong trying to de-serialize the input parameter; * the thrown exception is stored in [[ex]] */ - case class Invalid(arg: ArgSig[_, _], value: String, ex: Throwable) extends ParamError + case class Invalid(arg: ArgSig[_, _], value: Seq[String], ex: Throwable) extends ParamError /** * Something went wrong trying to evaluate the default value * for this input parameter @@ -261,14 +246,13 @@ object Router{ } } - def makeReadCall(dict: Map[String, String], + def makeReadCall(dict: Map[String, Seq[String]], + exchange: HttpServerExchange, default: => Option[Any], arg: ArgSig[_, _]) = { - read(dict, default, arg, arg.reads.reads(_)) - } - def makeReadVarargsCall(arg: ArgSig[_, _], values: Seq[String]) = { - readVarargs(arg, values, arg.reads.reads(_)) + read(dict, default, arg, arg.reads.read(exchange, _)) } + } @@ -375,17 +359,14 @@ class Router [C <: Context](val c: C) { """ val reader = - if(vararg) q""" - cask.Router.makeReadVarargsCall( - $argSig, - $extrasSymbol + if(vararg) c.abort(meth.pos, "Varargs are not supported in cask routes") + else q""" + cask.Router.makeReadCall( + $argListSymbol, + exchange, + $default, + $argSig ) - """ else q""" - cask.Router.makeReadCall( - $argListSymbol, - $default, - $argSig - ) """ c.internal.setPos(reader, meth.pos) (reader, argSig, vararg) @@ -412,7 +393,7 @@ class Router [C <: Context](val c: C) { case Some(s) => q"scala.Some($s)" }}, ${varargs.contains(true)}, - ($baseArgSym: $curCls, $argListSymbol: Map[String, String], $extrasSymbol: Seq[String]) => + ($baseArgSym: $curCls, exchange: io.undertow.server.HttpServerExchange, $argListSymbol: Map[String, Seq[String]], $extrasSymbol: Seq[String]) => cask.Router.validate(Seq(..$readArgs)) match{ case cask.Router.Result.Success(List(..$argNames)) => cask.Router.Result.Success( diff --git a/cask/src/cask/Routes.scala b/cask/src/cask/Routes.scala new file mode 100644 index 0000000..156fdbf --- /dev/null +++ b/cask/src/cask/Routes.scala @@ -0,0 +1,96 @@ +package cask +import language.experimental.macros +import java.io.OutputStream + +import cask.Router.EntryPoint +import io.undertow.server.HttpServerExchange + +import scala.annotation.StaticAnnotation +import scala.reflect.macros.blackbox.Context + +class ParamType[T](val arity: Int, val read: (HttpServerExchange, Seq[String]) => T) +object ParamType{ + implicit object StringParam extends ParamType[String](1, (h, x) => x.head) + implicit object BooleanParam extends ParamType[Boolean](1, (h, x) => x.head.toBoolean) + implicit object ByteParam extends ParamType[Byte](1, (h, x) => x.head.toByte) + implicit object ShortParam extends ParamType[Short](1, (h, x) => x.head.toShort) + implicit object IntParam extends ParamType[Int](1, (h, x) => x.head.toInt) + implicit object LongParam extends ParamType[Long](1, (h, x) => x.head.toLong) + implicit object DoubleParam extends ParamType[Double](1, (h, x) => x.head.toDouble) + implicit object FloatParam extends ParamType[Float](1, (h, x) => x.head.toFloat) + implicit def SeqParam[T: ParamType] = + new ParamType[Seq[T]](1, (h, s) => s.map(x => implicitly[ParamType[T]].read(h, Seq(x)))) + + implicit object HttpExchangeParam extends ParamType[HttpServerExchange](0, (h, x) => h) +} + + +trait RouteBase{ + val path: String +} +class get(val path: String) extends StaticAnnotation with RouteBase +class post(val path: String) extends StaticAnnotation with RouteBase +class put(val path: String) extends StaticAnnotation with RouteBase +class route(val path: String, val methods: Seq[String]) extends StaticAnnotation with RouteBase + +case class Response(data: Response.Data, + statusCode: Int = 200, + headers: Seq[(String, String)] = Nil) +object Response{ + implicit def dataResponse[T](t: T)(implicit c: T => Data) = Response(t) + trait Data{ + def write(out: OutputStream): Unit + } + object Data{ + implicit class StringData(s: String) extends Data{ + def write(out: OutputStream) = out.write(s.getBytes) + } + implicit class BytesData(b: Array[Byte]) extends Data{ + def write(out: OutputStream) = out.write(b) + } + } +} + +object Routes{ + case class RouteMetadata[T](metadata: RouteBase, entryPoint: EntryPoint[T]) + case class Metadata[T](value: RouteMetadata[T]*) + object Metadata{ + implicit def initialize[T] = macro initializeImpl[T] + implicit def initializeImpl[T: c.WeakTypeTag](c: Context): c.Expr[Metadata[T]] = { + import c.universe._ + val router = new cask.Router(c) + val routes = c.weakTypeOf[T].members + .map(m => (m, m.annotations.filter(_.tree.tpe <:< c.weakTypeOf[RouteBase]))) + .collect{case (m, Seq(a)) => + ( + m, + a, + router.extractMethod( + m.asInstanceOf[router.c.universe.MethodSymbol], + weakTypeOf[T].asInstanceOf[router.c.universe.Type] + ).asInstanceOf[c.universe.Tree] + ) + } + + val routeParts = for((m, a, routeTree) <- routes) yield { + val annotation = q"new ${a.tree.tpe}(..${a.tree.children.tail})" + q"cask.Routes.RouteMetadata($annotation, $routeTree)" + } + + + c.Expr[Metadata[T]](q"""cask.Routes.Metadata(..$routeParts)""") + } + } +} + +class Routes{ + private[this] var metadata0: Routes.Metadata[this.type] = null + def caskMetadata = + if (metadata0 != null) metadata0 + else throw new Exception("Routes not yet initialize") + + protected[this] def initialize()(implicit routes: Routes.Metadata[this.type]): Unit = { + metadata0 = routes + } +} + diff --git a/cask/test/src/test/cask/CaskTest.scala b/cask/test/src/test/cask/CaskTest.scala index 01c905a..a2e0f47 100644 --- a/cask/test/src/test/cask/CaskTest.scala +++ b/cask/test/src/test/cask/CaskTest.scala @@ -1,21 +1,38 @@ package test.cask +import io.undertow.io.Receiver.{ErrorCallback, FullBytesCallback} +import io.undertow.server.HttpServerExchange + object MyServer extends cask.Routes{ @cask.get("/user/:userName") def showUserProfile(userName: String) = { s"User $userName" } - @cask.post("/post/:postId") - def showPost(postId: Int, query: String) = { + @cask.get("/post/:postId") + def showPost(postId: Int, query: Seq[String]) = { s"Post $postId $query" } - @cask.put("/path/::subPath") - def showSubpath(subPath: String) = { + @cask.get("/path/::subPath") + def showSubpath(x: HttpServerExchange, subPath: String) = { + x.getRequestReceiver().receiveFullBytes((exchange, data) => { + + }: FullBytesCallback, + (exchange, exception) => { + + }: ErrorCallback + ) + println(x) s"Subpath $subPath" } +// @cask.post("/echo-size") +// def echoSize(x: HttpServerExchange, subPath: String) = { +// println(x) +// s"Subpath $subPath" +// } + initialize() } |