summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLi Haoyi <haoyi.sg@gmail.com>2018-07-20 10:31:12 +0800
committerLi Haoyi <haoyi.sg@gmail.com>2018-07-20 10:43:11 +0800
commitbb7dbd6a1b188f07b512057264177972d0dec850 (patch)
tree3005e706218d8f1f40ede824fe9d9a0c0bd8464e
parent37bfff46d2d875d2dec27a24973ec8f6784f1bfb (diff)
downloadcask-bb7dbd6a1b188f07b512057264177972d0dec850.tar.gz
cask-bb7dbd6a1b188f07b512057264177972d0dec850.tar.bz2
cask-bb7dbd6a1b188f07b512057264177972d0dec850.zip
Simple routing & serving now works in the 3 test routes given
-rw-r--r--build.sc6
-rw-r--r--cask/src/cask/Cask.scala145
-rw-r--r--cask/src/cask/Router.scala447
-rw-r--r--cask/src/cask/Util.scala25
-rw-r--r--cask/test/src/test/cask/CaskTest.scala26
5 files changed, 613 insertions, 36 deletions
diff --git a/build.sc b/build.sc
index 2a3ca2c..663a9d0 100644
--- a/build.sc
+++ b/build.sc
@@ -2,7 +2,11 @@ import mill._, scalalib._
object cask extends ScalaModule{
def scalaVersion = "2.12.6"
- def ivyDeps = Agg(ivy"org.scala-lang:scala-reflect:$scalaVersion")
+ def ivyDeps = Agg(
+ ivy"org.scala-lang:scala-reflect:$scalaVersion",
+ ivy"io.undertow:undertow-core:2.0.11.Final",
+ ivy"com.github.scopt::scopt:3.5.0"
+ )
object test extends Tests{
def ivyDeps = Agg(ivy"com.lihaoyi::utest::0.6.3")
def testFrameworks = Seq("utest.runner.Framework")
diff --git a/cask/src/cask/Cask.scala b/cask/src/cask/Cask.scala
index c903cc2..3445594 100644
--- a/cask/src/cask/Cask.scala
+++ b/cask/src/cask/Cask.scala
@@ -1,40 +1,151 @@
package cask
+import cask.Router.EntryPoint
+
import language.experimental.macros
import scala.annotation.StaticAnnotation
import scala.reflect.macros.blackbox.Context
+import java.io.OutputStream
+import java.nio.ByteBuffer
+
+import io.undertow.Undertow
+import io.undertow.server.{HttpHandler, HttpServerExchange}
+import io.undertow.util.{Headers, HttpString}
class route(val path: String) extends StaticAnnotation
-class Main(x: Any*)
+class Main(servers: Routes*){
+ val port: Int = 8080
+ val host: String = "localhost"
+ def main(args: Array[String]): Unit = {
+ val allRoutes = for{
+ server <- servers
+ route <- server.caskMetadata.value.map(x => x: Routes.RouteMetadata[_])
+ } yield (server, route)
+
+ val server = Undertow.builder
+ .addHttpListener(port, host)
+ .setHandler(new HttpHandler() {
+ def handleRequest(exchange: HttpServerExchange): Unit = {
+ val routeOpt =
+ allRoutes
+ .iterator
+ .map { case (s: Routes, r: Routes.RouteMetadata[_]) =>
+ Util.matchRoute(r.metadata.path, exchange.getRequestPath).map((s, r, _))
+ }
+ .flatten
+ .toStream
+ .headOption
+
+
+ routeOpt match{
+ case None =>
+ exchange.setStatusCode(404)
+ exchange.getResponseHeaders.put(Headers.CONTENT_TYPE, "text/plain")
+ exchange.getResponseSender.send("404 Not Found")
+ case Some((server, route, bindings)) =>
+ import collection.JavaConverters._
+ val allBindings =
+ bindings ++
+ exchange.getQueryParameters
+ .asScala
+ .toSeq
+ .flatMap{case (k, vs) => vs.asScala.map((k, _))}
+ val result = route.entryPoint
+ .asInstanceOf[EntryPoint[server.type]]
+ .invoke(server, allBindings.mapValues(Some(_)).toSeq)
+ result match{
+ case Router.Result.Success(response: Response) =>
+ response.headers.foreach{case (k, v) =>
+ exchange.getResponseHeaders.put(new HttpString(k), v)
+ }
+
+ exchange.setStatusCode(response.statusCode)
+
+
+ response.data.write(
+ new OutputStream {
+ def write(b: Int) = {
+ exchange.getResponseSender.send(ByteBuffer.wrap(Array(b.toByte)))
+ }
+ override def write(b: Array[Byte]) = {
+ exchange.getResponseSender.send(ByteBuffer.wrap(b))
+ }
+ override def write(b: Array[Byte], off: Int, len: Int) = {
+ exchange.getResponseSender.send(ByteBuffer.wrap(b.slice(off, off + len)))
+ }
+ }
+ )
+ case err: Router.Result.Error =>
+ exchange.setStatusCode(400)
+ exchange.getResponseHeaders.put(Headers.CONTENT_TYPE, "text/plain")
+ exchange.getResponseSender.send("400 Not Found " + result)
+ }
+
+
+ }
+ }
+ })
+ .build
+ server.start()
+ }
+}
-object Server{
- case class Route(name: String, metadata: route)
- case class Routes[T](value: Route*)
- object Routes{
+case class Response(data: Response.Data,
+ statusCode: Int = 200,
+ headers: Seq[(String, String)] = Nil)
+object Response{
+ implicit def dataResponse[T](t: T)(implicit c: T => Data) = Response(t)
+ trait Data{
+ def write(out: OutputStream): Unit
+ }
+ object Data{
+ implicit class StringData(s: String) extends Data{
+ def write(out: OutputStream) = out.write(s.getBytes)
+ }
+ implicit class BytesData(b: Array[Byte]) extends Data{
+ def write(out: OutputStream) = out.write(b)
+ }
+ }
+}
+object Routes{
+ case class RouteMetadata[T](metadata: route, entryPoint: EntryPoint[T])
+ case class Metadata[T](value: RouteMetadata[T]*)
+ object Metadata{
implicit def initialize[T] = macro initializeImpl[T]
- implicit def initializeImpl[T: c.WeakTypeTag](c: Context): c.Expr[Routes[T]] = {
+ implicit def initializeImpl[T: c.WeakTypeTag](c: Context): c.Expr[Metadata[T]] = {
import c.universe._
-
+ val router = new cask.Router(c)
val routes = c.weakTypeOf[T].members
.map(m => (m, m.annotations.filter(_.tree.tpe =:= c.weakTypeOf[route])))
- .collect{case (m, Seq(a)) => (m, a)}
+ .collect{case (m, Seq(a)) =>
+ (
+ m,
+ a,
+ router.extractMethod(
+ m.asInstanceOf[router.c.universe.MethodSymbol],
+ weakTypeOf[T].asInstanceOf[router.c.universe.Type]
+ ).asInstanceOf[c.universe.Tree]
+ )
+ }
- val routeParts = for((m, a) <- routes) yield {
+ val routeParts = for((m, a, routeTree) <- routes) yield {
val annotation = q"new ${a.tree.tpe}(..${a.tree.children.tail})"
- q"cask.Server.Route(${m.name.toString}, $annotation)"
+ q"cask.Routes.RouteMetadata($annotation, $routeTree)"
}
- c.Expr[Routes[T]](q"""cask.Server.Routes(..$routeParts)""")
+
+
+ c.Expr[Metadata[T]](q"""cask.Routes.Metadata(..$routeParts)""")
}
}
}
-class Server[T](){
- private[this] var routes0: Server.Routes[this.type] = null
- def routes =
- if (routes0 != null) routes0
+class Routes{
+ private[this] var metadata0: Routes.Metadata[this.type] = null
+ def caskMetadata =
+ if (metadata0 != null) metadata0
else throw new Exception("Routes not yet initialize")
- protected[this] def initialize()(implicit routes: Server.Routes[this.type]): Unit = {
- routes0 = routes
+ protected[this] def initialize()(implicit routes: Routes.Metadata[this.type]): Unit = {
+ metadata0 = routes
}
}
diff --git a/cask/src/cask/Router.scala b/cask/src/cask/Router.scala
new file mode 100644
index 0000000..ffa9f0b
--- /dev/null
+++ b/cask/src/cask/Router.scala
@@ -0,0 +1,447 @@
+package cask
+
+
+import language.experimental.macros
+
+import scala.annotation.StaticAnnotation
+import scala.collection.mutable
+import scala.reflect.macros.blackbox.Context
+
+/**
+ * More or less a minimal version of Autowire's Server that lets you generate
+ * a set of "routes" from the methods defined in an object, and call them
+ * using passing in name/args/kwargs via Java reflection, without having to
+ * generate/compile code or use Scala reflection. This saves us spinning up
+ * the Scala compiler and greatly reduces the startup time of cached scripts.
+ */
+object Router{
+ /**
+ * Allows you to query how many things are overriden by the enclosing owner.
+ */
+ case class Overrides(value: Int)
+ object Overrides{
+ def apply()(implicit c: Overrides) = c.value
+ implicit def generate: Overrides = macro impl
+ def impl(c: Context): c.Tree = {
+ import c.universe._
+ q"new _root_.cask.Router.Overrides(${c.internal.enclosingOwner.overrides.length})"
+ }
+ }
+
+ class doc(s: String) extends StaticAnnotation
+ class main extends StaticAnnotation
+ def generateRoutes[T]: Seq[Router.EntryPoint[T]] = macro generateRoutesImpl[T]
+ def generateRoutesImpl[T: c.WeakTypeTag](c: Context): c.Expr[Seq[EntryPoint[T]]] = {
+ import c.universe._
+ val r = new Router(c)
+ val allRoutes = r.getAllRoutesForClass(
+ weakTypeOf[T].asInstanceOf[r.c.Type]
+ ).asInstanceOf[Iterable[c.Tree]]
+
+ c.Expr[Seq[EntryPoint[T]]](q"_root_.scala.Seq(..$allRoutes)")
+ }
+
+ /**
+ * Models what is known by the router about a single argument: that it has
+ * a [[name]], a human-readable [[typeString]] describing what the type is
+ * (just for logging and reading, not a replacement for a `TypeTag`) and
+ * possible a function that can compute its default value
+ */
+ case class ArgSig[T, V](name: String,
+ typeString: String,
+ doc: Option[String],
+ default: Option[T => V])
+ (implicit val reads: scopt.Read[V])
+
+ def stripDashes(s: String) = {
+ if (s.startsWith("--")) s.drop(2)
+ else if (s.startsWith("-")) s.drop(1)
+ else s
+ }
+ /**
+ * What is known about a single endpoint for our routes. It has a [[name]],
+ * [[argSignatures]] for each argument, and a macro-generated [[invoke0]]
+ * that performs all the necessary argument parsing and de-serialization.
+ *
+ * Realistically, you will probably spend most of your time calling [[invoke]]
+ * instead, which provides a nicer API to call it that mimmicks the API of
+ * calling a Scala method.
+ */
+ case class EntryPoint[T](name: String,
+ argSignatures: Seq[ArgSig[T, _]],
+ doc: Option[String],
+ varargs: Boolean,
+ invoke0: (T, Map[String, String], Seq[String]) => Result[Any],
+ overrides: Int){
+ def invoke(target: T, groupedArgs: Seq[(String, Option[String])]): Result[Any] = {
+ var remainingArgSignatures = argSignatures.toList.filter(_.reads.arity > 0)
+
+ val accumulatedKeywords = mutable.Map.empty[ArgSig[T, _], mutable.Buffer[String]]
+ val keywordableArgs = if (varargs) argSignatures.dropRight(1) else argSignatures
+
+ for(arg <- keywordableArgs) accumulatedKeywords(arg) = mutable.Buffer.empty
+
+ val leftoverArgs = mutable.Buffer.empty[String]
+
+ val lookupArgSig = Map(argSignatures.map(x => (x.name, x)):_*)
+
+ var incomplete: Option[ArgSig[T, _]] = None
+
+ for(group <- groupedArgs){
+
+ group match{
+ case (value, None) =>
+ if (value(0) == '-' && !varargs){
+ lookupArgSig.get(stripDashes(value)) match{
+ case None => leftoverArgs.append(value)
+ case Some(sig) => incomplete = Some(sig)
+ }
+
+ } else remainingArgSignatures match {
+ case Nil => leftoverArgs.append(value)
+ case last :: Nil if varargs => leftoverArgs.append(value)
+ case next :: rest =>
+ accumulatedKeywords(next).append(value)
+ remainingArgSignatures = rest
+ }
+ case (rawKey, Some(value)) =>
+ val key = stripDashes(rawKey)
+ lookupArgSig.get(key) match{
+ case Some(x) if accumulatedKeywords.contains(x) =>
+ if (accumulatedKeywords(x).nonEmpty && varargs){
+ leftoverArgs.append(rawKey, value)
+ }else{
+ accumulatedKeywords(x).append(value)
+ remainingArgSignatures = remainingArgSignatures.filter(_.name != key)
+ }
+ case _ =>
+ leftoverArgs.append(rawKey, value)
+ }
+ }
+ }
+
+ val missing0 = remainingArgSignatures
+ .filter(_.default.isEmpty)
+
+ val missing = if(varargs) {
+ missing0.filter(_ != argSignatures.last)
+ } else {
+ missing0.filter(x => incomplete != Some(x))
+ }
+ val duplicates = accumulatedKeywords.toSeq.filter(_._2.length > 1)
+
+ if (
+ incomplete.nonEmpty ||
+ missing.nonEmpty ||
+ duplicates.nonEmpty ||
+ (leftoverArgs.nonEmpty && !varargs)
+ ){
+ Result.Error.MismatchedArguments(
+ missing = missing,
+ unknown = leftoverArgs,
+ duplicate = duplicates,
+ incomplete = incomplete
+
+ )
+ } else {
+ val mapping = accumulatedKeywords
+ .iterator
+ .collect{case (k, Seq(single)) => (k.name, single)}
+ .toMap
+
+ try invoke0(target, mapping, leftoverArgs)
+ catch{case e: Throwable =>
+ Result.Error.Exception(e)
+ }
+ }
+ }
+ }
+
+ def tryEither[T](t: => T, error: Throwable => Result.ParamError) = {
+ try Right(t)
+ catch{ case e: Throwable => Left(error(e))}
+ }
+ def readVarargs(arg: ArgSig[_, _],
+ values: Seq[String],
+ thunk: String => Any) = {
+ val attempts =
+ for(item <- values)
+ yield tryEither(thunk(item), Result.ParamError.Invalid(arg, item, _))
+
+
+ val bad = attempts.collect{ case Left(x) => x}
+ if (bad.nonEmpty) Left(bad)
+ else Right(attempts.collect{case Right(x) => x})
+ }
+ def read(dict: Map[String, String],
+ default: => Option[Any],
+ arg: ArgSig[_, _],
+ thunk: String => Any): FailMaybe = {
+ arg.reads.arity match{
+ case 0 =>
+ tryEither(thunk(null), Result.ParamError.DefaultFailed(arg, _)).left.map(Seq(_))
+ case 1 =>
+ dict.get(arg.name) match{
+ case None =>
+ tryEither(default.get, Result.ParamError.DefaultFailed(arg, _)).left.map(Seq(_))
+
+ case Some(x) =>
+ tryEither(thunk(x), Result.ParamError.Invalid(arg, x, _)).left.map(Seq(_))
+ }
+ }
+
+ }
+
+ /**
+ * Represents what comes out of an attempt to invoke an [[EntryPoint]].
+ * Could succeed with a value, but could fail in many different ways.
+ */
+ sealed trait Result[+T]
+ object Result{
+
+ /**
+ * Invoking the [[EntryPoint]] was totally successful, and returned a
+ * result
+ */
+ case class Success[T](value: T) extends Result[T]
+
+ /**
+ * Invoking the [[EntryPoint]] was not successful
+ */
+ sealed trait Error extends Result[Nothing]
+ object Error{
+
+ /**
+ * Invoking the [[EntryPoint]] failed with an exception while executing
+ * code within it.
+ */
+ case class Exception(t: Throwable) extends Error
+
+ /**
+ * Invoking the [[EntryPoint]] failed because the arguments provided
+ * did not line up with the arguments expected
+ */
+ case class MismatchedArguments(missing: Seq[ArgSig[_, _]],
+ unknown: Seq[String],
+ duplicate: Seq[(ArgSig[_, _], Seq[String])],
+ incomplete: Option[ArgSig[_, _]]) extends Error
+ /**
+ * Invoking the [[EntryPoint]] failed because there were problems
+ * deserializing/parsing individual arguments
+ */
+ case class InvalidArguments(values: Seq[ParamError]) extends Error
+ }
+
+ sealed trait ParamError
+ object ParamError{
+ /**
+ * Something went wrong trying to de-serialize the input parameter;
+ * the thrown exception is stored in [[ex]]
+ */
+ case class Invalid(arg: ArgSig[_, _], value: String, ex: Throwable) extends ParamError
+ /**
+ * Something went wrong trying to evaluate the default value
+ * for this input parameter
+ */
+ case class DefaultFailed(arg: ArgSig[_, _], ex: Throwable) extends ParamError
+ }
+ }
+
+
+ type FailMaybe = Either[Seq[Result.ParamError], Any]
+ type FailAll = Either[Seq[Result.ParamError], Seq[Any]]
+
+ def validate(args: Seq[FailMaybe]): Result[Seq[Any]] = {
+ val lefts = args.collect{case Left(x) => x}.flatten
+
+ if (lefts.nonEmpty) Result.Error.InvalidArguments(lefts)
+ else {
+ val rights = args.collect{case Right(x) => x}
+ Result.Success(rights)
+ }
+ }
+
+ def makeReadCall(dict: Map[String, String],
+ default: => Option[Any],
+ arg: ArgSig[_, _]) = {
+ read(dict, default, arg, arg.reads.reads(_))
+ }
+ def makeReadVarargsCall(arg: ArgSig[_, _], values: Seq[String]) = {
+ readVarargs(arg, values, arg.reads.reads(_))
+ }
+}
+
+
+class Router [C <: Context](val c: C) {
+ import c.universe._
+ def getValsOrMeths(curCls: Type): Iterable[MethodSymbol] = {
+ def isAMemberOfAnyRef(member: Symbol) = {
+ // AnyRef is an alias symbol, we go to the real "owner" of these methods
+ val anyRefSym = c.mirror.universe.definitions.ObjectClass
+ member.owner == anyRefSym
+ }
+ val extractableMembers = for {
+ member <- curCls.members.toList.reverse
+ if !isAMemberOfAnyRef(member)
+ if !member.isSynthetic
+ if member.isPublic
+ if member.isTerm
+ memTerm = member.asTerm
+ if memTerm.isMethod
+ if !memTerm.isModule
+ } yield memTerm.asMethod
+
+ extractableMembers flatMap { case memTerm =>
+ if (memTerm.isSetter || memTerm.isConstructor || memTerm.isGetter) Nil
+ else Seq(memTerm)
+
+ }
+ }
+
+
+
+ def extractMethod(meth: MethodSymbol, curCls: c.universe.Type): c.universe.Tree = {
+ val baseArgSym = TermName(c.freshName())
+ val flattenedArgLists = meth.paramss.flatten
+ def hasDefault(i: Int) = {
+ val defaultName = s"${meth.name}$$default$$${i + 1}"
+ if (curCls.members.exists(_.name.toString == defaultName)) Some(defaultName)
+ else None
+ }
+ val argListSymbol = q"${c.fresh[TermName]("argsList")}"
+ val extrasSymbol = q"${c.fresh[TermName]("extras")}"
+ val defaults = for ((arg, i) <- flattenedArgLists.zipWithIndex) yield {
+ val arg = TermName(c.freshName())
+ hasDefault(i).map(defaultName => q"($arg: $curCls) => $arg.${newTermName(defaultName)}")
+ }
+
+ def getDocAnnotation(annotations: List[Annotation]) = {
+ val (docTrees, remaining) = annotations.partition(_.tpe =:= typeOf[Router.doc])
+ val docValues = for {
+ doc <- docTrees
+ if doc.scalaArgs.head.isInstanceOf[Literal]
+ l = doc.scalaArgs.head.asInstanceOf[Literal]
+ if l.value.value.isInstanceOf[String]
+ } yield l.value.value.asInstanceOf[String]
+ (remaining, docValues.headOption)
+ }
+
+ def unwrapVarargType(arg: Symbol) = {
+ val vararg = arg.typeSignature.typeSymbol == definitions.RepeatedParamClass
+ val unwrappedType =
+ if (!vararg) arg.typeSignature
+ else arg.typeSignature.asInstanceOf[TypeRef].args(0)
+
+ (vararg, unwrappedType)
+ }
+
+
+ val (_, methodDoc) = getDocAnnotation(meth.annotations)
+ val readArgSigs = for(
+ ((arg, defaultOpt), i) <- flattenedArgLists.zip(defaults).zipWithIndex
+ ) yield {
+
+ val (vararg, varargUnwrappedType) = unwrapVarargType(arg)
+
+ val default =
+ if (vararg) q"scala.Some(scala.Nil)"
+ else defaultOpt match {
+ case Some(defaultExpr) => q"scala.Some($defaultExpr($baseArgSym))"
+ case None => q"scala.None"
+ }
+
+ val (docUnwrappedType, docOpt) = varargUnwrappedType match{
+ case t: AnnotatedType =>
+ import compat._
+ val (remaining, docValue) = getDocAnnotation(t.annotations)
+ if (remaining.isEmpty) (t.underlying, docValue)
+ else (c.universe.AnnotatedType(remaining, t.underlying), docValue)
+
+ case t => (t, None)
+ }
+
+ val docTree = docOpt match{
+ case None => q"scala.None"
+ case Some(s) => q"scala.Some($s)"
+ }
+
+ val argSig = q"""
+ cask.Router.ArgSig[$curCls, $docUnwrappedType](
+ ${arg.name.toString},
+ ${docUnwrappedType.toString + (if(vararg) "*" else "")},
+ $docTree,
+ $defaultOpt
+ )
+ """
+
+ val reader =
+ if(vararg) q"""
+ cask.Router.makeReadVarargsCall(
+ $argSig,
+ $extrasSymbol
+ )
+ """ else q"""
+ cask.Router.makeReadCall(
+ $argListSymbol,
+ $default,
+ $argSig
+ )
+ """
+ c.internal.setPos(reader, meth.pos)
+ (reader, argSig, vararg)
+ }
+
+ val (readArgs, argSigs, varargs) = readArgSigs.unzip3
+ val (argNames, argNameCasts) = flattenedArgLists.map { arg =>
+ val (vararg, unwrappedType) = unwrapVarargType(arg)
+ (
+ pq"${arg.name.toTermName}",
+ if (!vararg) q"${arg.name.toTermName}.asInstanceOf[$unwrappedType]"
+ else q"${arg.name.toTermName}.asInstanceOf[Seq[$unwrappedType]]: _*"
+
+ )
+ }.unzip
+
+
+ val res = q"""
+ cask.Router.EntryPoint[$curCls](
+ ${meth.name.toString},
+ scala.Seq(..$argSigs),
+ ${methodDoc match{
+ case None => q"scala.None"
+ case Some(s) => q"scala.Some($s)"
+ }},
+ ${varargs.contains(true)},
+ ($baseArgSym: $curCls, $argListSymbol: Map[String, String], $extrasSymbol: Seq[String]) =>
+ cask.Router.validate(Seq(..$readArgs)) match{
+ case cask.Router.Result.Success(List(..$argNames)) =>
+ cask.Router.Result.Success(
+ $baseArgSym.${meth.name.toTermName}(..$argNameCasts): cask.Response
+ )
+ case x: cask.Router.Result.Error => x
+ },
+ cask.Router.Overrides()
+ )
+ """
+
+ c.internal.transform(res){(t, a) =>
+ c.internal.setPos(t, meth.pos)
+ a.default(t)
+ }
+ res
+ }
+
+ def hasMainAnnotation(t: MethodSymbol) = {
+ t.annotations.exists(_.tpe =:= typeOf[Router.main])
+ }
+ def getAllRoutesForClass(curCls: Type,
+ pred: MethodSymbol => Boolean = hasMainAnnotation)
+ : Iterable[c.universe.Tree] = {
+ for{
+ t <- getValsOrMeths(curCls)
+ if pred(t)
+ } yield {
+ extractMethod(t, curCls)
+ }
+ }
+} \ No newline at end of file
diff --git a/cask/src/cask/Util.scala b/cask/src/cask/Util.scala
new file mode 100644
index 0000000..0856f4d
--- /dev/null
+++ b/cask/src/cask/Util.scala
@@ -0,0 +1,25 @@
+package cask
+
+object Util {
+ def trimSplit(p: String) = p.dropWhile(_ == '/').reverse.dropWhile(_ == '/').reverse.split('/')
+ def matchRoute(route: String, path: String): Option[Map[String, String]] = {
+ val routeSegments = trimSplit(route)
+ val pathSegments = trimSplit(path)
+
+ def rec(i: Int, bindings: Map[String, String]): Option[Map[String, String]] = {
+ if (routeSegments.length == i && pathSegments.length == i) Some(bindings)
+ else if ((routeSegments.length == i) != (pathSegments.length == i)) None
+ else {
+ val routeSeg = routeSegments(i)
+ val pathSeg = pathSegments(i)
+ if (routeSeg(0) == ':' && routeSeg(1) == ':') {
+ Some(bindings + (routeSeg.drop(2) -> pathSegments.drop(i).mkString("/")))
+ }
+ else if (routeSeg(0) == ':') rec(i+1, bindings + (routeSeg.drop(1) -> pathSeg))
+ else if (pathSeg == routeSeg) rec(i + 1, bindings)
+ else None
+ }
+ }
+ rec(0, Map.empty)
+ }
+}
diff --git a/cask/test/src/test/cask/CaskTest.scala b/cask/test/src/test/cask/CaskTest.scala
index a1384f8..9a0876d 100644
--- a/cask/test/src/test/cask/CaskTest.scala
+++ b/cask/test/src/test/cask/CaskTest.scala
@@ -1,34 +1,24 @@
package test.cask
+object MyServer extends cask.Routes{
-object MyServer extends cask.Server(){
- def x = "/ext"
- @cask.route("/user/:username" + (x * 2))
+ @cask.route("/user/:userName")
def showUserProfile(userName: String) = {
- // show the user profile for that user
s"User $userName"
}
- @cask.route("/post/:int")
- def showPost(postId: Int) = {
- // show the post with the given id, the id is an integer
- s"Post $postId"
+ @cask.route("/post/:postId")
+ def showPost(postId: Int, query: String) = {
+ s"Post $postId $query"
}
- @cask.route("/path/:subPath")
- def show_subpath(subPath: String) = {
- // show the subpath after /path/
+ @cask.route("/path/::subPath")
+ def showSubpath(subPath: String) = {
s"Subpath $subPath"
}
initialize()
- println(routes.value)
-}
-
-object Main extends cask.Main(MyServer){
- def main(args: Array[String]): Unit = {
- MyServer
- }
}
+object Main extends cask.Main(MyServer)