summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLi Haoyi <haoyi.sg@gmail.com>2019-09-14 22:37:49 +0800
committerLi Haoyi <haoyi.sg@gmail.com>2019-09-14 22:37:49 +0800
commit9e58e95add96a075d2cb70aa477441261f481ebd (patch)
tree243fd851913972de6f0110cd254ffd5d6e3565f5
parentdedeed376e1e906ec2eb1574a73f08c24aba47c8 (diff)
downloadcask-9e58e95add96a075d2cb70aa477441261f481ebd.tar.gz
cask-9e58e95add96a075d2cb70aa477441261f481ebd.tar.bz2
cask-9e58e95add96a075d2cb70aa477441261f481ebd.zip
First pass at providing a convenient API for handling websockets
-rw-r--r--build.sc3
-rw-r--r--cask/src/cask/endpoints/WebSocketEndpoint.scala85
-rw-r--r--cask/src/cask/internal/BatchActor.scala37
-rw-r--r--cask/src/cask/main/Main.scala4
-rw-r--r--cask/src/cask/package.scala5
-rw-r--r--example/websockets/app/src/Websockets.scala23
-rw-r--r--example/websockets2/app/src/Websockets2.scala29
-rw-r--r--example/websockets2/app/test/src/ExampleTests.scala123
-rw-r--r--example/websockets2/build.sc17
9 files changed, 306 insertions, 20 deletions
diff --git a/build.sc b/build.sc
index f6050e3..69fae11 100644
--- a/build.sc
+++ b/build.sc
@@ -21,6 +21,7 @@ import $file.example.todoApi.build
import $file.example.twirl.build
import $file.example.variableRoutes.build
import $file.example.websockets.build
+import $file.example.websockets2.build
object cask extends ScalaModule with PublishModule {
def scalaVersion = "2.13.0"
@@ -82,6 +83,7 @@ object example extends Module{
object twirl extends $file.example.twirl.build.AppModule with LocalModule
object variableRoutes extends $file.example.variableRoutes.build.AppModule with LocalModule
object websockets extends $file.example.websockets.build.AppModule with LocalModule
+ object websockets2 extends $file.example.websockets2.build.AppModule with LocalModule
}
def publishVersion = T.input($file.ci.version.publishVersion)
@@ -124,6 +126,7 @@ def uploadToGithub(authKey: String) = T.command{
$file.example.twirl.build.millSourcePath,
$file.example.variableRoutes.build.millSourcePath,
$file.example.websockets.build.millSourcePath,
+ $file.example.websockets2.build.millSourcePath,
)
for(example <- examples){
val f = T.ctx().dest
diff --git a/cask/src/cask/endpoints/WebSocketEndpoint.scala b/cask/src/cask/endpoints/WebSocketEndpoint.scala
index 842d508..6728581 100644
--- a/cask/src/cask/endpoints/WebSocketEndpoint.scala
+++ b/cask/src/cask/endpoints/WebSocketEndpoint.scala
@@ -1,9 +1,15 @@
package cask.endpoints
-import cask.internal.Router
+import java.nio.ByteBuffer
+
+import cask.internal.{BatchActor, Router}
import cask.model.Request
import io.undertow.websockets.WebSocketConnectionCallback
-import collection.JavaConverters._
+import io.undertow.websockets.core.{AbstractReceiveListener, BufferedBinaryMessage, BufferedTextMessage, CloseMessage, WebSocketChannel, WebSockets}
+import io.undertow.websockets.spi.WebSocketHttpExchange
+
+import scala.concurrent.ExecutionContext
+
sealed trait WebsocketResult
object WebsocketResult{
implicit class Response[T](value0: cask.model.Response[T])
@@ -24,3 +30,78 @@ class websocket(val path: String, override val subpath: Boolean = false)
def wrapPathSegment(s: String): Seq[String] = Seq(s)
}
+
+case class WsHandler(f: WsChannelActor => BatchActor[WsActor.Event])(implicit ec: ExecutionContext)
+extends WebsocketResult with WebSocketConnectionCallback {
+ def onConnect(exchange: WebSocketHttpExchange, channel: WebSocketChannel): Unit = {
+ val actor = f(new WsChannelActor(channel))
+ channel.getReceiveSetter.set(
+ new AbstractReceiveListener() {
+ override def onFullTextMessage(channel: WebSocketChannel, message: BufferedTextMessage) = {
+ actor.send(WsActor.Text(message.getData))
+ }
+
+ override def onFullBinaryMessage(channel: WebSocketChannel, message: BufferedBinaryMessage): Unit = {
+ actor.send(WsActor.Binary(
+ WebSockets.mergeBuffers(message.getData.getResource:_*).array()
+ ))
+ }
+
+ override def onFullPingMessage(channel: WebSocketChannel, message: BufferedBinaryMessage): Unit = {
+ actor.send(WsActor.Ping(
+ WebSockets.mergeBuffers(message.getData.getResource:_*).array()
+ ))
+ }
+ override def onFullPongMessage(channel: WebSocketChannel, message: BufferedBinaryMessage): Unit = {
+ actor.send(WsActor.Pong(
+ WebSockets.mergeBuffers(message.getData.getResource:_*).array()
+ ))
+ }
+
+ override def onCloseMessage(cm: CloseMessage, channel: WebSocketChannel) = {
+ actor.send(WsActor.Close(cm.getCode, cm.getReason))
+ }
+ }
+ )
+ channel.resumeReceives()
+ }
+}
+
+class WsChannelActor(channel: WebSocketChannel)(implicit ec: ExecutionContext)
+extends BatchActor[WsActor.Event]{
+ def run(items: Seq[WsActor.Event]): Unit = items.foreach{
+ case WsActor.Text(value) => WebSockets.sendTextBlocking(value, channel)
+ case WsActor.Binary(value) => WebSockets.sendBinaryBlocking(ByteBuffer.wrap(value), channel)
+ case WsActor.Ping(value) => WebSockets.sendPingBlocking(ByteBuffer.wrap(value), channel)
+ case WsActor.Pong(value) => WebSockets.sendPingBlocking(ByteBuffer.wrap(value), channel)
+ case WsActor.Close(code, reason) => WebSockets.sendCloseBlocking(code, reason, channel)
+ }
+}
+
+case class WsActor(handle: PartialFunction[WsActor.Event, Unit])
+ (implicit ec: ExecutionContext)
+extends BatchActor[WsActor.Event]{
+ def run(items: Seq[WsActor.Event]): Unit = {
+ items.foreach(handle.applyOrElse(_, (x: WsActor.Event) => ()))
+ }
+}
+
+object WsActor{
+ trait Event
+ case class Text(value: String) extends Event
+ case class Binary(value: Array[Byte]) extends Event
+ case class Ping(value: Array[Byte] = Array.empty[Byte]) extends Event
+ case class Pong(value: Array[Byte] = Array.empty[Byte]) extends Event
+ case class Close(code: Int = Close.NormalClosure, reason: String = "") extends Event
+ object Close{
+ val NormalClosure = CloseMessage.NORMAL_CLOSURE
+ val GoingAway = CloseMessage.GOING_AWAY
+ val WrongCode = CloseMessage.WRONG_CODE
+ val ProtocolError = CloseMessage.PROTOCOL_ERROR
+ val MsgContainsInvalidData = CloseMessage.MSG_CONTAINS_INVALID_DATA
+ val MsgViolatesPolicy = CloseMessage.MSG_VIOLATES_POLICY
+ val MsgTooBig = CloseMessage.MSG_TOO_BIG
+ val MissingExtensions = CloseMessage.MISSING_EXTENSIONS
+ val UnexpectedError = CloseMessage.UNEXPECTED_ERROR
+ }
+}
diff --git a/cask/src/cask/internal/BatchActor.scala b/cask/src/cask/internal/BatchActor.scala
new file mode 100644
index 0000000..1566a18
--- /dev/null
+++ b/cask/src/cask/internal/BatchActor.scala
@@ -0,0 +1,37 @@
+package cask.internal
+
+import scala.collection.mutable
+import scala.concurrent.ExecutionContext
+
+/**
+ * A simple asynchrous actor, allowing safe concurrent asynchronous processing
+ * of queued items. `run` handles items in batches, to allow for batch
+ * processing optimizations to be used where relevant.
+ */
+abstract class BatchActor[T]()(implicit ec: ExecutionContext) {
+ def run(items: Seq[T]): Unit
+
+ private val queue = new mutable.Queue[T]()
+ private var scheduled = false
+ def send(t: => T): Unit = synchronized{
+ queue.enqueue(t)
+ if (!scheduled){
+ scheduled = true
+ ec.execute(() => runWithItems())
+ }
+ }
+
+ def runWithItems(): Unit = {
+ val items = synchronized(queue.dequeueAll(_ => true))
+ try run(items)
+ catch{case e: Throwable => e.printStackTrace()}
+ synchronized{
+ if (queue.nonEmpty) ec.execute(() => runWithItems())
+ else{
+ assert(scheduled)
+ scheduled = false
+ }
+ }
+
+ }
+}
diff --git a/cask/src/cask/main/Main.scala b/cask/src/cask/main/Main.scala
index 32e6517..a138e8a 100644
--- a/cask/src/cask/main/Main.scala
+++ b/cask/src/cask/main/Main.scala
@@ -1,6 +1,6 @@
package cask.main
-import cask.endpoints.WebsocketResult
+import cask.endpoints.{WebsocketResult, WsHandler}
import cask.model._
import cask.internal.Router.EntryPoint
import cask.internal.{DispatchTrie, Router, Util}
@@ -73,6 +73,8 @@ abstract class BaseMain{
val (effectiveMethod, runner) = if (exchange.getRequestHeaders.getFirst("Upgrade") == "websocket") {
"websocket" -> ((r: Any) =>
r.asInstanceOf[WebsocketResult] match{
+ case l: WsHandler =>
+ io.undertow.Handlers.websocket(l).handleRequest(exchange)
case l: WebsocketResult.Listener =>
io.undertow.Handlers.websocket(l.value).handleRequest(exchange)
case r: WebsocketResult.Response[_] =>
diff --git a/cask/src/cask/package.scala b/cask/src/cask/package.scala
index d17cfe6..f55d33d 100644
--- a/cask/src/cask/package.scala
+++ b/cask/src/cask/package.scala
@@ -40,4 +40,9 @@ package object cask {
type RawDecorator = main.RawDecorator
type HttpEndpoint[InnerReturned, Input] = main.HttpEndpoint[InnerReturned, Input]
+ type WsHandler = cask.endpoints.WsHandler
+ val WsHandler = cask.endpoints.WsHandler
+ type WsActor = cask.endpoints.WsActor
+ val WsActor = cask.endpoints.WsActor
+ type WsChannelActor = cask.endpoints.WsChannelActor
}
diff --git a/example/websockets/app/src/Websockets.scala b/example/websockets/app/src/Websockets.scala
index a6ceb73..6cada0f 100644
--- a/example/websockets/app/src/Websockets.scala
+++ b/example/websockets/app/src/Websockets.scala
@@ -1,26 +1,15 @@
package app
-import io.undertow.websockets.WebSocketConnectionCallback
-import io.undertow.websockets.core.{AbstractReceiveListener, BufferedTextMessage, WebSocketChannel, WebSockets}
-import io.undertow.websockets.spi.WebSocketHttpExchange
-
+import concurrent.ExecutionContext.Implicits.global
object Websockets extends cask.MainRoutes{
@cask.websocket("/connect/:userName")
def showUserProfile(userName: String): cask.WebsocketResult = {
if (userName != "haoyi") cask.Response("", statusCode = 403)
- else new WebSocketConnectionCallback() {
- override def onConnect(exchange: WebSocketHttpExchange, channel: WebSocketChannel): Unit = {
- channel.getReceiveSetter.set(
- new AbstractReceiveListener() {
- override def onFullTextMessage(channel: WebSocketChannel, message: BufferedTextMessage) = {
- message.getData match{
- case "" => channel.close()
- case data => WebSockets.sendTextBlocking(userName + " " + data, channel)
- }
- }
- }
- )
- channel.resumeReceives()
+ else cask.WsHandler { channel =>
+ cask.WsActor {
+ case cask.WsActor.Text("") => channel.send(cask.WsActor.Close())
+ case cask.WsActor.Text(data) =>
+ channel.send(cask.WsActor.Text(userName + " " + data))
}
}
}
diff --git a/example/websockets2/app/src/Websockets2.scala b/example/websockets2/app/src/Websockets2.scala
new file mode 100644
index 0000000..b78bee3
--- /dev/null
+++ b/example/websockets2/app/src/Websockets2.scala
@@ -0,0 +1,29 @@
+package app
+
+import io.undertow.websockets.WebSocketConnectionCallback
+import io.undertow.websockets.core.{AbstractReceiveListener, BufferedTextMessage, WebSocketChannel, WebSockets}
+import io.undertow.websockets.spi.WebSocketHttpExchange
+
+object Websockets2 extends cask.MainRoutes{
+ @cask.websocket("/connect/:userName")
+ def showUserProfile(userName: String): cask.WebsocketResult = {
+ if (userName != "haoyi") cask.Response("", statusCode = 403)
+ else new WebSocketConnectionCallback() {
+ override def onConnect(exchange: WebSocketHttpExchange, channel: WebSocketChannel): Unit = {
+ channel.getReceiveSetter.set(
+ new AbstractReceiveListener() {
+ override def onFullTextMessage(channel: WebSocketChannel, message: BufferedTextMessage) = {
+ message.getData match{
+ case "" => channel.close()
+ case data => WebSockets.sendTextBlocking(userName + " " + data, channel)
+ }
+ }
+ }
+ )
+ channel.resumeReceives()
+ }
+ }
+ }
+
+ initialize()
+}
diff --git a/example/websockets2/app/test/src/ExampleTests.scala b/example/websockets2/app/test/src/ExampleTests.scala
new file mode 100644
index 0000000..27bff5e
--- /dev/null
+++ b/example/websockets2/app/test/src/ExampleTests.scala
@@ -0,0 +1,123 @@
+package app
+
+import java.util.concurrent.atomic.AtomicInteger
+
+import org.asynchttpclient.ws.{WebSocket, WebSocketListener, WebSocketUpgradeHandler}
+import utest._
+
+object ExampleTests extends TestSuite{
+
+
+ def withServer[T](example: cask.main.BaseMain)(f: String => T): T = {
+ val server = io.undertow.Undertow.builder
+ .addHttpListener(8080, "localhost")
+ .setHandler(example.defaultHandler)
+ .build
+ server.start()
+ val res =
+ try f("http://localhost:8080")
+ finally server.stop()
+ res
+ }
+
+ val tests = Tests{
+ test("Websockets") - withServer(Websockets2){ host =>
+ @volatile var out = List.empty[String]
+ val client = org.asynchttpclient.Dsl.asyncHttpClient();
+ try{
+
+ // 4. open websocket
+ val ws: WebSocket = client.prepareGet("ws://localhost:8080/connect/haoyi")
+ .execute(new WebSocketUpgradeHandler.Builder().addWebSocketListener(
+ new WebSocketListener() {
+
+ override def onTextFrame(payload: String, finalFragment: Boolean, rsv: Int) {
+ out = payload :: out
+ }
+
+ def onOpen(websocket: WebSocket) = ()
+
+ def onClose(websocket: WebSocket, code: Int, reason: String) = ()
+
+ def onError(t: Throwable) = ()
+ }).build()
+ ).get()
+
+ // 5. send messages
+ ws.sendTextFrame("hello")
+ ws.sendTextFrame("world")
+ ws.sendTextFrame("")
+ Thread.sleep(100)
+ out ==> List("haoyi world", "haoyi hello")
+
+ var error: String = ""
+ val cli2 = client.prepareGet("ws://localhost:8080/connect/nobody")
+ .execute(new WebSocketUpgradeHandler.Builder().addWebSocketListener(
+ new WebSocketListener() {
+
+ def onOpen(websocket: WebSocket) = ()
+
+ def onClose(websocket: WebSocket, code: Int, reason: String) = ()
+
+ def onError(t: Throwable) = {
+ error = t.toString
+ }
+ }).build()
+ ).get()
+
+ assert(error.contains("403"))
+
+ } finally{
+ client.close()
+ }
+ }
+
+ test("Websockets2000") - withServer(Websockets2){ host =>
+ @volatile var out = List.empty[String]
+ val closed = new AtomicInteger(0)
+ val client = org.asynchttpclient.Dsl.asyncHttpClient();
+ val ws = Seq.fill(2000)(client.prepareGet("ws://localhost:8080/connect/haoyi")
+ .execute(new WebSocketUpgradeHandler.Builder().addWebSocketListener(
+ new WebSocketListener() {
+
+ override def onTextFrame(payload: String, finalFragment: Boolean, rsv: Int) = {
+ ExampleTests.synchronized {
+ out = payload :: out
+ }
+ }
+
+ def onOpen(websocket: WebSocket) = ()
+
+ def onClose(websocket: WebSocket, code: Int, reason: String) = {
+ closed.incrementAndGet()
+ }
+
+ def onError(t: Throwable) = ()
+ }).build()
+ ).get())
+
+ try{
+ // 5. send messages
+ ws.foreach(_.sendTextFrame("hello"))
+
+ Thread.sleep(1500)
+ out.length ==> 2000
+
+ ws.foreach(_.sendTextFrame("world"))
+
+ Thread.sleep(1500)
+ out.length ==> 4000
+ closed.get() ==> 0
+
+ ws.foreach(_.sendTextFrame(""))
+
+ Thread.sleep(1500)
+ closed.get() ==> 2000
+
+ }finally{
+ client.close()
+ }
+ }
+
+ }
+}
diff --git a/example/websockets2/build.sc b/example/websockets2/build.sc
new file mode 100644
index 0000000..197e285
--- /dev/null
+++ b/example/websockets2/build.sc
@@ -0,0 +1,17 @@
+import mill._, scalalib._
+
+
+trait AppModule extends ScalaModule{
+ def scalaVersion = "2.13.0"
+ def ivyDeps = Agg[Dep](
+ )
+ object test extends Tests{
+ def testFrameworks = Seq("utest.runner.Framework")
+
+ def ivyDeps = Agg(
+ ivy"com.lihaoyi::utest::0.7.1",
+ ivy"com.lihaoyi::requests::0.2.0",
+ ivy"org.asynchttpclient:async-http-client:2.5.2"
+ )
+ }
+} \ No newline at end of file