From bb7dbd6a1b188f07b512057264177972d0dec850 Mon Sep 17 00:00:00 2001 From: Li Haoyi Date: Fri, 20 Jul 2018 10:31:12 +0800 Subject: Simple routing & serving now works in the 3 test routes given --- build.sc | 6 +- cask/src/cask/Cask.scala | 145 +++++++++-- cask/src/cask/Router.scala | 447 +++++++++++++++++++++++++++++++++ cask/src/cask/Util.scala | 25 ++ cask/test/src/test/cask/CaskTest.scala | 26 +- 5 files changed, 613 insertions(+), 36 deletions(-) create mode 100644 cask/src/cask/Router.scala create mode 100644 cask/src/cask/Util.scala diff --git a/build.sc b/build.sc index 2a3ca2c..663a9d0 100644 --- a/build.sc +++ b/build.sc @@ -2,7 +2,11 @@ import mill._, scalalib._ object cask extends ScalaModule{ def scalaVersion = "2.12.6" - def ivyDeps = Agg(ivy"org.scala-lang:scala-reflect:$scalaVersion") + def ivyDeps = Agg( + ivy"org.scala-lang:scala-reflect:$scalaVersion", + ivy"io.undertow:undertow-core:2.0.11.Final", + ivy"com.github.scopt::scopt:3.5.0" + ) object test extends Tests{ def ivyDeps = Agg(ivy"com.lihaoyi::utest::0.6.3") def testFrameworks = Seq("utest.runner.Framework") diff --git a/cask/src/cask/Cask.scala b/cask/src/cask/Cask.scala index c903cc2..3445594 100644 --- a/cask/src/cask/Cask.scala +++ b/cask/src/cask/Cask.scala @@ -1,40 +1,151 @@ package cask +import cask.Router.EntryPoint + import language.experimental.macros import scala.annotation.StaticAnnotation import scala.reflect.macros.blackbox.Context +import java.io.OutputStream +import java.nio.ByteBuffer + +import io.undertow.Undertow +import io.undertow.server.{HttpHandler, HttpServerExchange} +import io.undertow.util.{Headers, HttpString} class route(val path: String) extends StaticAnnotation -class Main(x: Any*) +class Main(servers: Routes*){ + val port: Int = 8080 + val host: String = "localhost" + def main(args: Array[String]): Unit = { + val allRoutes = for{ + server <- servers + route <- server.caskMetadata.value.map(x => x: Routes.RouteMetadata[_]) + } yield (server, route) + + val server = Undertow.builder + .addHttpListener(port, host) + .setHandler(new HttpHandler() { + def handleRequest(exchange: HttpServerExchange): Unit = { + val routeOpt = + allRoutes + .iterator + .map { case (s: Routes, r: Routes.RouteMetadata[_]) => + Util.matchRoute(r.metadata.path, exchange.getRequestPath).map((s, r, _)) + } + .flatten + .toStream + .headOption + + + routeOpt match{ + case None => + exchange.setStatusCode(404) + exchange.getResponseHeaders.put(Headers.CONTENT_TYPE, "text/plain") + exchange.getResponseSender.send("404 Not Found") + case Some((server, route, bindings)) => + import collection.JavaConverters._ + val allBindings = + bindings ++ + exchange.getQueryParameters + .asScala + .toSeq + .flatMap{case (k, vs) => vs.asScala.map((k, _))} + val result = route.entryPoint + .asInstanceOf[EntryPoint[server.type]] + .invoke(server, allBindings.mapValues(Some(_)).toSeq) + result match{ + case Router.Result.Success(response: Response) => + response.headers.foreach{case (k, v) => + exchange.getResponseHeaders.put(new HttpString(k), v) + } + + exchange.setStatusCode(response.statusCode) + + + response.data.write( + new OutputStream { + def write(b: Int) = { + exchange.getResponseSender.send(ByteBuffer.wrap(Array(b.toByte))) + } + override def write(b: Array[Byte]) = { + exchange.getResponseSender.send(ByteBuffer.wrap(b)) + } + override def write(b: Array[Byte], off: Int, len: Int) = { + exchange.getResponseSender.send(ByteBuffer.wrap(b.slice(off, off + len))) + } + } + ) + case err: Router.Result.Error => + exchange.setStatusCode(400) + exchange.getResponseHeaders.put(Headers.CONTENT_TYPE, "text/plain") + exchange.getResponseSender.send("400 Not Found " + result) + } + + + } + } + }) + .build + server.start() + } +} -object Server{ - case class Route(name: String, metadata: route) - case class Routes[T](value: Route*) - object Routes{ +case class Response(data: Response.Data, + statusCode: Int = 200, + headers: Seq[(String, String)] = Nil) +object Response{ + implicit def dataResponse[T](t: T)(implicit c: T => Data) = Response(t) + trait Data{ + def write(out: OutputStream): Unit + } + object Data{ + implicit class StringData(s: String) extends Data{ + def write(out: OutputStream) = out.write(s.getBytes) + } + implicit class BytesData(b: Array[Byte]) extends Data{ + def write(out: OutputStream) = out.write(b) + } + } +} +object Routes{ + case class RouteMetadata[T](metadata: route, entryPoint: EntryPoint[T]) + case class Metadata[T](value: RouteMetadata[T]*) + object Metadata{ implicit def initialize[T] = macro initializeImpl[T] - implicit def initializeImpl[T: c.WeakTypeTag](c: Context): c.Expr[Routes[T]] = { + implicit def initializeImpl[T: c.WeakTypeTag](c: Context): c.Expr[Metadata[T]] = { import c.universe._ - + val router = new cask.Router(c) val routes = c.weakTypeOf[T].members .map(m => (m, m.annotations.filter(_.tree.tpe =:= c.weakTypeOf[route]))) - .collect{case (m, Seq(a)) => (m, a)} + .collect{case (m, Seq(a)) => + ( + m, + a, + router.extractMethod( + m.asInstanceOf[router.c.universe.MethodSymbol], + weakTypeOf[T].asInstanceOf[router.c.universe.Type] + ).asInstanceOf[c.universe.Tree] + ) + } - val routeParts = for((m, a) <- routes) yield { + val routeParts = for((m, a, routeTree) <- routes) yield { val annotation = q"new ${a.tree.tpe}(..${a.tree.children.tail})" - q"cask.Server.Route(${m.name.toString}, $annotation)" + q"cask.Routes.RouteMetadata($annotation, $routeTree)" } - c.Expr[Routes[T]](q"""cask.Server.Routes(..$routeParts)""") + + + c.Expr[Metadata[T]](q"""cask.Routes.Metadata(..$routeParts)""") } } } -class Server[T](){ - private[this] var routes0: Server.Routes[this.type] = null - def routes = - if (routes0 != null) routes0 +class Routes{ + private[this] var metadata0: Routes.Metadata[this.type] = null + def caskMetadata = + if (metadata0 != null) metadata0 else throw new Exception("Routes not yet initialize") - protected[this] def initialize()(implicit routes: Server.Routes[this.type]): Unit = { - routes0 = routes + protected[this] def initialize()(implicit routes: Routes.Metadata[this.type]): Unit = { + metadata0 = routes } } diff --git a/cask/src/cask/Router.scala b/cask/src/cask/Router.scala new file mode 100644 index 0000000..ffa9f0b --- /dev/null +++ b/cask/src/cask/Router.scala @@ -0,0 +1,447 @@ +package cask + + +import language.experimental.macros + +import scala.annotation.StaticAnnotation +import scala.collection.mutable +import scala.reflect.macros.blackbox.Context + +/** + * More or less a minimal version of Autowire's Server that lets you generate + * a set of "routes" from the methods defined in an object, and call them + * using passing in name/args/kwargs via Java reflection, without having to + * generate/compile code or use Scala reflection. This saves us spinning up + * the Scala compiler and greatly reduces the startup time of cached scripts. + */ +object Router{ + /** + * Allows you to query how many things are overriden by the enclosing owner. + */ + case class Overrides(value: Int) + object Overrides{ + def apply()(implicit c: Overrides) = c.value + implicit def generate: Overrides = macro impl + def impl(c: Context): c.Tree = { + import c.universe._ + q"new _root_.cask.Router.Overrides(${c.internal.enclosingOwner.overrides.length})" + } + } + + class doc(s: String) extends StaticAnnotation + class main extends StaticAnnotation + def generateRoutes[T]: Seq[Router.EntryPoint[T]] = macro generateRoutesImpl[T] + def generateRoutesImpl[T: c.WeakTypeTag](c: Context): c.Expr[Seq[EntryPoint[T]]] = { + import c.universe._ + val r = new Router(c) + val allRoutes = r.getAllRoutesForClass( + weakTypeOf[T].asInstanceOf[r.c.Type] + ).asInstanceOf[Iterable[c.Tree]] + + c.Expr[Seq[EntryPoint[T]]](q"_root_.scala.Seq(..$allRoutes)") + } + + /** + * Models what is known by the router about a single argument: that it has + * a [[name]], a human-readable [[typeString]] describing what the type is + * (just for logging and reading, not a replacement for a `TypeTag`) and + * possible a function that can compute its default value + */ + case class ArgSig[T, V](name: String, + typeString: String, + doc: Option[String], + default: Option[T => V]) + (implicit val reads: scopt.Read[V]) + + def stripDashes(s: String) = { + if (s.startsWith("--")) s.drop(2) + else if (s.startsWith("-")) s.drop(1) + else s + } + /** + * What is known about a single endpoint for our routes. It has a [[name]], + * [[argSignatures]] for each argument, and a macro-generated [[invoke0]] + * that performs all the necessary argument parsing and de-serialization. + * + * Realistically, you will probably spend most of your time calling [[invoke]] + * instead, which provides a nicer API to call it that mimmicks the API of + * calling a Scala method. + */ + case class EntryPoint[T](name: String, + argSignatures: Seq[ArgSig[T, _]], + doc: Option[String], + varargs: Boolean, + invoke0: (T, Map[String, String], Seq[String]) => Result[Any], + overrides: Int){ + def invoke(target: T, groupedArgs: Seq[(String, Option[String])]): Result[Any] = { + var remainingArgSignatures = argSignatures.toList.filter(_.reads.arity > 0) + + val accumulatedKeywords = mutable.Map.empty[ArgSig[T, _], mutable.Buffer[String]] + val keywordableArgs = if (varargs) argSignatures.dropRight(1) else argSignatures + + for(arg <- keywordableArgs) accumulatedKeywords(arg) = mutable.Buffer.empty + + val leftoverArgs = mutable.Buffer.empty[String] + + val lookupArgSig = Map(argSignatures.map(x => (x.name, x)):_*) + + var incomplete: Option[ArgSig[T, _]] = None + + for(group <- groupedArgs){ + + group match{ + case (value, None) => + if (value(0) == '-' && !varargs){ + lookupArgSig.get(stripDashes(value)) match{ + case None => leftoverArgs.append(value) + case Some(sig) => incomplete = Some(sig) + } + + } else remainingArgSignatures match { + case Nil => leftoverArgs.append(value) + case last :: Nil if varargs => leftoverArgs.append(value) + case next :: rest => + accumulatedKeywords(next).append(value) + remainingArgSignatures = rest + } + case (rawKey, Some(value)) => + val key = stripDashes(rawKey) + lookupArgSig.get(key) match{ + case Some(x) if accumulatedKeywords.contains(x) => + if (accumulatedKeywords(x).nonEmpty && varargs){ + leftoverArgs.append(rawKey, value) + }else{ + accumulatedKeywords(x).append(value) + remainingArgSignatures = remainingArgSignatures.filter(_.name != key) + } + case _ => + leftoverArgs.append(rawKey, value) + } + } + } + + val missing0 = remainingArgSignatures + .filter(_.default.isEmpty) + + val missing = if(varargs) { + missing0.filter(_ != argSignatures.last) + } else { + missing0.filter(x => incomplete != Some(x)) + } + val duplicates = accumulatedKeywords.toSeq.filter(_._2.length > 1) + + if ( + incomplete.nonEmpty || + missing.nonEmpty || + duplicates.nonEmpty || + (leftoverArgs.nonEmpty && !varargs) + ){ + Result.Error.MismatchedArguments( + missing = missing, + unknown = leftoverArgs, + duplicate = duplicates, + incomplete = incomplete + + ) + } else { + val mapping = accumulatedKeywords + .iterator + .collect{case (k, Seq(single)) => (k.name, single)} + .toMap + + try invoke0(target, mapping, leftoverArgs) + catch{case e: Throwable => + Result.Error.Exception(e) + } + } + } + } + + def tryEither[T](t: => T, error: Throwable => Result.ParamError) = { + try Right(t) + catch{ case e: Throwable => Left(error(e))} + } + def readVarargs(arg: ArgSig[_, _], + values: Seq[String], + thunk: String => Any) = { + val attempts = + for(item <- values) + yield tryEither(thunk(item), Result.ParamError.Invalid(arg, item, _)) + + + val bad = attempts.collect{ case Left(x) => x} + if (bad.nonEmpty) Left(bad) + else Right(attempts.collect{case Right(x) => x}) + } + def read(dict: Map[String, String], + default: => Option[Any], + arg: ArgSig[_, _], + thunk: String => Any): FailMaybe = { + arg.reads.arity match{ + case 0 => + tryEither(thunk(null), Result.ParamError.DefaultFailed(arg, _)).left.map(Seq(_)) + case 1 => + dict.get(arg.name) match{ + case None => + tryEither(default.get, Result.ParamError.DefaultFailed(arg, _)).left.map(Seq(_)) + + case Some(x) => + tryEither(thunk(x), Result.ParamError.Invalid(arg, x, _)).left.map(Seq(_)) + } + } + + } + + /** + * Represents what comes out of an attempt to invoke an [[EntryPoint]]. + * Could succeed with a value, but could fail in many different ways. + */ + sealed trait Result[+T] + object Result{ + + /** + * Invoking the [[EntryPoint]] was totally successful, and returned a + * result + */ + case class Success[T](value: T) extends Result[T] + + /** + * Invoking the [[EntryPoint]] was not successful + */ + sealed trait Error extends Result[Nothing] + object Error{ + + /** + * Invoking the [[EntryPoint]] failed with an exception while executing + * code within it. + */ + case class Exception(t: Throwable) extends Error + + /** + * Invoking the [[EntryPoint]] failed because the arguments provided + * did not line up with the arguments expected + */ + case class MismatchedArguments(missing: Seq[ArgSig[_, _]], + unknown: Seq[String], + duplicate: Seq[(ArgSig[_, _], Seq[String])], + incomplete: Option[ArgSig[_, _]]) extends Error + /** + * Invoking the [[EntryPoint]] failed because there were problems + * deserializing/parsing individual arguments + */ + case class InvalidArguments(values: Seq[ParamError]) extends Error + } + + sealed trait ParamError + object ParamError{ + /** + * Something went wrong trying to de-serialize the input parameter; + * the thrown exception is stored in [[ex]] + */ + case class Invalid(arg: ArgSig[_, _], value: String, ex: Throwable) extends ParamError + /** + * Something went wrong trying to evaluate the default value + * for this input parameter + */ + case class DefaultFailed(arg: ArgSig[_, _], ex: Throwable) extends ParamError + } + } + + + type FailMaybe = Either[Seq[Result.ParamError], Any] + type FailAll = Either[Seq[Result.ParamError], Seq[Any]] + + def validate(args: Seq[FailMaybe]): Result[Seq[Any]] = { + val lefts = args.collect{case Left(x) => x}.flatten + + if (lefts.nonEmpty) Result.Error.InvalidArguments(lefts) + else { + val rights = args.collect{case Right(x) => x} + Result.Success(rights) + } + } + + def makeReadCall(dict: Map[String, String], + default: => Option[Any], + arg: ArgSig[_, _]) = { + read(dict, default, arg, arg.reads.reads(_)) + } + def makeReadVarargsCall(arg: ArgSig[_, _], values: Seq[String]) = { + readVarargs(arg, values, arg.reads.reads(_)) + } +} + + +class Router [C <: Context](val c: C) { + import c.universe._ + def getValsOrMeths(curCls: Type): Iterable[MethodSymbol] = { + def isAMemberOfAnyRef(member: Symbol) = { + // AnyRef is an alias symbol, we go to the real "owner" of these methods + val anyRefSym = c.mirror.universe.definitions.ObjectClass + member.owner == anyRefSym + } + val extractableMembers = for { + member <- curCls.members.toList.reverse + if !isAMemberOfAnyRef(member) + if !member.isSynthetic + if member.isPublic + if member.isTerm + memTerm = member.asTerm + if memTerm.isMethod + if !memTerm.isModule + } yield memTerm.asMethod + + extractableMembers flatMap { case memTerm => + if (memTerm.isSetter || memTerm.isConstructor || memTerm.isGetter) Nil + else Seq(memTerm) + + } + } + + + + def extractMethod(meth: MethodSymbol, curCls: c.universe.Type): c.universe.Tree = { + val baseArgSym = TermName(c.freshName()) + val flattenedArgLists = meth.paramss.flatten + def hasDefault(i: Int) = { + val defaultName = s"${meth.name}$$default$$${i + 1}" + if (curCls.members.exists(_.name.toString == defaultName)) Some(defaultName) + else None + } + val argListSymbol = q"${c.fresh[TermName]("argsList")}" + val extrasSymbol = q"${c.fresh[TermName]("extras")}" + val defaults = for ((arg, i) <- flattenedArgLists.zipWithIndex) yield { + val arg = TermName(c.freshName()) + hasDefault(i).map(defaultName => q"($arg: $curCls) => $arg.${newTermName(defaultName)}") + } + + def getDocAnnotation(annotations: List[Annotation]) = { + val (docTrees, remaining) = annotations.partition(_.tpe =:= typeOf[Router.doc]) + val docValues = for { + doc <- docTrees + if doc.scalaArgs.head.isInstanceOf[Literal] + l = doc.scalaArgs.head.asInstanceOf[Literal] + if l.value.value.isInstanceOf[String] + } yield l.value.value.asInstanceOf[String] + (remaining, docValues.headOption) + } + + def unwrapVarargType(arg: Symbol) = { + val vararg = arg.typeSignature.typeSymbol == definitions.RepeatedParamClass + val unwrappedType = + if (!vararg) arg.typeSignature + else arg.typeSignature.asInstanceOf[TypeRef].args(0) + + (vararg, unwrappedType) + } + + + val (_, methodDoc) = getDocAnnotation(meth.annotations) + val readArgSigs = for( + ((arg, defaultOpt), i) <- flattenedArgLists.zip(defaults).zipWithIndex + ) yield { + + val (vararg, varargUnwrappedType) = unwrapVarargType(arg) + + val default = + if (vararg) q"scala.Some(scala.Nil)" + else defaultOpt match { + case Some(defaultExpr) => q"scala.Some($defaultExpr($baseArgSym))" + case None => q"scala.None" + } + + val (docUnwrappedType, docOpt) = varargUnwrappedType match{ + case t: AnnotatedType => + import compat._ + val (remaining, docValue) = getDocAnnotation(t.annotations) + if (remaining.isEmpty) (t.underlying, docValue) + else (c.universe.AnnotatedType(remaining, t.underlying), docValue) + + case t => (t, None) + } + + val docTree = docOpt match{ + case None => q"scala.None" + case Some(s) => q"scala.Some($s)" + } + + val argSig = q""" + cask.Router.ArgSig[$curCls, $docUnwrappedType]( + ${arg.name.toString}, + ${docUnwrappedType.toString + (if(vararg) "*" else "")}, + $docTree, + $defaultOpt + ) + """ + + val reader = + if(vararg) q""" + cask.Router.makeReadVarargsCall( + $argSig, + $extrasSymbol + ) + """ else q""" + cask.Router.makeReadCall( + $argListSymbol, + $default, + $argSig + ) + """ + c.internal.setPos(reader, meth.pos) + (reader, argSig, vararg) + } + + val (readArgs, argSigs, varargs) = readArgSigs.unzip3 + val (argNames, argNameCasts) = flattenedArgLists.map { arg => + val (vararg, unwrappedType) = unwrapVarargType(arg) + ( + pq"${arg.name.toTermName}", + if (!vararg) q"${arg.name.toTermName}.asInstanceOf[$unwrappedType]" + else q"${arg.name.toTermName}.asInstanceOf[Seq[$unwrappedType]]: _*" + + ) + }.unzip + + + val res = q""" + cask.Router.EntryPoint[$curCls]( + ${meth.name.toString}, + scala.Seq(..$argSigs), + ${methodDoc match{ + case None => q"scala.None" + case Some(s) => q"scala.Some($s)" + }}, + ${varargs.contains(true)}, + ($baseArgSym: $curCls, $argListSymbol: Map[String, String], $extrasSymbol: Seq[String]) => + cask.Router.validate(Seq(..$readArgs)) match{ + case cask.Router.Result.Success(List(..$argNames)) => + cask.Router.Result.Success( + $baseArgSym.${meth.name.toTermName}(..$argNameCasts): cask.Response + ) + case x: cask.Router.Result.Error => x + }, + cask.Router.Overrides() + ) + """ + + c.internal.transform(res){(t, a) => + c.internal.setPos(t, meth.pos) + a.default(t) + } + res + } + + def hasMainAnnotation(t: MethodSymbol) = { + t.annotations.exists(_.tpe =:= typeOf[Router.main]) + } + def getAllRoutesForClass(curCls: Type, + pred: MethodSymbol => Boolean = hasMainAnnotation) + : Iterable[c.universe.Tree] = { + for{ + t <- getValsOrMeths(curCls) + if pred(t) + } yield { + extractMethod(t, curCls) + } + } +} \ No newline at end of file diff --git a/cask/src/cask/Util.scala b/cask/src/cask/Util.scala new file mode 100644 index 0000000..0856f4d --- /dev/null +++ b/cask/src/cask/Util.scala @@ -0,0 +1,25 @@ +package cask + +object Util { + def trimSplit(p: String) = p.dropWhile(_ == '/').reverse.dropWhile(_ == '/').reverse.split('/') + def matchRoute(route: String, path: String): Option[Map[String, String]] = { + val routeSegments = trimSplit(route) + val pathSegments = trimSplit(path) + + def rec(i: Int, bindings: Map[String, String]): Option[Map[String, String]] = { + if (routeSegments.length == i && pathSegments.length == i) Some(bindings) + else if ((routeSegments.length == i) != (pathSegments.length == i)) None + else { + val routeSeg = routeSegments(i) + val pathSeg = pathSegments(i) + if (routeSeg(0) == ':' && routeSeg(1) == ':') { + Some(bindings + (routeSeg.drop(2) -> pathSegments.drop(i).mkString("/"))) + } + else if (routeSeg(0) == ':') rec(i+1, bindings + (routeSeg.drop(1) -> pathSeg)) + else if (pathSeg == routeSeg) rec(i + 1, bindings) + else None + } + } + rec(0, Map.empty) + } +} diff --git a/cask/test/src/test/cask/CaskTest.scala b/cask/test/src/test/cask/CaskTest.scala index a1384f8..9a0876d 100644 --- a/cask/test/src/test/cask/CaskTest.scala +++ b/cask/test/src/test/cask/CaskTest.scala @@ -1,34 +1,24 @@ package test.cask +object MyServer extends cask.Routes{ -object MyServer extends cask.Server(){ - def x = "/ext" - @cask.route("/user/:username" + (x * 2)) + @cask.route("/user/:userName") def showUserProfile(userName: String) = { - // show the user profile for that user s"User $userName" } - @cask.route("/post/:int") - def showPost(postId: Int) = { - // show the post with the given id, the id is an integer - s"Post $postId" + @cask.route("/post/:postId") + def showPost(postId: Int, query: String) = { + s"Post $postId $query" } - @cask.route("/path/:subPath") - def show_subpath(subPath: String) = { - // show the subpath after /path/ + @cask.route("/path/::subPath") + def showSubpath(subPath: String) = { s"Subpath $subPath" } initialize() - println(routes.value) -} - -object Main extends cask.Main(MyServer){ - def main(args: Array[String]): Unit = { - MyServer - } } +object Main extends cask.Main(MyServer) -- cgit v1.2.3