diff options
-rw-r--r-- | build.sc | 3 | ||||
-rw-r--r-- | cask/src/cask/endpoints/WebSocketEndpoint.scala | 85 | ||||
-rw-r--r-- | cask/src/cask/internal/BatchActor.scala | 37 | ||||
-rw-r--r-- | cask/src/cask/main/Main.scala | 4 | ||||
-rw-r--r-- | cask/src/cask/package.scala | 5 | ||||
-rw-r--r-- | example/websockets/app/src/Websockets.scala | 23 | ||||
-rw-r--r-- | example/websockets2/app/src/Websockets2.scala | 29 | ||||
-rw-r--r-- | example/websockets2/app/test/src/ExampleTests.scala | 123 | ||||
-rw-r--r-- | example/websockets2/build.sc | 17 |
9 files changed, 306 insertions, 20 deletions
@@ -21,6 +21,7 @@ import $file.example.todoApi.build import $file.example.twirl.build import $file.example.variableRoutes.build import $file.example.websockets.build +import $file.example.websockets2.build object cask extends ScalaModule with PublishModule { def scalaVersion = "2.13.0" @@ -82,6 +83,7 @@ object example extends Module{ object twirl extends $file.example.twirl.build.AppModule with LocalModule object variableRoutes extends $file.example.variableRoutes.build.AppModule with LocalModule object websockets extends $file.example.websockets.build.AppModule with LocalModule + object websockets2 extends $file.example.websockets2.build.AppModule with LocalModule } def publishVersion = T.input($file.ci.version.publishVersion) @@ -124,6 +126,7 @@ def uploadToGithub(authKey: String) = T.command{ $file.example.twirl.build.millSourcePath, $file.example.variableRoutes.build.millSourcePath, $file.example.websockets.build.millSourcePath, + $file.example.websockets2.build.millSourcePath, ) for(example <- examples){ val f = T.ctx().dest diff --git a/cask/src/cask/endpoints/WebSocketEndpoint.scala b/cask/src/cask/endpoints/WebSocketEndpoint.scala index 842d508..6728581 100644 --- a/cask/src/cask/endpoints/WebSocketEndpoint.scala +++ b/cask/src/cask/endpoints/WebSocketEndpoint.scala @@ -1,9 +1,15 @@ package cask.endpoints -import cask.internal.Router +import java.nio.ByteBuffer + +import cask.internal.{BatchActor, Router} import cask.model.Request import io.undertow.websockets.WebSocketConnectionCallback -import collection.JavaConverters._ +import io.undertow.websockets.core.{AbstractReceiveListener, BufferedBinaryMessage, BufferedTextMessage, CloseMessage, WebSocketChannel, WebSockets} +import io.undertow.websockets.spi.WebSocketHttpExchange + +import scala.concurrent.ExecutionContext + sealed trait WebsocketResult object WebsocketResult{ implicit class Response[T](value0: cask.model.Response[T]) @@ -24,3 +30,78 @@ class websocket(val path: String, override val subpath: Boolean = false) def wrapPathSegment(s: String): Seq[String] = Seq(s) } + +case class WsHandler(f: WsChannelActor => BatchActor[WsActor.Event])(implicit ec: ExecutionContext) +extends WebsocketResult with WebSocketConnectionCallback { + def onConnect(exchange: WebSocketHttpExchange, channel: WebSocketChannel): Unit = { + val actor = f(new WsChannelActor(channel)) + channel.getReceiveSetter.set( + new AbstractReceiveListener() { + override def onFullTextMessage(channel: WebSocketChannel, message: BufferedTextMessage) = { + actor.send(WsActor.Text(message.getData)) + } + + override def onFullBinaryMessage(channel: WebSocketChannel, message: BufferedBinaryMessage): Unit = { + actor.send(WsActor.Binary( + WebSockets.mergeBuffers(message.getData.getResource:_*).array() + )) + } + + override def onFullPingMessage(channel: WebSocketChannel, message: BufferedBinaryMessage): Unit = { + actor.send(WsActor.Ping( + WebSockets.mergeBuffers(message.getData.getResource:_*).array() + )) + } + override def onFullPongMessage(channel: WebSocketChannel, message: BufferedBinaryMessage): Unit = { + actor.send(WsActor.Pong( + WebSockets.mergeBuffers(message.getData.getResource:_*).array() + )) + } + + override def onCloseMessage(cm: CloseMessage, channel: WebSocketChannel) = { + actor.send(WsActor.Close(cm.getCode, cm.getReason)) + } + } + ) + channel.resumeReceives() + } +} + +class WsChannelActor(channel: WebSocketChannel)(implicit ec: ExecutionContext) +extends BatchActor[WsActor.Event]{ + def run(items: Seq[WsActor.Event]): Unit = items.foreach{ + case WsActor.Text(value) => WebSockets.sendTextBlocking(value, channel) + case WsActor.Binary(value) => WebSockets.sendBinaryBlocking(ByteBuffer.wrap(value), channel) + case WsActor.Ping(value) => WebSockets.sendPingBlocking(ByteBuffer.wrap(value), channel) + case WsActor.Pong(value) => WebSockets.sendPingBlocking(ByteBuffer.wrap(value), channel) + case WsActor.Close(code, reason) => WebSockets.sendCloseBlocking(code, reason, channel) + } +} + +case class WsActor(handle: PartialFunction[WsActor.Event, Unit]) + (implicit ec: ExecutionContext) +extends BatchActor[WsActor.Event]{ + def run(items: Seq[WsActor.Event]): Unit = { + items.foreach(handle.applyOrElse(_, (x: WsActor.Event) => ())) + } +} + +object WsActor{ + trait Event + case class Text(value: String) extends Event + case class Binary(value: Array[Byte]) extends Event + case class Ping(value: Array[Byte] = Array.empty[Byte]) extends Event + case class Pong(value: Array[Byte] = Array.empty[Byte]) extends Event + case class Close(code: Int = Close.NormalClosure, reason: String = "") extends Event + object Close{ + val NormalClosure = CloseMessage.NORMAL_CLOSURE + val GoingAway = CloseMessage.GOING_AWAY + val WrongCode = CloseMessage.WRONG_CODE + val ProtocolError = CloseMessage.PROTOCOL_ERROR + val MsgContainsInvalidData = CloseMessage.MSG_CONTAINS_INVALID_DATA + val MsgViolatesPolicy = CloseMessage.MSG_VIOLATES_POLICY + val MsgTooBig = CloseMessage.MSG_TOO_BIG + val MissingExtensions = CloseMessage.MISSING_EXTENSIONS + val UnexpectedError = CloseMessage.UNEXPECTED_ERROR + } +} diff --git a/cask/src/cask/internal/BatchActor.scala b/cask/src/cask/internal/BatchActor.scala new file mode 100644 index 0000000..1566a18 --- /dev/null +++ b/cask/src/cask/internal/BatchActor.scala @@ -0,0 +1,37 @@ +package cask.internal + +import scala.collection.mutable +import scala.concurrent.ExecutionContext + +/** + * A simple asynchrous actor, allowing safe concurrent asynchronous processing + * of queued items. `run` handles items in batches, to allow for batch + * processing optimizations to be used where relevant. + */ +abstract class BatchActor[T]()(implicit ec: ExecutionContext) { + def run(items: Seq[T]): Unit + + private val queue = new mutable.Queue[T]() + private var scheduled = false + def send(t: => T): Unit = synchronized{ + queue.enqueue(t) + if (!scheduled){ + scheduled = true + ec.execute(() => runWithItems()) + } + } + + def runWithItems(): Unit = { + val items = synchronized(queue.dequeueAll(_ => true)) + try run(items) + catch{case e: Throwable => e.printStackTrace()} + synchronized{ + if (queue.nonEmpty) ec.execute(() => runWithItems()) + else{ + assert(scheduled) + scheduled = false + } + } + + } +} diff --git a/cask/src/cask/main/Main.scala b/cask/src/cask/main/Main.scala index 32e6517..a138e8a 100644 --- a/cask/src/cask/main/Main.scala +++ b/cask/src/cask/main/Main.scala @@ -1,6 +1,6 @@ package cask.main -import cask.endpoints.WebsocketResult +import cask.endpoints.{WebsocketResult, WsHandler} import cask.model._ import cask.internal.Router.EntryPoint import cask.internal.{DispatchTrie, Router, Util} @@ -73,6 +73,8 @@ abstract class BaseMain{ val (effectiveMethod, runner) = if (exchange.getRequestHeaders.getFirst("Upgrade") == "websocket") { "websocket" -> ((r: Any) => r.asInstanceOf[WebsocketResult] match{ + case l: WsHandler => + io.undertow.Handlers.websocket(l).handleRequest(exchange) case l: WebsocketResult.Listener => io.undertow.Handlers.websocket(l.value).handleRequest(exchange) case r: WebsocketResult.Response[_] => diff --git a/cask/src/cask/package.scala b/cask/src/cask/package.scala index d17cfe6..f55d33d 100644 --- a/cask/src/cask/package.scala +++ b/cask/src/cask/package.scala @@ -40,4 +40,9 @@ package object cask { type RawDecorator = main.RawDecorator type HttpEndpoint[InnerReturned, Input] = main.HttpEndpoint[InnerReturned, Input] + type WsHandler = cask.endpoints.WsHandler + val WsHandler = cask.endpoints.WsHandler + type WsActor = cask.endpoints.WsActor + val WsActor = cask.endpoints.WsActor + type WsChannelActor = cask.endpoints.WsChannelActor } diff --git a/example/websockets/app/src/Websockets.scala b/example/websockets/app/src/Websockets.scala index a6ceb73..6cada0f 100644 --- a/example/websockets/app/src/Websockets.scala +++ b/example/websockets/app/src/Websockets.scala @@ -1,26 +1,15 @@ package app -import io.undertow.websockets.WebSocketConnectionCallback -import io.undertow.websockets.core.{AbstractReceiveListener, BufferedTextMessage, WebSocketChannel, WebSockets} -import io.undertow.websockets.spi.WebSocketHttpExchange - +import concurrent.ExecutionContext.Implicits.global object Websockets extends cask.MainRoutes{ @cask.websocket("/connect/:userName") def showUserProfile(userName: String): cask.WebsocketResult = { if (userName != "haoyi") cask.Response("", statusCode = 403) - else new WebSocketConnectionCallback() { - override def onConnect(exchange: WebSocketHttpExchange, channel: WebSocketChannel): Unit = { - channel.getReceiveSetter.set( - new AbstractReceiveListener() { - override def onFullTextMessage(channel: WebSocketChannel, message: BufferedTextMessage) = { - message.getData match{ - case "" => channel.close() - case data => WebSockets.sendTextBlocking(userName + " " + data, channel) - } - } - } - ) - channel.resumeReceives() + else cask.WsHandler { channel => + cask.WsActor { + case cask.WsActor.Text("") => channel.send(cask.WsActor.Close()) + case cask.WsActor.Text(data) => + channel.send(cask.WsActor.Text(userName + " " + data)) } } } diff --git a/example/websockets2/app/src/Websockets2.scala b/example/websockets2/app/src/Websockets2.scala new file mode 100644 index 0000000..b78bee3 --- /dev/null +++ b/example/websockets2/app/src/Websockets2.scala @@ -0,0 +1,29 @@ +package app + +import io.undertow.websockets.WebSocketConnectionCallback +import io.undertow.websockets.core.{AbstractReceiveListener, BufferedTextMessage, WebSocketChannel, WebSockets} +import io.undertow.websockets.spi.WebSocketHttpExchange + +object Websockets2 extends cask.MainRoutes{ + @cask.websocket("/connect/:userName") + def showUserProfile(userName: String): cask.WebsocketResult = { + if (userName != "haoyi") cask.Response("", statusCode = 403) + else new WebSocketConnectionCallback() { + override def onConnect(exchange: WebSocketHttpExchange, channel: WebSocketChannel): Unit = { + channel.getReceiveSetter.set( + new AbstractReceiveListener() { + override def onFullTextMessage(channel: WebSocketChannel, message: BufferedTextMessage) = { + message.getData match{ + case "" => channel.close() + case data => WebSockets.sendTextBlocking(userName + " " + data, channel) + } + } + } + ) + channel.resumeReceives() + } + } + } + + initialize() +} diff --git a/example/websockets2/app/test/src/ExampleTests.scala b/example/websockets2/app/test/src/ExampleTests.scala new file mode 100644 index 0000000..27bff5e --- /dev/null +++ b/example/websockets2/app/test/src/ExampleTests.scala @@ -0,0 +1,123 @@ +package app + +import java.util.concurrent.atomic.AtomicInteger + +import org.asynchttpclient.ws.{WebSocket, WebSocketListener, WebSocketUpgradeHandler} +import utest._ + +object ExampleTests extends TestSuite{ + + + def withServer[T](example: cask.main.BaseMain)(f: String => T): T = { + val server = io.undertow.Undertow.builder + .addHttpListener(8080, "localhost") + .setHandler(example.defaultHandler) + .build + server.start() + val res = + try f("http://localhost:8080") + finally server.stop() + res + } + + val tests = Tests{ + test("Websockets") - withServer(Websockets2){ host => + @volatile var out = List.empty[String] + val client = org.asynchttpclient.Dsl.asyncHttpClient(); + try{ + + // 4. open websocket + val ws: WebSocket = client.prepareGet("ws://localhost:8080/connect/haoyi") + .execute(new WebSocketUpgradeHandler.Builder().addWebSocketListener( + new WebSocketListener() { + + override def onTextFrame(payload: String, finalFragment: Boolean, rsv: Int) { + out = payload :: out + } + + def onOpen(websocket: WebSocket) = () + + def onClose(websocket: WebSocket, code: Int, reason: String) = () + + def onError(t: Throwable) = () + }).build() + ).get() + + // 5. send messages + ws.sendTextFrame("hello") + ws.sendTextFrame("world") + ws.sendTextFrame("") + Thread.sleep(100) + out ==> List("haoyi world", "haoyi hello") + + var error: String = "" + val cli2 = client.prepareGet("ws://localhost:8080/connect/nobody") + .execute(new WebSocketUpgradeHandler.Builder().addWebSocketListener( + new WebSocketListener() { + + def onOpen(websocket: WebSocket) = () + + def onClose(websocket: WebSocket, code: Int, reason: String) = () + + def onError(t: Throwable) = { + error = t.toString + } + }).build() + ).get() + + assert(error.contains("403")) + + } finally{ + client.close() + } + } + + test("Websockets2000") - withServer(Websockets2){ host => + @volatile var out = List.empty[String] + val closed = new AtomicInteger(0) + val client = org.asynchttpclient.Dsl.asyncHttpClient(); + val ws = Seq.fill(2000)(client.prepareGet("ws://localhost:8080/connect/haoyi") + .execute(new WebSocketUpgradeHandler.Builder().addWebSocketListener( + new WebSocketListener() { + + override def onTextFrame(payload: String, finalFragment: Boolean, rsv: Int) = { + ExampleTests.synchronized { + out = payload :: out + } + } + + def onOpen(websocket: WebSocket) = () + + def onClose(websocket: WebSocket, code: Int, reason: String) = { + closed.incrementAndGet() + } + + def onError(t: Throwable) = () + }).build() + ).get()) + + try{ + // 5. send messages + ws.foreach(_.sendTextFrame("hello")) + + Thread.sleep(1500) + out.length ==> 2000 + + ws.foreach(_.sendTextFrame("world")) + + Thread.sleep(1500) + out.length ==> 4000 + closed.get() ==> 0 + + ws.foreach(_.sendTextFrame("")) + + Thread.sleep(1500) + closed.get() ==> 2000 + + }finally{ + client.close() + } + } + + } +} diff --git a/example/websockets2/build.sc b/example/websockets2/build.sc new file mode 100644 index 0000000..197e285 --- /dev/null +++ b/example/websockets2/build.sc @@ -0,0 +1,17 @@ +import mill._, scalalib._ + + +trait AppModule extends ScalaModule{ + def scalaVersion = "2.13.0" + def ivyDeps = Agg[Dep]( + ) + object test extends Tests{ + def testFrameworks = Seq("utest.runner.Framework") + + def ivyDeps = Agg( + ivy"com.lihaoyi::utest::0.7.1", + ivy"com.lihaoyi::requests::0.2.0", + ivy"org.asynchttpclient:async-http-client:2.5.2" + ) + } +}
\ No newline at end of file |