summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLi Haoyi <haoyi.sg@gmail.com>2018-07-21 18:55:23 +0800
committerLi Haoyi <haoyi.sg@gmail.com>2018-07-21 18:55:23 +0800
commit65f27110fa611c34fe8867dbe69ba608da23f592 (patch)
treed50e3fe3711ddda8e074e7758e709ca7a1e37e92
parentc1dcd6b5794bd1ceb3d92edc8a7f730c93098ef3 (diff)
downloadcask-65f27110fa611c34fe8867dbe69ba608da23f592.tar.gz
cask-65f27110fa611c34fe8867dbe69ba608da23f592.tar.bz2
cask-65f27110fa611c34fe8867dbe69ba608da23f592.zip
Route requests paths using a proper trie
-rw-r--r--cask/src/cask/DispatchTrie.scala64
-rw-r--r--cask/src/cask/Main.scala42
-rw-r--r--cask/src/cask/Routes.scala2
-rw-r--r--cask/src/cask/Util.scala22
-rw-r--r--cask/test/src/test/cask/CaskTest.scala113
5 files changed, 178 insertions, 65 deletions
diff --git a/cask/src/cask/DispatchTrie.scala b/cask/src/cask/DispatchTrie.scala
new file mode 100644
index 0000000..034bb55
--- /dev/null
+++ b/cask/src/cask/DispatchTrie.scala
@@ -0,0 +1,64 @@
+package cask
+import collection.mutable
+object DispatchTrie{
+ def construct[T](index: Int, inputs: Seq[(IndexedSeq[String], T)]): DispatchTrie[T] = {
+ val continuations = mutable.Map.empty[String, mutable.Buffer[(IndexedSeq[String], T)]]
+
+ val terminals = mutable.Buffer.empty[(IndexedSeq[String], T)]
+
+ for((path, endPoint) <- inputs) {
+ if (path.length < index) () // do nothing
+ else if (path.length == index) {
+ terminals.append(path -> endPoint)
+ } else if (path.length > index){
+ val buf = continuations.getOrElseUpdate(path(index), mutable.Buffer.empty)
+ buf.append(path -> endPoint)
+ }
+ }
+
+ val wildcards = continuations.filter(_._1(0) == ':')
+ if (terminals.length > 1){
+ throw new Exception(
+ "More than one endpoint has the same path: " +
+ terminals.map(_._1.map(_.mkString("/"))).mkString(", ")
+ )
+ } else if(wildcards.size >= 1 && continuations.size > 1) {
+ throw new Exception(
+ "Routes overlap with wildcards: " +
+ (wildcards ++ continuations).flatMap(_._2).map(_._1.mkString("/"))
+ )
+ }else{
+ DispatchTrie[T](
+ current = terminals.headOption.map(_._2),
+ children = continuations.map{ case (k, vs) =>
+ if (!k.startsWith("::")) (k, construct(index + 1, vs))
+ else (k, DispatchTrie(Some(vs.head._2), Map()))
+ }.toMap
+ )
+ }
+ }
+}
+
+case class DispatchTrie[T](current: Option[T],
+ children: Map[String, DispatchTrie[T]]){
+ final def lookup(input: List[String],
+ bindings: Map[String, String])
+ : Option[(T, Map[String, String])] = {
+ input match{
+ case Nil => current.map(_ -> bindings)
+ case head :: rest =>
+ if (children.size == 1 && children.keys.head.startsWith("::")){
+ children.values.head.lookup(Nil, bindings + (children.keys.head.drop(2) -> input.mkString("/")))
+ }else if (children.size == 1 && children.keys.head.startsWith(":")){
+ children.values.head.lookup(rest, bindings + (children.keys.head.drop(1) -> head))
+ }else{
+ children.get(head) match{
+ case None => None
+ case Some(continuation) => continuation.lookup(rest, bindings)
+ }
+ }
+
+ }
+ }
+
+}
diff --git a/cask/src/cask/Main.scala b/cask/src/cask/Main.scala
index 1d86de1..acf7a0d 100644
--- a/cask/src/cask/Main.scala
+++ b/cask/src/cask/Main.scala
@@ -9,41 +9,35 @@ import io.undertow.server.{HttpHandler, HttpServerExchange}
import io.undertow.util.{Headers, HttpString}
class MainRoutes extends BaseMain with Routes{
- def servers = Seq(this)
+ def allRoutes = Seq(this)
}
class Main(servers0: Routes*) extends BaseMain{
- def servers = servers0.toSeq
+ def allRoutes = servers0.toSeq
}
abstract class BaseMain{
- def servers: Seq[Routes]
+ def allRoutes: Seq[Routes]
val port: Int = 8080
val host: String = "localhost"
- val allRoutes = for{
- server <- servers
- route <- server.caskMetadata.value.map(x => x: Routes.RouteMetadata[_])
- } yield (server, route)
+ lazy val routeList = for{
+ routes <- allRoutes
+ route <- routes.caskMetadata.value.map(x => x: Routes.RouteMetadata[_])
+ } yield (routes, route)
- val defaultHandler = new HttpHandler() {
- def handleRequest(exchange: HttpServerExchange): Unit = {
- val routeOpt =
- allRoutes
- .iterator
- .map { case (s: Routes, r: Routes.RouteMetadata[_]) =>
- Util.matchRoute(r.metadata.path, exchange.getRequestPath).map((s, r, _))
- }
- .flatten
- .toStream
- .headOption
+ lazy val routeTrie = DispatchTrie.construct[(Routes, Router.EntryPoint[_, HttpServerExchange])](0,
+ for((route, metadata) <- routeList)
+ yield (Util.splitPath(metadata.metadata.path): IndexedSeq[String], (route, metadata.entryPoint))
+ )
-
- routeOpt match{
+ lazy val defaultHandler = new HttpHandler() {
+ def handleRequest(exchange: HttpServerExchange): Unit = {
+ routeTrie.lookup(Util.splitPath(exchange.getRequestPath).toList, Map()) match{
case None =>
exchange.setStatusCode(404)
exchange.getResponseHeaders.put(Headers.CONTENT_TYPE, "text/plain")
exchange.getResponseSender.send("404 Not Found")
- case Some((server, route, bindings)) =>
+ case Some(((routes, entrypoint), bindings)) =>
import collection.JavaConverters._
val allBindings =
bindings.toSeq ++
@@ -52,9 +46,9 @@ abstract class BaseMain{
.toSeq
.flatMap{case (k, vs) => vs.asScala.map((k, _))}
- val result = route.entryPoint
- .asInstanceOf[EntryPoint[server.type, HttpServerExchange]]
- .invoke(server, exchange, allBindings.map{case (k, v) => (k, Some(v))})
+ val result = entrypoint
+ .asInstanceOf[EntryPoint[routes.type, HttpServerExchange]]
+ .invoke(routes, exchange, allBindings.map{case (k, v) => (k, Some(v))})
result match{
case Router.Result.Success(response: Response) =>
diff --git a/cask/src/cask/Routes.scala b/cask/src/cask/Routes.scala
index 0e88a5e..5ce8f6c 100644
--- a/cask/src/cask/Routes.scala
+++ b/cask/src/cask/Routes.scala
@@ -38,7 +38,7 @@ object Routes{
val routeParts = for{
m <- c.weakTypeOf[T].members
- annot <- m.annotations.filter(_.tree.tpe <:< c.weakTypeOf[RouteBase])
+ annot <- m.annotations.filter(_.tree.tpe <:< c.weakTypeOf[AnnotationBase])
} yield {
val annotObject = q"new ${annot.tree.tpe}(..${annot.tree.children.tail})"
val annotObjectSym = c.universe.TermName(c.freshName("annotObject"))
diff --git a/cask/src/cask/Util.scala b/cask/src/cask/Util.scala
index 89adfa1..0f88883 100644
--- a/cask/src/cask/Util.scala
+++ b/cask/src/cask/Util.scala
@@ -1,27 +1,7 @@
package cask
object Util {
- def trimSplit(p: String) =
+ def splitPath(p: String) =
p.dropWhile(_ == '/').reverse.dropWhile(_ == '/').reverse.split('/').filter(_.nonEmpty)
- def matchRoute(route: String, path: String): Option[Map[String, String]] = {
- val routeSegments = trimSplit(route)
- val pathSegments = trimSplit(path)
-
- def rec(i: Int, bindings: Map[String, String]): Option[Map[String, String]] = {
- if (routeSegments.length == i && pathSegments.length == i) Some(bindings)
- else if ((routeSegments.length == i) != (pathSegments.length == i)) None
- else {
- val routeSeg = routeSegments(i)
- val pathSeg = pathSegments(i)
- if (routeSeg(0) == ':' && routeSeg(1) == ':') {
- Some(bindings + (routeSeg.drop(2) -> pathSegments.drop(i).mkString("/")))
- }
- else if (routeSeg(0) == ':') rec(i+1, bindings + (routeSeg.drop(1) -> pathSeg))
- else if (pathSeg == routeSeg) rec(i + 1, bindings)
- else None
- }
- }
- rec(0, Map.empty)
- }
}
diff --git a/cask/test/src/test/cask/CaskTest.scala b/cask/test/src/test/cask/CaskTest.scala
index 5330dce..7bbb8de 100644
--- a/cask/test/src/test/cask/CaskTest.scala
+++ b/cask/test/src/test/cask/CaskTest.scala
@@ -1,26 +1,101 @@
package test.cask
+import cask.DispatchTrie
+import utest._
-import io.undertow.server.HttpServerExchange
+object CaskTest extends TestSuite {
+ val tests = Tests{
-object MyServer extends cask.Routes{
- @cask.get("/user/:userName")
- def showUserProfile(userName: String) = {
- s"User $userName"
- }
+ 'hello - {
+ val x = DispatchTrie.construct(0,
+ Seq(Vector("hello") -> 1)
+ )
- @cask.get("/post/:postId")
- def showPost(postId: Int, query: Seq[String]) = {
- s"Post $postId $query"
- }
+ assert(
+ x.lookup(List("hello"), Map()) == Some((1, Map())),
+ x.lookup(List("hello", "world"), Map()) == None,
+ x.lookup(List("world"), Map()) == None
+ )
+ }
+ 'nested - {
+ val x = DispatchTrie.construct(0,
+ Seq(
+ Vector("hello", "world") -> 1,
+ Vector("hello", "cow") -> 2
+ )
+ )
+ assert(
+ x.lookup(List("hello", "world"), Map()) == Some((1, Map())),
+ x.lookup(List("hello", "cow"), Map()) == Some((2, Map())),
+ x.lookup(List("hello"), Map()) == None,
+ x.lookup(List("hello", "moo"), Map()) == None,
+ x.lookup(List("hello", "world", "moo"), Map()) == None
+ )
+ }
+ 'bindings - {
+ val x = DispatchTrie.construct(0,
+ Seq(Vector(":hello", ":world") -> 1)
+ )
+ assert(
+ x.lookup(List("hello", "world"), Map()) == Some((1, Map("hello" -> "hello", "world" -> "world"))),
+ x.lookup(List("world", "hello"), Map()) == Some((1, Map("hello" -> "world", "world" -> "hello"))),
- @cask.get("/path/::subPath")
- def showSubpath(x: HttpServerExchange, subPath: String) = {
- val length = x.getInputStream.readAllBytes().length
- println(x)
- s"Subpath $subPath + $length"
- }
+ x.lookup(List("hello", "world", "cow"), Map()) == None,
+ x.lookup(List("hello"), Map()) == None
+ )
+ }
- initialize()
-}
+ 'path - {
+ val x = DispatchTrie.construct(0,
+ Seq(Vector("hello", "::world") -> 1)
+ )
+ assert(
+ x.lookup(List("hello", "world"), Map()) == Some((1,Map("world" -> "world"))),
+ x.lookup(List("hello", "world", "cow"), Map()) == Some((1,Map("world" -> "world/cow"))),
+ x.lookup(List("hello"), Map()) == None
+ )
+ }
-object Main extends cask.Main(MyServer)
+ 'errors - {
+ intercept[Exception]{
+ DispatchTrie.construct(0,
+ Seq(
+ Vector("hello", ":world") -> 1,
+ Vector("hello", "world") -> 2
+ )
+ )
+ }
+ intercept[Exception]{
+ DispatchTrie.construct(0,
+ Seq(
+ Vector("hello", ":world") -> 1,
+ Vector("hello", "world", "omg") -> 2
+ )
+ )
+ }
+ intercept[Exception]{
+ DispatchTrie.construct(0,
+ Seq(
+ Vector("hello", "::world") -> 1,
+ Vector("hello", "cow", "omg") -> 2
+ )
+ )
+ }
+ intercept[Exception]{
+ DispatchTrie.construct(0,
+ Seq(
+ Vector("hello", ":world") -> 1,
+ Vector("hello", ":cow") -> 2
+ )
+ )
+ }
+ intercept[Exception]{
+ DispatchTrie.construct(0,
+ Seq(
+ Vector("hello", "world") -> 1,
+ Vector("hello", "world") -> 2
+ )
+ )
+ }
+ }
+ }
+} \ No newline at end of file