diff options
Diffstat (limited to 'cask/src')
-rw-r--r-- | cask/src/cask/endpoints/WebSocketEndpoint.scala | 60 | ||||
-rw-r--r-- | cask/src/cask/main/Routes.scala | 2 | ||||
-rw-r--r-- | cask/src/cask/package.scala | 3 |
3 files changed, 27 insertions, 38 deletions
diff --git a/cask/src/cask/endpoints/WebSocketEndpoint.scala b/cask/src/cask/endpoints/WebSocketEndpoint.scala index 6ca5def..994f015 100644 --- a/cask/src/cask/endpoints/WebSocketEndpoint.scala +++ b/cask/src/cask/endpoints/WebSocketEndpoint.scala @@ -4,7 +4,7 @@ import java.nio.ByteBuffer import cask.model.Request import cask.router.Result -import cask.util.Logger +import cask.util.{Logger, Ws} import io.undertow.websockets.WebSocketConnectionCallback import io.undertow.websockets.core.{AbstractReceiveListener, BufferedBinaryMessage, BufferedTextMessage, CloseMessage, WebSocketChannel, WebSockets} import io.undertow.websockets.spi.WebSocketHttpExchange @@ -32,36 +32,41 @@ class websocket(val path: String, override val subpath: Boolean = false) def wrapPathSegment(s: String): Seq[String] = Seq(s) } -case class WsHandler(f: WsChannelActor => cask.util.BatchActor[WsActor.Event]) +case class WsHandler(f: WsChannelActor => cask.util.BatchActor[Ws.Event]) (implicit ec: ExecutionContext, log: Logger) extends WebsocketResult with WebSocketConnectionCallback { def onConnect(exchange: WebSocketHttpExchange, channel: WebSocketChannel): Unit = { + channel.suspendReceives() val actor = f(new WsChannelActor(channel)) + // Somehow browsers closing tabs and Java processes being killed appear + // as different events here; the former goes to AbstractReceiveListener#onClose, + // while the latter to ChannelListener#handleEvent. Make sure we handle both cases. + channel.addCloseTask(channel => actor.send(Ws.ChannelClosed())) channel.getReceiveSetter.set( new AbstractReceiveListener() { override def onFullTextMessage(channel: WebSocketChannel, message: BufferedTextMessage) = { - actor.send(WsActor.Text(message.getData)) + actor.send(Ws.Text(message.getData)) } override def onFullBinaryMessage(channel: WebSocketChannel, message: BufferedBinaryMessage): Unit = { - actor.send(WsActor.Binary( + actor.send(Ws.Binary( WebSockets.mergeBuffers(message.getData.getResource:_*).array() )) } override def onFullPingMessage(channel: WebSocketChannel, message: BufferedBinaryMessage): Unit = { - actor.send(WsActor.Ping( + actor.send(Ws.Ping( WebSockets.mergeBuffers(message.getData.getResource:_*).array() )) } override def onFullPongMessage(channel: WebSocketChannel, message: BufferedBinaryMessage): Unit = { - actor.send(WsActor.Pong( + actor.send(Ws.Pong( WebSockets.mergeBuffers(message.getData.getResource:_*).array() )) } override def onCloseMessage(cm: CloseMessage, channel: WebSocketChannel) = { - actor.send(WsActor.Close(cm.getCode, cm.getReason)) + actor.send(Ws.Close(cm.getCode, cm.getReason)) } } ) @@ -71,40 +76,21 @@ extends WebsocketResult with WebSocketConnectionCallback { class WsChannelActor(channel: WebSocketChannel) (implicit ec: ExecutionContext, log: Logger) -extends cask.util.BatchActor[WsActor.Event]{ - def run(items: Seq[WsActor.Event]): Unit = items.foreach{ - case WsActor.Text(value) => WebSockets.sendTextBlocking(value, channel) - case WsActor.Binary(value) => WebSockets.sendBinaryBlocking(ByteBuffer.wrap(value), channel) - case WsActor.Ping(value) => WebSockets.sendPingBlocking(ByteBuffer.wrap(value), channel) - case WsActor.Pong(value) => WebSockets.sendPingBlocking(ByteBuffer.wrap(value), channel) - case WsActor.Close(code, reason) => WebSockets.sendCloseBlocking(code, reason, channel) +extends cask.util.BatchActor[Ws.Event]{ + def run(items: Seq[Ws.Event]): Unit = items.foreach{ + case Ws.Text(value) => WebSockets.sendTextBlocking(value, channel) + case Ws.Binary(value) => WebSockets.sendBinaryBlocking(ByteBuffer.wrap(value), channel) + case Ws.Ping(value) => WebSockets.sendPingBlocking(ByteBuffer.wrap(value), channel) + case Ws.Pong(value) => WebSockets.sendPingBlocking(ByteBuffer.wrap(value), channel) + case Ws.Close(code, reason) => WebSockets.sendCloseBlocking(code, reason, channel) } } -case class WsActor(handle: PartialFunction[WsActor.Event, Unit]) +case class WsActor(handle: PartialFunction[Ws.Event, Unit]) (implicit ec: ExecutionContext, log: Logger) -extends cask.util.BatchActor[WsActor.Event]{ - def run(items: Seq[WsActor.Event]): Unit = { - items.foreach(handle.applyOrElse(_, (x: WsActor.Event) => ())) +extends cask.util.BatchActor[Ws.Event]{ + def run(items: Seq[Ws.Event]): Unit = { + items.foreach(handle.applyOrElse(_, (x: Ws.Event) => ())) } } -object WsActor{ - trait Event - case class Text(value: String) extends Event - case class Binary(value: Array[Byte]) extends Event - case class Ping(value: Array[Byte] = Array.empty[Byte]) extends Event - case class Pong(value: Array[Byte] = Array.empty[Byte]) extends Event - case class Close(code: Int = Close.NormalClosure, reason: String = "") extends Event - object Close{ - val NormalClosure = CloseMessage.NORMAL_CLOSURE - val GoingAway = CloseMessage.GOING_AWAY - val WrongCode = CloseMessage.WRONG_CODE - val ProtocolError = CloseMessage.PROTOCOL_ERROR - val MsgContainsInvalidData = CloseMessage.MSG_CONTAINS_INVALID_DATA - val MsgViolatesPolicy = CloseMessage.MSG_VIOLATES_POLICY - val MsgTooBig = CloseMessage.MSG_TOO_BIG - val MissingExtensions = CloseMessage.MISSING_EXTENSIONS - val UnexpectedError = CloseMessage.UNEXPECTED_ERROR - } -} diff --git a/cask/src/cask/main/Routes.scala b/cask/src/cask/main/Routes.scala index 98c5b78..f93e641 100644 --- a/cask/src/cask/main/Routes.scala +++ b/cask/src/cask/main/Routes.scala @@ -17,5 +17,5 @@ trait Routes{ metadata0 = routes } - def log: cask.util.Logger + implicit def log: cask.util.Logger } diff --git a/cask/src/cask/package.scala b/cask/src/cask/package.scala index d34fe26..7c1d61c 100644 --- a/cask/src/cask/package.scala +++ b/cask/src/cask/package.scala @@ -47,6 +47,9 @@ package object cask { type WsActor = cask.endpoints.WsActor val WsActor = cask.endpoints.WsActor type WsChannelActor = cask.endpoints.WsChannelActor + type WsClient = cask.util.WsClient + val WsClient = cask.util.WsClient + val Ws = cask.util.Ws // util type Logger = util.Logger |