package cask.endpoints import java.nio.ByteBuffer import cask.internal.{BatchActor, Router} import cask.model.Request import cask.util.Logger import io.undertow.websockets.WebSocketConnectionCallback import io.undertow.websockets.core.{AbstractReceiveListener, BufferedBinaryMessage, BufferedTextMessage, CloseMessage, WebSocketChannel, WebSockets} import io.undertow.websockets.spi.WebSocketHttpExchange import scala.concurrent.ExecutionContext sealed trait WebsocketResult object WebsocketResult{ implicit class Response[T](value0: cask.model.Response[T]) (implicit f: T => cask.model.Response.Data) extends WebsocketResult{ def value = value0.map(f) } implicit class Listener(val value: WebSocketConnectionCallback) extends WebsocketResult } class websocket(val path: String, override val subpath: Boolean = false) extends cask.main.Endpoint[WebsocketResult, Seq[String]]{ val methods = Seq("websocket") type InputParser[T] = QueryParamReader[T] type OuterReturned = Router.Result[WebsocketResult] def wrapFunction(ctx: Request, delegate: Delegate): OuterReturned = { delegate(WebEndpoint.buildMapFromQueryParams(ctx)) } def wrapPathSegment(s: String): Seq[String] = Seq(s) } case class WsHandler(f: WsChannelActor => BatchActor[WsActor.Event]) (implicit ec: ExecutionContext, log: Logger) extends WebsocketResult with WebSocketConnectionCallback { def onConnect(exchange: WebSocketHttpExchange, channel: WebSocketChannel): Unit = { val actor = f(new WsChannelActor(channel)) channel.getReceiveSetter.set( new AbstractReceiveListener() { override def onFullTextMessage(channel: WebSocketChannel, message: BufferedTextMessage) = { actor.send(WsActor.Text(message.getData)) } override def onFullBinaryMessage(channel: WebSocketChannel, message: BufferedBinaryMessage): Unit = { actor.send(WsActor.Binary( WebSockets.mergeBuffers(message.getData.getResource:_*).array() )) } override def onFullPingMessage(channel: WebSocketChannel, message: BufferedBinaryMessage): Unit = { actor.send(WsActor.Ping( WebSockets.mergeBuffers(message.getData.getResource:_*).array() )) } override def onFullPongMessage(channel: WebSocketChannel, message: BufferedBinaryMessage): Unit = { actor.send(WsActor.Pong( WebSockets.mergeBuffers(message.getData.getResource:_*).array() )) } override def onCloseMessage(cm: CloseMessage, channel: WebSocketChannel) = { actor.send(WsActor.Close(cm.getCode, cm.getReason)) } } ) channel.resumeReceives() } } class WsChannelActor(channel: WebSocketChannel) (implicit ec: ExecutionContext, log: Logger) extends BatchActor[WsActor.Event]{ def run(items: Seq[WsActor.Event]): Unit = items.foreach{ case WsActor.Text(value) => WebSockets.sendTextBlocking(value, channel) case WsActor.Binary(value) => WebSockets.sendBinaryBlocking(ByteBuffer.wrap(value), channel) case WsActor.Ping(value) => WebSockets.sendPingBlocking(ByteBuffer.wrap(value), channel) case WsActor.Pong(value) => WebSockets.sendPingBlocking(ByteBuffer.wrap(value), channel) case WsActor.Close(code, reason) => WebSockets.sendCloseBlocking(code, reason, channel) } } case class WsActor(handle: PartialFunction[WsActor.Event, Unit]) (implicit ec: ExecutionContext, log: Logger) extends BatchActor[WsActor.Event]{ def run(items: Seq[WsActor.Event]): Unit = { items.foreach(handle.applyOrElse(_, (x: WsActor.Event) => ())) } } object WsActor{ trait Event case class Text(value: String) extends Event case class Binary(value: Array[Byte]) extends Event case class Ping(value: Array[Byte] = Array.empty[Byte]) extends Event case class Pong(value: Array[Byte] = Array.empty[Byte]) extends Event case class Close(code: Int = Close.NormalClosure, reason: String = "") extends Event object Close{ val NormalClosure = CloseMessage.NORMAL_CLOSURE val GoingAway = CloseMessage.GOING_AWAY val WrongCode = CloseMessage.WRONG_CODE val ProtocolError = CloseMessage.PROTOCOL_ERROR val MsgContainsInvalidData = CloseMessage.MSG_CONTAINS_INVALID_DATA val MsgViolatesPolicy = CloseMessage.MSG_VIOLATES_POLICY val MsgTooBig = CloseMessage.MSG_TOO_BIG val MissingExtensions = CloseMessage.MISSING_EXTENSIONS val UnexpectedError = CloseMessage.UNEXPECTED_ERROR } }