summaryrefslogtreecommitdiff
path: root/cask
diff options
context:
space:
mode:
authorLi Haoyi <haoyi.sg@gmail.com>2018-08-13 03:54:50 +0800
committerLi Haoyi <haoyi.sg@gmail.com>2018-08-13 03:54:50 +0800
commit2fc9fd22084bb4a89a72be525c18fc409303ada5 (patch)
tree511734c255237e99ba0b5b302232073572faaa38 /cask
parent790deda0f38e36c7378ff05a9c234a56e14a5d6b (diff)
downloadcask-2fc9fd22084bb4a89a72be525c18fc409303ada5.tar.gz
cask-2fc9fd22084bb4a89a72be525c18fc409303ada5.tar.bz2
cask-2fc9fd22084bb4a89a72be525c18fc409303ada5.zip
Basic websocket support works
Diffstat (limited to 'cask')
-rw-r--r--cask/src/cask/endpoints/FormEndpoint.scala4
-rw-r--r--cask/src/cask/endpoints/JsonEndpoint.scala5
-rw-r--r--cask/src/cask/endpoints/StaticEndpoints.scala6
-rw-r--r--cask/src/cask/endpoints/WebEndpoints.scala6
-rw-r--r--cask/src/cask/endpoints/WebSocketEndpoint.scala52
-rw-r--r--cask/src/cask/internal/Router.scala2
-rw-r--r--cask/src/cask/main/Decorators.scala13
-rw-r--r--cask/src/cask/main/Main.scala100
-rw-r--r--cask/src/cask/main/Routes.scala10
-rw-r--r--cask/src/cask/model/Params.scala1
-rw-r--r--cask/src/cask/package.scala6
11 files changed, 153 insertions, 52 deletions
diff --git a/cask/src/cask/endpoints/FormEndpoint.scala b/cask/src/cask/endpoints/FormEndpoint.scala
index 48190ce..eb882fa 100644
--- a/cask/src/cask/endpoints/FormEndpoint.scala
+++ b/cask/src/cask/endpoints/FormEndpoint.scala
@@ -1,7 +1,7 @@
package cask.endpoints
import cask.internal.{Router, Util}
-import cask.main.{Endpoint, Routes}
+import cask.main.{Endpoint, HttpDecorator, Routes}
import cask.model._
import io.undertow.server.handlers.form.FormParserFactory
@@ -43,7 +43,7 @@ object FormReader{
def read(ctx: ParamContext, label: String, input: Seq[FormEntry]) = input.map(_.asInstanceOf[FormFile])
}
}
-class postForm(val path: String, override val subpath: Boolean = false) extends Endpoint{
+class postForm(val path: String, override val subpath: Boolean = false) extends Endpoint with HttpDecorator{
type Output = Response
val methods = Seq("post")
diff --git a/cask/src/cask/endpoints/JsonEndpoint.scala b/cask/src/cask/endpoints/JsonEndpoint.scala
index 51199c3..f3b0cae 100644
--- a/cask/src/cask/endpoints/JsonEndpoint.scala
+++ b/cask/src/cask/endpoints/JsonEndpoint.scala
@@ -4,7 +4,7 @@ import java.io.ByteArrayOutputStream
import cask.internal.{Router, Util}
import cask.internal.Router.EntryPoint
-import cask.main.{Endpoint, Routes}
+import cask.main.{Endpoint, HttpDecorator, Routes}
import cask.model.{ParamContext, Response}
@@ -26,12 +26,11 @@ object JsReader{
}
}
}
-class postJson(val path: String, override val subpath: Boolean = false) extends Endpoint{
+class postJson(val path: String, override val subpath: Boolean = false) extends Endpoint with HttpDecorator{
type Output = Response
val methods = Seq("post")
type Input = ujson.Js.Value
type InputParser[T] = JsReader[T]
-
def wrapFunction(ctx: ParamContext,
delegate: Map[String, Input] => Router.Result[Output]): Router.Result[Response] = {
val obj = for{
diff --git a/cask/src/cask/endpoints/StaticEndpoints.scala b/cask/src/cask/endpoints/StaticEndpoints.scala
index e93c09b..a9b3193 100644
--- a/cask/src/cask/endpoints/StaticEndpoints.scala
+++ b/cask/src/cask/endpoints/StaticEndpoints.scala
@@ -1,10 +1,10 @@
package cask.endpoints
import cask.internal.Router
-import cask.main.Endpoint
-import cask.model.{Response, ParamContext}
+import cask.main.{Endpoint, HttpDecorator}
+import cask.model.{ParamContext, Response}
-class static(val path: String) extends Endpoint {
+class static(val path: String) extends Endpoint with HttpDecorator{
type Output = String
val methods = Seq("get")
type Input = Seq[String]
diff --git a/cask/src/cask/endpoints/WebEndpoints.scala b/cask/src/cask/endpoints/WebEndpoints.scala
index 2125b4d..41c3113 100644
--- a/cask/src/cask/endpoints/WebEndpoints.scala
+++ b/cask/src/cask/endpoints/WebEndpoints.scala
@@ -1,13 +1,13 @@
package cask.endpoints
import cask.internal.Router
-import cask.main.Endpoint
-import cask.model.{Response, ParamContext}
+import cask.main.{Endpoint, HttpDecorator}
+import cask.model.{ParamContext, Response}
import collection.JavaConverters._
-trait WebEndpoint extends Endpoint{
+trait WebEndpoint extends Endpoint with HttpDecorator{
type Output = Response
type Input = Seq[String]
type InputParser[T] = QueryParamReader[T]
diff --git a/cask/src/cask/endpoints/WebSocketEndpoint.scala b/cask/src/cask/endpoints/WebSocketEndpoint.scala
new file mode 100644
index 0000000..a795afd
--- /dev/null
+++ b/cask/src/cask/endpoints/WebSocketEndpoint.scala
@@ -0,0 +1,52 @@
+package cask.endpoints
+
+import cask.internal.Router
+import cask.model.{ParamContext, Subpath}
+import io.undertow.server.HttpServerExchange
+import io.undertow.websockets.WebSocketConnectionCallback
+trait WebsocketParam[T] extends Router.ArgReader[Seq[String], T, cask.model.ParamContext]
+
+object WebsocketParam{
+ class NilParam[T](f: (ParamContext, String) => T) extends WebsocketParam[T]{
+ def arity = 0
+ def read(ctx: ParamContext, label: String, v: Seq[String]): T = f(ctx, label)
+ }
+ implicit object HttpExchangeParam extends NilParam[HttpServerExchange](
+ (ctx, label) => ctx.exchange
+ )
+ implicit object SubpathParam extends NilParam[Subpath](
+ (ctx, label) => new Subpath(ctx.remaining)
+ )
+ class SimpleParam[T](f: String => T) extends WebsocketParam[T]{
+ def arity = 1
+ def read(ctx: cask.model.ParamContext, label: String, v: Seq[String]): T = f(v.head)
+ }
+
+ implicit object StringParam extends SimpleParam[String](x => x)
+ implicit object BooleanParam extends SimpleParam[Boolean](_.toBoolean)
+ implicit object ByteParam extends SimpleParam[Byte](_.toByte)
+ implicit object ShortParam extends SimpleParam[Short](_.toShort)
+ implicit object IntParam extends SimpleParam[Int](_.toInt)
+ implicit object LongParam extends SimpleParam[Long](_.toLong)
+ implicit object DoubleParam extends SimpleParam[Double](_.toDouble)
+ implicit object FloatParam extends SimpleParam[Float](_.toFloat)
+}
+
+sealed trait WebsocketResult
+object WebsocketResult{
+ implicit class Response(val value: cask.model.Response) extends WebsocketResult
+ implicit class Listener(val value: WebSocketConnectionCallback) extends WebsocketResult
+}
+
+class websocket(val path: String, subpath: Boolean = false) extends cask.main.BaseEndpoint{
+ type Output = WebsocketResult
+ val methods = Seq("websocket")
+ type Input = Seq[String]
+ type InputParser[T] = WebsocketParam[T]
+ type Returned = Router.Result[WebsocketResult]
+ def wrapFunction(ctx: ParamContext, delegate: Delegate): Returned = delegate(Map())
+
+ def wrapPathSegment(s: String): Input = Seq(s)
+
+
+}
diff --git a/cask/src/cask/internal/Router.scala b/cask/src/cask/internal/Router.scala
index 36976af..9fa9b60 100644
--- a/cask/src/cask/internal/Router.scala
+++ b/cask/src/cask/internal/Router.scala
@@ -204,7 +204,7 @@ class Router[C <: Context](val c: C) {
def extractMethod(method: MethodSymbol,
curCls: c.universe.Type,
convertToResultType: c.Tree,
- ctx: c.Type,
+ ctx: c.Tree,
argReaders: Seq[c.Tree],
annotDeserializeTypes: Seq[c.Tree]): c.universe.Tree = {
val baseArgSym = TermName(c.freshName())
diff --git a/cask/src/cask/main/Decorators.scala b/cask/src/cask/main/Decorators.scala
index 73d8c19..239cab4 100644
--- a/cask/src/cask/main/Decorators.scala
+++ b/cask/src/cask/main/Decorators.scala
@@ -2,13 +2,15 @@ package cask.main
import cask.internal.Router
import cask.internal.Router.ArgReader
-import cask.model.{Response, ParamContext}
+import cask.model.{ParamContext, Response}
+
+trait Endpoint extends BaseEndpoint with HttpDecorator
/**
* Used to annotate a single Cask endpoint function; similar to a [[Decorator]]
* but with additional metadata and capabilities.
*/
-trait Endpoint extends BaseDecorator{
+trait BaseEndpoint extends BaseDecorator{
/**
* What is the path that this particular endpoint matches?
*/
@@ -45,10 +47,13 @@ trait BaseDecorator{
type InputParser[T] <: ArgReader[Input, T, ParamContext]
type Output
type Delegate = Map[String, Input] => Router.Result[Output]
- type Returned = Router.Result[Response]
+ type Returned <: Router.Result[Any]
def wrapFunction(ctx: ParamContext, delegate: Delegate): Returned
def getParamParser[T](implicit p: InputParser[T]) = p
}
+trait HttpDecorator extends BaseDecorator{
+ type Returned = Router.Result[Response]
+}
/**
* A decorator allows you to annotate a function to wrap it, via
@@ -61,7 +66,7 @@ trait BaseDecorator{
* to `wrapFunction`, which takes a `Map` representing any additional argument
* lists (if any).
*/
-trait Decorator extends BaseDecorator {
+trait Decorator extends HttpDecorator {
type Input = Any
type Output = Response
diff --git a/cask/src/cask/main/Main.scala b/cask/src/cask/main/Main.scala
index 87c66c4..94d0e14 100644
--- a/cask/src/cask/main/Main.scala
+++ b/cask/src/cask/main/Main.scala
@@ -1,5 +1,6 @@
package cask.main
+import cask.endpoints.WebsocketResult
import cask.model._
import cask.internal.Router.EntryPoint
import cask.internal.{DispatchTrie, Router, Util}
@@ -27,7 +28,7 @@ abstract class BaseMain{
} yield (routes, route)
- lazy val routeTries = Seq("get", "put", "post")
+ lazy val routeTries = Seq("get", "put", "post", "websocket")
.map { method =>
method -> DispatchTrie.construct[(Routes, Routes.EndpointMetadata[_])](0,
for ((route, metadata) <- routeList if metadata.endpoint.methods.contains(method))
@@ -52,42 +53,81 @@ abstract class BaseMain{
)
}
+ def genericWebsocketHandler(exchange0: HttpServerExchange) =
+ hello(exchange0, "websocket", ParamContext(exchange0, _), exchange0.getRequestPath).foreach{ r =>
+ r.asInstanceOf[WebsocketResult] match{
+ case l: WebsocketResult.Listener =>
+ io.undertow.Handlers.websocket(l.value).handleRequest(exchange0)
+ case r: WebsocketResult.Response =>
+ writeResponseHandler(r).handleRequest(exchange0)
+ }
+ }
- def defaultHandler = new BlockingHandler(
+ def defaultHandler =
new HttpHandler() {
def handleRequest(exchange: HttpServerExchange): Unit = {
- routeTries(exchange.getRequestMethod.toString.toLowerCase()).lookup(Util.splitPath(exchange.getRequestPath).toList, Map()) match{
- case None => writeResponse(exchange, handleNotFound())
- case Some(((routes, metadata), extBindings, remaining)) =>
- val ctx = ParamContext(exchange, remaining)
- def rec(remaining: List[Decorator],
- bindings: List[Map[String, Any]]): Router.Result[Response] = try {
- remaining match {
- case head :: rest =>
- head.wrapFunction(ctx, args => rec(rest, args :: bindings))
-
- case Nil =>
- metadata.endpoint.wrapFunction(ctx, epBindings =>
- metadata.entryPoint
- .asInstanceOf[EntryPoint[cask.main.Routes, cask.model.ParamContext]]
- .invoke(routes, ctx, (epBindings ++ extBindings.mapValues(metadata.endpoint.wrapPathSegment)) :: bindings.reverse)
- .asInstanceOf[Router.Result[Nothing]]
- )
-
- }
- // Make sure we wrap any exceptions that bubble up from decorator
- // bodies, so outer decorators do not need to worry about their
- // delegate throwing on them
- }catch{case e: Throwable => Router.Result.Error.Exception(e) }
-
- rec((metadata.decorators ++ routes.decorators ++ mainDecorators).toList, Nil)match{
- case Router.Result.Success(response: Response) => writeResponse(exchange, response)
- case e: Router.Result.Error => writeResponse(exchange, handleEndpointError(exchange, routes, metadata, e))
- }
+ if (exchange.getRequestHeaders.getFirst("Upgrade") == "websocket") {
+
+ genericWebsocketHandler(exchange)
+ } else {
+ defaultHttpHandler.handleRequest(exchange)
}
}
}
+
+ def writeResponseHandler(r: WebsocketResult.Response) = new BlockingHandler(
+ new HttpHandler {
+ def handleRequest(exchange: HttpServerExchange): Unit = {
+ writeResponse(exchange, r.value)
+ }
+ }
)
+ def defaultHttpHandler = new BlockingHandler(
+ new HttpHandler() {
+ def handleRequest(exchange: HttpServerExchange) = {
+ hello(exchange, exchange.getRequestMethod.toString.toLowerCase(), ParamContext(exchange, _), exchange.getRequestPath).foreach{ r =>
+ writeResponse(exchange, r.asInstanceOf[Response])
+ }
+ }
+ }
+ )
+
+ def hello(exchange0: HttpServerExchange, effectiveMethod: String, ctx0: Seq[String] => ParamContext, path: String) = {
+ routeTries(effectiveMethod).lookup(Util.splitPath(path).toList, Map()) match{
+ case None =>
+ writeResponse(exchange0, handleNotFound())
+ None
+ case Some(((routes, metadata), extBindings, remaining)) =>
+ val ctx = ParamContext(exchange0, remaining)
+ val ctx1 = ctx0(remaining)
+ def rec(remaining: List[Decorator],
+ bindings: List[Map[String, Any]]): Router.Result[Any] = try {
+ remaining match {
+ case head :: rest =>
+ head.wrapFunction(ctx, args => rec(rest, args :: bindings).asInstanceOf[Router.Result[head.Output]])
+
+ case Nil =>
+ metadata.endpoint.wrapFunction(ctx, epBindings =>
+ metadata.entryPoint
+ .asInstanceOf[EntryPoint[cask.main.Routes, cask.model.ParamContext]]
+ .invoke(routes, ctx1, (epBindings ++ extBindings.mapValues(metadata.endpoint.wrapPathSegment)) :: bindings.reverse)
+ .asInstanceOf[Router.Result[Nothing]]
+ )
+ }
+ // Make sure we wrap any exceptions that bubble up from decorator
+ // bodies, so outer decorators do not need to worry about their
+ // delegate throwing on them
+ }catch{case e: Throwable => Router.Result.Error.Exception(e) }
+
+ rec((metadata.decorators ++ routes.decorators ++ mainDecorators).toList, Nil)match{
+ case Router.Result.Success(res) => Some(res)
+ case e: Router.Result.Error =>
+ writeResponse(exchange0, handleEndpointError(exchange0, routes, metadata, e))
+ None
+ }
+ }
+
+ }
def handleEndpointError(exchange: HttpServerExchange,
routes: Routes,
diff --git a/cask/src/cask/main/Routes.scala b/cask/src/cask/main/Routes.scala
index 7b47731..aaec832 100644
--- a/cask/src/cask/main/Routes.scala
+++ b/cask/src/cask/main/Routes.scala
@@ -8,8 +8,8 @@ import language.experimental.macros
object Routes{
case class EndpointMetadata[T](decorators: Seq[Decorator],
- endpoint: Endpoint,
- entryPoint: EntryPoint[T, ParamContext])
+ endpoint: BaseEndpoint,
+ entryPoint: EntryPoint[T, _])
case class RoutesEndpointsMetadata[T](value: EndpointMetadata[T]*)
object RoutesEndpointsMetadata{
implicit def initialize[T] = macro initializeImpl[T]
@@ -22,12 +22,12 @@ object Routes{
val annotations = m.annotations.filter(_.tree.tpe <:< c.weakTypeOf[BaseDecorator]).reverse
if annotations.nonEmpty
} yield {
- if(!(annotations.head.tree.tpe <:< weakTypeOf[Endpoint])) c.abort(
+ if(!(annotations.head.tree.tpe <:< weakTypeOf[BaseEndpoint])) c.abort(
annotations.head.tree.pos,
s"Last annotation applied to a function must be an instance of Endpoint, " +
s"not ${annotations.head.tree.tpe}"
)
- val allEndpoints = annotations.filter(_.tree.tpe <:< weakTypeOf[Endpoint])
+ val allEndpoints = annotations.filter(_.tree.tpe <:< weakTypeOf[BaseEndpoint])
if(allEndpoints.length > 1) c.abort(
annotations.head.tree.pos,
s"You can only apply one Endpoint annotation to a function, not " +
@@ -43,7 +43,7 @@ object Routes{
m.asInstanceOf[MethodSymbol],
weakTypeOf[T],
q"${annotObjectSyms.head}.convertToResultType",
- c.weakTypeOf[ParamContext],
+ tq"cask.ParamContext",
annotObjectSyms.map(annotObjectSym => q"$annotObjectSym.getParamParser"),
annotObjectSyms.map(annotObjectSym => tq"$annotObjectSym.Input")
diff --git a/cask/src/cask/model/Params.scala b/cask/src/cask/model/Params.scala
index 27c1a68..bd10161 100644
--- a/cask/src/cask/model/Params.scala
+++ b/cask/src/cask/model/Params.scala
@@ -6,6 +6,7 @@ import cask.endpoints.ParamReader.NilParam
import cask.internal.Util
import io.undertow.server.HttpServerExchange
import io.undertow.server.handlers.CookieImpl
+import io.undertow.websockets.spi.WebSocketHttpExchange
class Subpath(val value: Seq[String])
object Subpath{
diff --git a/cask/src/cask/package.scala b/cask/src/cask/package.scala
index 19dc675..b1a1550 100644
--- a/cask/src/cask/package.scala
+++ b/cask/src/cask/package.scala
@@ -22,6 +22,10 @@ package object cask {
val ParamContext = model.ParamContext
// endpoints
+ type websocket = endpoints.websocket
+ val WebsocketResult = endpoints.WebsocketResult
+ type WebsocketResult = endpoints.WebsocketResult
+
type get = endpoints.get
type post = endpoints.post
type put = endpoints.put
@@ -37,6 +41,6 @@ package object cask {
type Main = main.Main
type Decorator = main.Decorator
type Endpoint = main.Endpoint
- type BaseDecorator = main.BaseDecorator
+ type BaseDecorator = main.HttpDecorator
}