summaryrefslogtreecommitdiff
path: root/cask/src/cask/endpoints/WebSocketEndpoint.scala
diff options
context:
space:
mode:
Diffstat (limited to 'cask/src/cask/endpoints/WebSocketEndpoint.scala')
-rw-r--r--cask/src/cask/endpoints/WebSocketEndpoint.scala60
1 files changed, 23 insertions, 37 deletions
diff --git a/cask/src/cask/endpoints/WebSocketEndpoint.scala b/cask/src/cask/endpoints/WebSocketEndpoint.scala
index 6ca5def..994f015 100644
--- a/cask/src/cask/endpoints/WebSocketEndpoint.scala
+++ b/cask/src/cask/endpoints/WebSocketEndpoint.scala
@@ -4,7 +4,7 @@ import java.nio.ByteBuffer
import cask.model.Request
import cask.router.Result
-import cask.util.Logger
+import cask.util.{Logger, Ws}
import io.undertow.websockets.WebSocketConnectionCallback
import io.undertow.websockets.core.{AbstractReceiveListener, BufferedBinaryMessage, BufferedTextMessage, CloseMessage, WebSocketChannel, WebSockets}
import io.undertow.websockets.spi.WebSocketHttpExchange
@@ -32,36 +32,41 @@ class websocket(val path: String, override val subpath: Boolean = false)
def wrapPathSegment(s: String): Seq[String] = Seq(s)
}
-case class WsHandler(f: WsChannelActor => cask.util.BatchActor[WsActor.Event])
+case class WsHandler(f: WsChannelActor => cask.util.BatchActor[Ws.Event])
(implicit ec: ExecutionContext, log: Logger)
extends WebsocketResult with WebSocketConnectionCallback {
def onConnect(exchange: WebSocketHttpExchange, channel: WebSocketChannel): Unit = {
+ channel.suspendReceives()
val actor = f(new WsChannelActor(channel))
+ // Somehow browsers closing tabs and Java processes being killed appear
+ // as different events here; the former goes to AbstractReceiveListener#onClose,
+ // while the latter to ChannelListener#handleEvent. Make sure we handle both cases.
+ channel.addCloseTask(channel => actor.send(Ws.ChannelClosed()))
channel.getReceiveSetter.set(
new AbstractReceiveListener() {
override def onFullTextMessage(channel: WebSocketChannel, message: BufferedTextMessage) = {
- actor.send(WsActor.Text(message.getData))
+ actor.send(Ws.Text(message.getData))
}
override def onFullBinaryMessage(channel: WebSocketChannel, message: BufferedBinaryMessage): Unit = {
- actor.send(WsActor.Binary(
+ actor.send(Ws.Binary(
WebSockets.mergeBuffers(message.getData.getResource:_*).array()
))
}
override def onFullPingMessage(channel: WebSocketChannel, message: BufferedBinaryMessage): Unit = {
- actor.send(WsActor.Ping(
+ actor.send(Ws.Ping(
WebSockets.mergeBuffers(message.getData.getResource:_*).array()
))
}
override def onFullPongMessage(channel: WebSocketChannel, message: BufferedBinaryMessage): Unit = {
- actor.send(WsActor.Pong(
+ actor.send(Ws.Pong(
WebSockets.mergeBuffers(message.getData.getResource:_*).array()
))
}
override def onCloseMessage(cm: CloseMessage, channel: WebSocketChannel) = {
- actor.send(WsActor.Close(cm.getCode, cm.getReason))
+ actor.send(Ws.Close(cm.getCode, cm.getReason))
}
}
)
@@ -71,40 +76,21 @@ extends WebsocketResult with WebSocketConnectionCallback {
class WsChannelActor(channel: WebSocketChannel)
(implicit ec: ExecutionContext, log: Logger)
-extends cask.util.BatchActor[WsActor.Event]{
- def run(items: Seq[WsActor.Event]): Unit = items.foreach{
- case WsActor.Text(value) => WebSockets.sendTextBlocking(value, channel)
- case WsActor.Binary(value) => WebSockets.sendBinaryBlocking(ByteBuffer.wrap(value), channel)
- case WsActor.Ping(value) => WebSockets.sendPingBlocking(ByteBuffer.wrap(value), channel)
- case WsActor.Pong(value) => WebSockets.sendPingBlocking(ByteBuffer.wrap(value), channel)
- case WsActor.Close(code, reason) => WebSockets.sendCloseBlocking(code, reason, channel)
+extends cask.util.BatchActor[Ws.Event]{
+ def run(items: Seq[Ws.Event]): Unit = items.foreach{
+ case Ws.Text(value) => WebSockets.sendTextBlocking(value, channel)
+ case Ws.Binary(value) => WebSockets.sendBinaryBlocking(ByteBuffer.wrap(value), channel)
+ case Ws.Ping(value) => WebSockets.sendPingBlocking(ByteBuffer.wrap(value), channel)
+ case Ws.Pong(value) => WebSockets.sendPingBlocking(ByteBuffer.wrap(value), channel)
+ case Ws.Close(code, reason) => WebSockets.sendCloseBlocking(code, reason, channel)
}
}
-case class WsActor(handle: PartialFunction[WsActor.Event, Unit])
+case class WsActor(handle: PartialFunction[Ws.Event, Unit])
(implicit ec: ExecutionContext, log: Logger)
-extends cask.util.BatchActor[WsActor.Event]{
- def run(items: Seq[WsActor.Event]): Unit = {
- items.foreach(handle.applyOrElse(_, (x: WsActor.Event) => ()))
+extends cask.util.BatchActor[Ws.Event]{
+ def run(items: Seq[Ws.Event]): Unit = {
+ items.foreach(handle.applyOrElse(_, (x: Ws.Event) => ()))
}
}
-object WsActor{
- trait Event
- case class Text(value: String) extends Event
- case class Binary(value: Array[Byte]) extends Event
- case class Ping(value: Array[Byte] = Array.empty[Byte]) extends Event
- case class Pong(value: Array[Byte] = Array.empty[Byte]) extends Event
- case class Close(code: Int = Close.NormalClosure, reason: String = "") extends Event
- object Close{
- val NormalClosure = CloseMessage.NORMAL_CLOSURE
- val GoingAway = CloseMessage.GOING_AWAY
- val WrongCode = CloseMessage.WRONG_CODE
- val ProtocolError = CloseMessage.PROTOCOL_ERROR
- val MsgContainsInvalidData = CloseMessage.MSG_CONTAINS_INVALID_DATA
- val MsgViolatesPolicy = CloseMessage.MSG_VIOLATES_POLICY
- val MsgTooBig = CloseMessage.MSG_TOO_BIG
- val MissingExtensions = CloseMessage.MISSING_EXTENSIONS
- val UnexpectedError = CloseMessage.UNEXPECTED_ERROR
- }
-}