summaryrefslogtreecommitdiff
path: root/cask/src/cask/endpoints/WebSocketEndpoint.scala
diff options
context:
space:
mode:
Diffstat (limited to 'cask/src/cask/endpoints/WebSocketEndpoint.scala')
-rw-r--r--cask/src/cask/endpoints/WebSocketEndpoint.scala85
1 files changed, 83 insertions, 2 deletions
diff --git a/cask/src/cask/endpoints/WebSocketEndpoint.scala b/cask/src/cask/endpoints/WebSocketEndpoint.scala
index 842d508..6728581 100644
--- a/cask/src/cask/endpoints/WebSocketEndpoint.scala
+++ b/cask/src/cask/endpoints/WebSocketEndpoint.scala
@@ -1,9 +1,15 @@
package cask.endpoints
-import cask.internal.Router
+import java.nio.ByteBuffer
+
+import cask.internal.{BatchActor, Router}
import cask.model.Request
import io.undertow.websockets.WebSocketConnectionCallback
-import collection.JavaConverters._
+import io.undertow.websockets.core.{AbstractReceiveListener, BufferedBinaryMessage, BufferedTextMessage, CloseMessage, WebSocketChannel, WebSockets}
+import io.undertow.websockets.spi.WebSocketHttpExchange
+
+import scala.concurrent.ExecutionContext
+
sealed trait WebsocketResult
object WebsocketResult{
implicit class Response[T](value0: cask.model.Response[T])
@@ -24,3 +30,78 @@ class websocket(val path: String, override val subpath: Boolean = false)
def wrapPathSegment(s: String): Seq[String] = Seq(s)
}
+
+case class WsHandler(f: WsChannelActor => BatchActor[WsActor.Event])(implicit ec: ExecutionContext)
+extends WebsocketResult with WebSocketConnectionCallback {
+ def onConnect(exchange: WebSocketHttpExchange, channel: WebSocketChannel): Unit = {
+ val actor = f(new WsChannelActor(channel))
+ channel.getReceiveSetter.set(
+ new AbstractReceiveListener() {
+ override def onFullTextMessage(channel: WebSocketChannel, message: BufferedTextMessage) = {
+ actor.send(WsActor.Text(message.getData))
+ }
+
+ override def onFullBinaryMessage(channel: WebSocketChannel, message: BufferedBinaryMessage): Unit = {
+ actor.send(WsActor.Binary(
+ WebSockets.mergeBuffers(message.getData.getResource:_*).array()
+ ))
+ }
+
+ override def onFullPingMessage(channel: WebSocketChannel, message: BufferedBinaryMessage): Unit = {
+ actor.send(WsActor.Ping(
+ WebSockets.mergeBuffers(message.getData.getResource:_*).array()
+ ))
+ }
+ override def onFullPongMessage(channel: WebSocketChannel, message: BufferedBinaryMessage): Unit = {
+ actor.send(WsActor.Pong(
+ WebSockets.mergeBuffers(message.getData.getResource:_*).array()
+ ))
+ }
+
+ override def onCloseMessage(cm: CloseMessage, channel: WebSocketChannel) = {
+ actor.send(WsActor.Close(cm.getCode, cm.getReason))
+ }
+ }
+ )
+ channel.resumeReceives()
+ }
+}
+
+class WsChannelActor(channel: WebSocketChannel)(implicit ec: ExecutionContext)
+extends BatchActor[WsActor.Event]{
+ def run(items: Seq[WsActor.Event]): Unit = items.foreach{
+ case WsActor.Text(value) => WebSockets.sendTextBlocking(value, channel)
+ case WsActor.Binary(value) => WebSockets.sendBinaryBlocking(ByteBuffer.wrap(value), channel)
+ case WsActor.Ping(value) => WebSockets.sendPingBlocking(ByteBuffer.wrap(value), channel)
+ case WsActor.Pong(value) => WebSockets.sendPingBlocking(ByteBuffer.wrap(value), channel)
+ case WsActor.Close(code, reason) => WebSockets.sendCloseBlocking(code, reason, channel)
+ }
+}
+
+case class WsActor(handle: PartialFunction[WsActor.Event, Unit])
+ (implicit ec: ExecutionContext)
+extends BatchActor[WsActor.Event]{
+ def run(items: Seq[WsActor.Event]): Unit = {
+ items.foreach(handle.applyOrElse(_, (x: WsActor.Event) => ()))
+ }
+}
+
+object WsActor{
+ trait Event
+ case class Text(value: String) extends Event
+ case class Binary(value: Array[Byte]) extends Event
+ case class Ping(value: Array[Byte] = Array.empty[Byte]) extends Event
+ case class Pong(value: Array[Byte] = Array.empty[Byte]) extends Event
+ case class Close(code: Int = Close.NormalClosure, reason: String = "") extends Event
+ object Close{
+ val NormalClosure = CloseMessage.NORMAL_CLOSURE
+ val GoingAway = CloseMessage.GOING_AWAY
+ val WrongCode = CloseMessage.WRONG_CODE
+ val ProtocolError = CloseMessage.PROTOCOL_ERROR
+ val MsgContainsInvalidData = CloseMessage.MSG_CONTAINS_INVALID_DATA
+ val MsgViolatesPolicy = CloseMessage.MSG_VIOLATES_POLICY
+ val MsgTooBig = CloseMessage.MSG_TOO_BIG
+ val MissingExtensions = CloseMessage.MISSING_EXTENSIONS
+ val UnexpectedError = CloseMessage.UNEXPECTED_ERROR
+ }
+}