summaryrefslogtreecommitdiff
path: root/cask/src/cask/endpoints/WebSocketEndpoint.scala
blob: 6ca5def88a2cc978d301156a37ae2d4337015d2d (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
package cask.endpoints

import java.nio.ByteBuffer

import cask.model.Request
import cask.router.Result
import cask.util.Logger
import io.undertow.websockets.WebSocketConnectionCallback
import io.undertow.websockets.core.{AbstractReceiveListener, BufferedBinaryMessage, BufferedTextMessage, CloseMessage, WebSocketChannel, WebSockets}
import io.undertow.websockets.spi.WebSocketHttpExchange

import scala.concurrent.ExecutionContext

sealed trait WebsocketResult
object WebsocketResult{
  implicit class Response[T](value0: cask.model.Response[T])
                            (implicit f: T => cask.model.Response.Data) extends WebsocketResult{
    def value = value0.map(f)
  }
  implicit class Listener(val value: WebSocketConnectionCallback) extends WebsocketResult
}

class websocket(val path: String, override val subpath: Boolean = false)
  extends cask.router.Endpoint[WebsocketResult, Seq[String]]{
  val methods = Seq("websocket")
  type InputParser[T] = QueryParamReader[T]
  type OuterReturned = Result[WebsocketResult]
  def wrapFunction(ctx: Request, delegate: Delegate): OuterReturned = {
    delegate(WebEndpoint.buildMapFromQueryParams(ctx))
  }

  def wrapPathSegment(s: String): Seq[String] = Seq(s)
}

case class WsHandler(f: WsChannelActor => cask.util.BatchActor[WsActor.Event])
                    (implicit ec: ExecutionContext, log: Logger)
extends WebsocketResult with WebSocketConnectionCallback {
   def onConnect(exchange: WebSocketHttpExchange, channel: WebSocketChannel): Unit = {
     val actor = f(new WsChannelActor(channel))
     channel.getReceiveSetter.set(
       new AbstractReceiveListener() {
         override def onFullTextMessage(channel: WebSocketChannel, message: BufferedTextMessage) = {
           actor.send(WsActor.Text(message.getData))
         }

         override def onFullBinaryMessage(channel: WebSocketChannel, message: BufferedBinaryMessage): Unit = {
           actor.send(WsActor.Binary(
             WebSockets.mergeBuffers(message.getData.getResource:_*).array()
           ))
         }

         override def onFullPingMessage(channel: WebSocketChannel, message: BufferedBinaryMessage): Unit = {
           actor.send(WsActor.Ping(
             WebSockets.mergeBuffers(message.getData.getResource:_*).array()
           ))
         }
         override def onFullPongMessage(channel: WebSocketChannel, message: BufferedBinaryMessage): Unit = {
           actor.send(WsActor.Pong(
             WebSockets.mergeBuffers(message.getData.getResource:_*).array()
           ))
         }

         override def onCloseMessage(cm: CloseMessage, channel: WebSocketChannel) = {
           actor.send(WsActor.Close(cm.getCode, cm.getReason))
         }
       }
     )
    channel.resumeReceives()
  }
}

class WsChannelActor(channel: WebSocketChannel)
                    (implicit ec: ExecutionContext, log: Logger)
extends cask.util.BatchActor[WsActor.Event]{
  def run(items: Seq[WsActor.Event]): Unit = items.foreach{
    case WsActor.Text(value) => WebSockets.sendTextBlocking(value, channel)
    case WsActor.Binary(value) => WebSockets.sendBinaryBlocking(ByteBuffer.wrap(value), channel)
    case WsActor.Ping(value) => WebSockets.sendPingBlocking(ByteBuffer.wrap(value), channel)
    case WsActor.Pong(value) => WebSockets.sendPingBlocking(ByteBuffer.wrap(value), channel)
    case WsActor.Close(code, reason) => WebSockets.sendCloseBlocking(code, reason, channel)
  }
}

case class WsActor(handle: PartialFunction[WsActor.Event, Unit])
                  (implicit ec: ExecutionContext, log: Logger)
extends cask.util.BatchActor[WsActor.Event]{
  def run(items: Seq[WsActor.Event]): Unit = {
    items.foreach(handle.applyOrElse(_, (x: WsActor.Event) => ()))
  }
}

object WsActor{
  trait Event
  case class Text(value: String) extends Event
  case class Binary(value: Array[Byte]) extends Event
  case class Ping(value: Array[Byte] = Array.empty[Byte]) extends Event
  case class Pong(value: Array[Byte] = Array.empty[Byte]) extends Event
  case class Close(code: Int = Close.NormalClosure, reason: String = "") extends Event
  object Close{
    val NormalClosure = CloseMessage.NORMAL_CLOSURE
    val GoingAway = CloseMessage.GOING_AWAY
    val WrongCode = CloseMessage.WRONG_CODE
    val ProtocolError = CloseMessage.PROTOCOL_ERROR
    val MsgContainsInvalidData = CloseMessage.MSG_CONTAINS_INVALID_DATA
    val MsgViolatesPolicy = CloseMessage.MSG_VIOLATES_POLICY
    val MsgTooBig = CloseMessage.MSG_TOO_BIG
    val MissingExtensions = CloseMessage.MISSING_EXTENSIONS
    val UnexpectedError = CloseMessage.UNEXPECTED_ERROR
  }
}