From 65f27110fa611c34fe8867dbe69ba608da23f592 Mon Sep 17 00:00:00 2001 From: Li Haoyi Date: Sat, 21 Jul 2018 18:55:23 +0800 Subject: Route requests paths using a proper trie --- cask/src/cask/DispatchTrie.scala | 64 +++++++++++++++++++ cask/src/cask/Main.scala | 42 ++++++------ cask/src/cask/Routes.scala | 2 +- cask/src/cask/Util.scala | 22 +------ cask/test/src/test/cask/CaskTest.scala | 113 +++++++++++++++++++++++++++------ 5 files changed, 178 insertions(+), 65 deletions(-) create mode 100644 cask/src/cask/DispatchTrie.scala diff --git a/cask/src/cask/DispatchTrie.scala b/cask/src/cask/DispatchTrie.scala new file mode 100644 index 0000000..034bb55 --- /dev/null +++ b/cask/src/cask/DispatchTrie.scala @@ -0,0 +1,64 @@ +package cask +import collection.mutable +object DispatchTrie{ + def construct[T](index: Int, inputs: Seq[(IndexedSeq[String], T)]): DispatchTrie[T] = { + val continuations = mutable.Map.empty[String, mutable.Buffer[(IndexedSeq[String], T)]] + + val terminals = mutable.Buffer.empty[(IndexedSeq[String], T)] + + for((path, endPoint) <- inputs) { + if (path.length < index) () // do nothing + else if (path.length == index) { + terminals.append(path -> endPoint) + } else if (path.length > index){ + val buf = continuations.getOrElseUpdate(path(index), mutable.Buffer.empty) + buf.append(path -> endPoint) + } + } + + val wildcards = continuations.filter(_._1(0) == ':') + if (terminals.length > 1){ + throw new Exception( + "More than one endpoint has the same path: " + + terminals.map(_._1.map(_.mkString("/"))).mkString(", ") + ) + } else if(wildcards.size >= 1 && continuations.size > 1) { + throw new Exception( + "Routes overlap with wildcards: " + + (wildcards ++ continuations).flatMap(_._2).map(_._1.mkString("/")) + ) + }else{ + DispatchTrie[T]( + current = terminals.headOption.map(_._2), + children = continuations.map{ case (k, vs) => + if (!k.startsWith("::")) (k, construct(index + 1, vs)) + else (k, DispatchTrie(Some(vs.head._2), Map())) + }.toMap + ) + } + } +} + +case class DispatchTrie[T](current: Option[T], + children: Map[String, DispatchTrie[T]]){ + final def lookup(input: List[String], + bindings: Map[String, String]) + : Option[(T, Map[String, String])] = { + input match{ + case Nil => current.map(_ -> bindings) + case head :: rest => + if (children.size == 1 && children.keys.head.startsWith("::")){ + children.values.head.lookup(Nil, bindings + (children.keys.head.drop(2) -> input.mkString("/"))) + }else if (children.size == 1 && children.keys.head.startsWith(":")){ + children.values.head.lookup(rest, bindings + (children.keys.head.drop(1) -> head)) + }else{ + children.get(head) match{ + case None => None + case Some(continuation) => continuation.lookup(rest, bindings) + } + } + + } + } + +} diff --git a/cask/src/cask/Main.scala b/cask/src/cask/Main.scala index 1d86de1..acf7a0d 100644 --- a/cask/src/cask/Main.scala +++ b/cask/src/cask/Main.scala @@ -9,41 +9,35 @@ import io.undertow.server.{HttpHandler, HttpServerExchange} import io.undertow.util.{Headers, HttpString} class MainRoutes extends BaseMain with Routes{ - def servers = Seq(this) + def allRoutes = Seq(this) } class Main(servers0: Routes*) extends BaseMain{ - def servers = servers0.toSeq + def allRoutes = servers0.toSeq } abstract class BaseMain{ - def servers: Seq[Routes] + def allRoutes: Seq[Routes] val port: Int = 8080 val host: String = "localhost" - val allRoutes = for{ - server <- servers - route <- server.caskMetadata.value.map(x => x: Routes.RouteMetadata[_]) - } yield (server, route) + lazy val routeList = for{ + routes <- allRoutes + route <- routes.caskMetadata.value.map(x => x: Routes.RouteMetadata[_]) + } yield (routes, route) - val defaultHandler = new HttpHandler() { - def handleRequest(exchange: HttpServerExchange): Unit = { - val routeOpt = - allRoutes - .iterator - .map { case (s: Routes, r: Routes.RouteMetadata[_]) => - Util.matchRoute(r.metadata.path, exchange.getRequestPath).map((s, r, _)) - } - .flatten - .toStream - .headOption + lazy val routeTrie = DispatchTrie.construct[(Routes, Router.EntryPoint[_, HttpServerExchange])](0, + for((route, metadata) <- routeList) + yield (Util.splitPath(metadata.metadata.path): IndexedSeq[String], (route, metadata.entryPoint)) + ) - - routeOpt match{ + lazy val defaultHandler = new HttpHandler() { + def handleRequest(exchange: HttpServerExchange): Unit = { + routeTrie.lookup(Util.splitPath(exchange.getRequestPath).toList, Map()) match{ case None => exchange.setStatusCode(404) exchange.getResponseHeaders.put(Headers.CONTENT_TYPE, "text/plain") exchange.getResponseSender.send("404 Not Found") - case Some((server, route, bindings)) => + case Some(((routes, entrypoint), bindings)) => import collection.JavaConverters._ val allBindings = bindings.toSeq ++ @@ -52,9 +46,9 @@ abstract class BaseMain{ .toSeq .flatMap{case (k, vs) => vs.asScala.map((k, _))} - val result = route.entryPoint - .asInstanceOf[EntryPoint[server.type, HttpServerExchange]] - .invoke(server, exchange, allBindings.map{case (k, v) => (k, Some(v))}) + val result = entrypoint + .asInstanceOf[EntryPoint[routes.type, HttpServerExchange]] + .invoke(routes, exchange, allBindings.map{case (k, v) => (k, Some(v))}) result match{ case Router.Result.Success(response: Response) => diff --git a/cask/src/cask/Routes.scala b/cask/src/cask/Routes.scala index 0e88a5e..5ce8f6c 100644 --- a/cask/src/cask/Routes.scala +++ b/cask/src/cask/Routes.scala @@ -38,7 +38,7 @@ object Routes{ val routeParts = for{ m <- c.weakTypeOf[T].members - annot <- m.annotations.filter(_.tree.tpe <:< c.weakTypeOf[RouteBase]) + annot <- m.annotations.filter(_.tree.tpe <:< c.weakTypeOf[AnnotationBase]) } yield { val annotObject = q"new ${annot.tree.tpe}(..${annot.tree.children.tail})" val annotObjectSym = c.universe.TermName(c.freshName("annotObject")) diff --git a/cask/src/cask/Util.scala b/cask/src/cask/Util.scala index 89adfa1..0f88883 100644 --- a/cask/src/cask/Util.scala +++ b/cask/src/cask/Util.scala @@ -1,27 +1,7 @@ package cask object Util { - def trimSplit(p: String) = + def splitPath(p: String) = p.dropWhile(_ == '/').reverse.dropWhile(_ == '/').reverse.split('/').filter(_.nonEmpty) - def matchRoute(route: String, path: String): Option[Map[String, String]] = { - val routeSegments = trimSplit(route) - val pathSegments = trimSplit(path) - - def rec(i: Int, bindings: Map[String, String]): Option[Map[String, String]] = { - if (routeSegments.length == i && pathSegments.length == i) Some(bindings) - else if ((routeSegments.length == i) != (pathSegments.length == i)) None - else { - val routeSeg = routeSegments(i) - val pathSeg = pathSegments(i) - if (routeSeg(0) == ':' && routeSeg(1) == ':') { - Some(bindings + (routeSeg.drop(2) -> pathSegments.drop(i).mkString("/"))) - } - else if (routeSeg(0) == ':') rec(i+1, bindings + (routeSeg.drop(1) -> pathSeg)) - else if (pathSeg == routeSeg) rec(i + 1, bindings) - else None - } - } - rec(0, Map.empty) - } } diff --git a/cask/test/src/test/cask/CaskTest.scala b/cask/test/src/test/cask/CaskTest.scala index 5330dce..7bbb8de 100644 --- a/cask/test/src/test/cask/CaskTest.scala +++ b/cask/test/src/test/cask/CaskTest.scala @@ -1,26 +1,101 @@ package test.cask +import cask.DispatchTrie +import utest._ -import io.undertow.server.HttpServerExchange +object CaskTest extends TestSuite { + val tests = Tests{ -object MyServer extends cask.Routes{ - @cask.get("/user/:userName") - def showUserProfile(userName: String) = { - s"User $userName" - } + 'hello - { + val x = DispatchTrie.construct(0, + Seq(Vector("hello") -> 1) + ) - @cask.get("/post/:postId") - def showPost(postId: Int, query: Seq[String]) = { - s"Post $postId $query" - } + assert( + x.lookup(List("hello"), Map()) == Some((1, Map())), + x.lookup(List("hello", "world"), Map()) == None, + x.lookup(List("world"), Map()) == None + ) + } + 'nested - { + val x = DispatchTrie.construct(0, + Seq( + Vector("hello", "world") -> 1, + Vector("hello", "cow") -> 2 + ) + ) + assert( + x.lookup(List("hello", "world"), Map()) == Some((1, Map())), + x.lookup(List("hello", "cow"), Map()) == Some((2, Map())), + x.lookup(List("hello"), Map()) == None, + x.lookup(List("hello", "moo"), Map()) == None, + x.lookup(List("hello", "world", "moo"), Map()) == None + ) + } + 'bindings - { + val x = DispatchTrie.construct(0, + Seq(Vector(":hello", ":world") -> 1) + ) + assert( + x.lookup(List("hello", "world"), Map()) == Some((1, Map("hello" -> "hello", "world" -> "world"))), + x.lookup(List("world", "hello"), Map()) == Some((1, Map("hello" -> "world", "world" -> "hello"))), - @cask.get("/path/::subPath") - def showSubpath(x: HttpServerExchange, subPath: String) = { - val length = x.getInputStream.readAllBytes().length - println(x) - s"Subpath $subPath + $length" - } + x.lookup(List("hello", "world", "cow"), Map()) == None, + x.lookup(List("hello"), Map()) == None + ) + } - initialize() -} + 'path - { + val x = DispatchTrie.construct(0, + Seq(Vector("hello", "::world") -> 1) + ) + assert( + x.lookup(List("hello", "world"), Map()) == Some((1,Map("world" -> "world"))), + x.lookup(List("hello", "world", "cow"), Map()) == Some((1,Map("world" -> "world/cow"))), + x.lookup(List("hello"), Map()) == None + ) + } -object Main extends cask.Main(MyServer) + 'errors - { + intercept[Exception]{ + DispatchTrie.construct(0, + Seq( + Vector("hello", ":world") -> 1, + Vector("hello", "world") -> 2 + ) + ) + } + intercept[Exception]{ + DispatchTrie.construct(0, + Seq( + Vector("hello", ":world") -> 1, + Vector("hello", "world", "omg") -> 2 + ) + ) + } + intercept[Exception]{ + DispatchTrie.construct(0, + Seq( + Vector("hello", "::world") -> 1, + Vector("hello", "cow", "omg") -> 2 + ) + ) + } + intercept[Exception]{ + DispatchTrie.construct(0, + Seq( + Vector("hello", ":world") -> 1, + Vector("hello", ":cow") -> 2 + ) + ) + } + intercept[Exception]{ + DispatchTrie.construct(0, + Seq( + Vector("hello", "world") -> 1, + Vector("hello", "world") -> 2 + ) + ) + } + } + } +} \ No newline at end of file -- cgit v1.2.3