From 3c59af848c37e1530876e95f7321c8757855d030 Mon Sep 17 00:00:00 2001 From: Jakob Odersky Date: Tue, 13 Nov 2018 21:08:51 -0800 Subject: Various enhancements - add select support for takes - add syntax sugar - add support for JS and Native --- build.sbt | 27 ++++- example/src/main/scala/example/main.scala | 79 +++++++++++++-- project/plugins.sbt | 5 + shared/src/main/scala/escale/api.scala | 141 ++++++++++++++++++++++++++ shared/src/main/scala/escale/syntax.scala | 53 ++++++++++ shared/src/test/scala/escale/SelectTest.scala | 47 +++++++++ shared/src/test/scala/escale/SimpleTest.scala | 47 +++++++++ shared/src/test/scala/escale/SyntaxTest.scala | 44 ++++++++ src/main/scala/escale/api.scala | 73 ------------- 9 files changed, 431 insertions(+), 85 deletions(-) create mode 100644 shared/src/main/scala/escale/api.scala create mode 100644 shared/src/main/scala/escale/syntax.scala create mode 100644 shared/src/test/scala/escale/SelectTest.scala create mode 100644 shared/src/test/scala/escale/SimpleTest.scala create mode 100644 shared/src/test/scala/escale/SyntaxTest.scala delete mode 100644 src/main/scala/escale/api.scala diff --git a/build.sbt b/build.sbt index b7d8353..925044b 100644 --- a/build.sbt +++ b/build.sbt @@ -1,16 +1,33 @@ -lazy val escale = project +// shadow sbt-scalajs' crossProject and CrossType from Scala.js 0.6.x +import sbtcrossproject.CrossPlugin.autoImport.{CrossType, crossProject} + +lazy val escale = crossProject(JSPlatform, JVMPlatform, NativePlatform) + .crossType(CrossType.Full) .in(file(".")) .settings( - scalaVersion := "2.12.7", libraryDependencies ++= Seq( "org.scala-lang" % "scala-reflect" % scalaVersion.value, - "org.scala-lang.modules" %% "scala-async" % "0.9.7" - ) + "org.scala-lang.modules" %% "scala-async" % "0.9.7", + "com.lihaoyi" %%% "utest" % "0.6.6" % "test" + ), + testFrameworks += new TestFramework("utest.runner.Framework"), + scalaVersion := crossScalaVersions.value.head + ) + .jsSettings( + crossScalaVersions := "2.12.6" :: "2.11.12" :: Nil + ) + .jvmSettings( + crossScalaVersions := "2.12.7" :: "2.11.12" :: Nil ) + .nativeSettings( + crossScalaVersions := "2.11.12" :: Nil, + nativeLinkStubs := true + ) + lazy val example = project .in(file("example")) - .dependsOn(escale) + .dependsOn(escale.jvm) .settings( scalaVersion := "2.12.7" ) diff --git a/example/src/main/scala/example/main.scala b/example/src/main/scala/example/main.scala index f6d0a48..a8470a2 100644 --- a/example/src/main/scala/example/main.scala +++ b/example/src/main/scala/example/main.scala @@ -1,6 +1,7 @@ package example import escale.Channel +import scala.async.Async import scala.async.Async._ import scala.concurrent.Await import scala.concurrent.ExecutionContext.Implicits.global @@ -8,8 +9,23 @@ import scala.concurrent.duration._ object Main extends App { + //val t = Channel.timeout(300) + //Await.result(t.take(), 10.seconds) + val ch = Channel[Int](0) +// Channel.select( +// ch -> {(x: Int) => println("a")}, +// ch2 -> {(x: String) => println("a")} +// ) + + val p2 = async { + var a = 0 + while ({a = await(ch.take()); a} < 5) { + println(a) + } + } + val p1 = async { await(ch.put(1)) await(ch.put(2)) @@ -17,14 +33,63 @@ object Main extends App { await(ch.put(5)) } - val p2 = async { - await(ch.take()) - await(ch.take()) - await(ch.take()) - await(ch.take()) - } - val result = Await.result(p2, 3.seconds) println(result) } + +object SelectTest extends App { + + val ch = Channel[Int](0) + val t = Channel.timeout(100) + ch.put(2) + + val out = Channel[String](1) +// +// Await.result(Channel.select(ch, t), 10.seconds) match { +// case (`t`, _) => println("timeout") +// case (`ch`, value) => println(value) +// } + + val r = async { + await(Channel.select(ch, t)) match { + case (`t`, _) => println("timeout") + case (`ch`, value: Int) => await(out.put(value.toString)), + } + await(out.take()) + } + Await.result(r, 10.seconds) + println(r) + +} + + +object Select2Test extends App { + + val ch = Channel[Int](0) + val t = Channel.timeout(100) + ch.put(2) + + val out = Channel[String](0) + // + // Await.result(Channel.select(ch, t), 10.seconds) match { + // case (`t`, _) => println("timeout") + // case (`ch`, value) => println(value) + // } + + Channel.select2( + t -> {u: Unit => println("timeout")}, + ch -> {v: Int => println(v); out.put(v.toString); ()} + ) + + val r = async { + await(Channel.select2( + t -> {u: Unit => println("timeout")}, + out -> {s: String => println(s)} + )) + } + Await.result(r, 10.seconds) + println(r) + + +} diff --git a/project/plugins.sbt b/project/plugins.sbt index f86e373..79192b7 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -1 +1,6 @@ +addSbtPlugin("org.portable-scala" % "sbt-scalajs-crossproject" % "0.6.0") +addSbtPlugin("org.portable-scala" % "sbt-scala-native-crossproject" % "0.6.0") +addSbtPlugin("org.scala-js" % "sbt-scalajs" % "0.6.23") +addSbtPlugin("org.scala-native" % "sbt-scala-native" % "0.3.7") + addSbtPlugin("com.geirsson" % "sbt-scalafmt" % "1.5.1") diff --git a/shared/src/main/scala/escale/api.scala b/shared/src/main/scala/escale/api.scala new file mode 100644 index 0000000..ac475ae --- /dev/null +++ b/shared/src/main/scala/escale/api.scala @@ -0,0 +1,141 @@ +package escale + +import java.util.concurrent.atomic.AtomicBoolean +import scala.annotation.tailrec +import scala.collection.mutable +import scala.concurrent.{Future, Promise} + +class Channel[A](capacity: Int) { + require(capacity >= 0, "capacity must be >= 0") + import Channel._ + + private val puts = mutable.Queue.empty[(Handler[Unit], A)] + private val takes = mutable.Queue.empty[Handler[A]] + + private val buffer = mutable.Queue.empty[A] + + @tailrec final def put(handler: Handler[Unit], value: A): Unit = + synchronized { + if (takes.size > 0) { + val th = takes.dequeue() + val callback = th.commit() + if (th.active) { + + handler.commit()(()) + callback(value) + } else { + put(handler, value) + } + } else if (buffer.size < capacity) { + buffer.enqueue(value) + handler.commit()(()) + } else { + require(puts.size < MaxOps, "Too many pending put operations.") + puts.enqueue(handler -> value) + } + } + def put(value: A): Future[Unit] = { + val p = Promise[Unit] + put(new Handler[Unit](_ => p.success(())), value) + p.future + } + + def take(handler: Handler[A]): Unit = synchronized { + if (puts.size > 0) { + val callback = handler.commit() + if (handler.active) { + val (ph, pd) = puts.dequeue() + val data = if (capacity == 0) { + pd + } else { + val d = buffer.dequeue() + buffer.enqueue(pd) + d + } + ph.commit()(()) + callback(data) + } + } else if (buffer.isEmpty) { + require(takes.size < MaxOps, "Too many pending take operations") + takes.enqueue(handler) + } else { + val callback = handler.commit() + if (handler.active) { + callback(buffer.dequeue()) + } + } + } + def take(): Future[A] = { + val p = Promise[A] + take(new Handler[A](a => p.success(a))) + p.future + } + +} + +object Channel { + final val MaxOps = 1024 + + def apply[A](capacity: Int = 0): Channel[A] = new Channel[A](capacity) + + // TODO: this currently consumes a thread for every instance + def timeout(ms: Int): Channel[Unit] = { + val c = new Channel[Unit](0) + Future { + Thread.sleep(ms) + c.put(()) + }(scala.concurrent.ExecutionContext.global) + c + } + + //def select(ops: Op[_]*): Unit = ??? + + def select(channels: Channel[_]*): Future[(Channel[_], Any)] = { + val flag = new Flag + val result = Promise[(Channel[_], Any)] + for (ch <- channels) { + val handler = new SelectHandler[Any](flag, v => result.success((ch, v))) + ch.take(handler) + } + result.future + } + + type Op[A] = (Channel[A], A => Unit) + + def select2(reads: Op[_]*): Future[Unit] = { + val flag = new Flag + val done = Promise[Unit] + for ((ch, callback) <- reads) { + val c = callback.andThen { _ => + done.success(()) + () + } + val handler = new SelectHandler(flag, c) + ch.take(handler) + } + done.future + } + +} +class Handler[-A](callback: A => Unit) { + def active: Boolean = true + def commit(): A => Unit = callback +} + +class Flag { + val active = new AtomicBoolean(true) +} +class SelectHandler[A](flag: Flag, callback: A => Unit) + extends Handler[A](callback) { + var _active = true + override def active = _active + override def commit(): A => Unit = + if (flag.active.compareAndSet(true, false)) { + callback + } else { + _active = false + _ => + () + } + +} diff --git a/shared/src/main/scala/escale/syntax.scala b/shared/src/main/scala/escale/syntax.scala new file mode 100644 index 0000000..fc4b789 --- /dev/null +++ b/shared/src/main/scala/escale/syntax.scala @@ -0,0 +1,53 @@ +package escale + +import scala.concurrent.{ExecutionContext, Future} +import scala.language.experimental.macros + +object Macros { + import scala.reflect.macros.blackbox._ + + def goImpl[A: c.WeakTypeTag](c: Context)(body: c.Expr[A])( + execContext: c.Expr[ExecutionContext]): c.Tree = { + import c.universe._ + val pkg = c.mirror.staticPackage("scala.async") + q"""$pkg.Async.async($body)($execContext)""" + } + + def asyncTakeImpl[A: c.WeakTypeTag](c: Context)( + channel: c.Expr[Channel[A]]): c.Tree = { + import c.universe._ + val pkg = c.mirror.staticPackage("scala.async") + q"""$pkg.Async.await($channel.take())""" + } + + def asyncPutImpl[A: c.WeakTypeTag](c: Context)(value: c.Expr[A]): c.Tree = { + import c.universe._ + val pkg = c.mirror.staticPackage("scala.async") + q"""$pkg.Async.await(${c.prefix}.channel.put($value))""" + } + + def selectImpl(c: Context)(channels: c.Expr[Channel[_]]*): c.Tree = { + import c.universe._ + val pkg = c.mirror.staticPackage("scala.async") + val Channel = c.mirror.staticModule("escale.Channel") + q"""($pkg.Async.await($Channel.select(..$channels)): @unchecked)""" + } + +} + +package object syntax { + + def chan[A](capacity: Int = 0): Channel[A] = Channel[A](capacity) + + def go[A](body: => A)(implicit execContext: ExecutionContext): Future[A] = + macro Macros.goImpl[A] + + def !<[A](channel: Channel[A]): A = macro Macros.asyncTakeImpl[A] + + implicit class ChannelOps[A](val channel: Channel[A]) extends AnyVal { + def !<(value: A): Unit = macro Macros.asyncPutImpl[A] + } + + def select(channels: Channel[_]*): (Channel[_], Any) = macro Macros.selectImpl + +} diff --git a/shared/src/test/scala/escale/SelectTest.scala b/shared/src/test/scala/escale/SelectTest.scala new file mode 100644 index 0000000..3d41723 --- /dev/null +++ b/shared/src/test/scala/escale/SelectTest.scala @@ -0,0 +1,47 @@ +package escale + +import utest._ +import scala.async.Async._ +import scala.concurrent.ExecutionContext.Implicits.global +import syntax._ + +object SelectTest extends TestSuite { + val tests = Tests { + "select" - { + val ints = Channel[Int](0) + val strings = Channel[String](0) + val stop = Channel[Unit](0) + val cleaned = Channel[Int](10) + + val p0 = async { + var done = false + do { + (await(Channel.select(ints, strings, stop)): @unchecked) match { + case (`ints`, value: Int) => + cleaned !< value + case (`strings`, value: String) => + cleaned !< value.toInt + case (`stop`, _) => + done = true + } + } while (!done) + "done" + } + + val p1 = async{ + ints !< 2 + } + val p2 = async{ + strings !< "2" + ints !< 1 + } + val p3 = async{ + await(p1) + await(p2) + stop !< () + } + p0 + } + } + +} diff --git a/shared/src/test/scala/escale/SimpleTest.scala b/shared/src/test/scala/escale/SimpleTest.scala new file mode 100644 index 0000000..6f52b06 --- /dev/null +++ b/shared/src/test/scala/escale/SimpleTest.scala @@ -0,0 +1,47 @@ +package escale + +import scala.async.Async.{async, await} +import utest._ +import scala.concurrent.ExecutionContext.Implicits.global + +object SimpleTest extends TestSuite { + val tests = Tests{ + "put and take" - { + val ch = Channel[Int](0) + val p1 = async{ + await(ch.put(1)) + await(ch.put(2)) + await(ch.put(3)) + await(ch.put(4)) + } + async{ + await(ch.take()) + await(ch.take()) + await(ch.take()) + await(ch.take()) + } + p1 + } + "put and take while"- { + val ch = Channel[Int](0) + ch.put(1) + + val p1 = async { + await(ch.put(2)) + await(ch.put(3)) + await(ch.put(4)) + await(ch.put(5)) + } + + val p2 = async { + var sum = 0 + var a = 0 + while ({ a = await(ch.take()); a } < 5) { + sum += a + } + assert(sum == 10) + } + p2 + } + } +} \ No newline at end of file diff --git a/shared/src/test/scala/escale/SyntaxTest.scala b/shared/src/test/scala/escale/SyntaxTest.scala new file mode 100644 index 0000000..fffcbe9 --- /dev/null +++ b/shared/src/test/scala/escale/SyntaxTest.scala @@ -0,0 +1,44 @@ +package escale + +import utest._ +import scala.concurrent.ExecutionContext.Implicits.global +import escale.syntax._ +import scala.concurrent.Future + +object SyntaxTest extends TestSuite { + val tests = Tests { + "!< and !<" - { + val ch1 = chan[Int]() + val ch2 = chan[Int](1) + go { + ch1 !< 1 + ch1 !< 2 + ch1 !< 3 + } + go { + var sum = 0 + sum += !<(ch1) + sum += !<(ch1) + sum += !<(ch1) + ch2 !< 4 + sum += !<(ch2) + assert(sum == 10) + } + } + "select syntax" - { + def run(): Future[String] = go { + val Ch1 = chan[Int]() + val Ch2 = chan[Int]() + + go {/*Thread.sleep(1);*/ Ch1 !< 1} + go {/*Thread.sleep(1);*/ Ch2 !< 1} + + select(Ch1, Ch2) match { + case (Ch1, _) => "ch1 was first" + case (Ch2, _) => "ch2 was first" + } + } + run() + } + } +} diff --git a/src/main/scala/escale/api.scala b/src/main/scala/escale/api.scala deleted file mode 100644 index 57cb31d..0000000 --- a/src/main/scala/escale/api.scala +++ /dev/null @@ -1,73 +0,0 @@ -package escale - -import scala.collection.mutable -import scala.concurrent.{Future, Promise} - -class Channel[A](capacity: Int) { - require(capacity >= 0, "capacity must be >= 0") - import Channel._ - - private val puts = mutable.Queue.empty[(Handler[Unit], A)] - private val takes = mutable.Queue.empty[Handler[A]] - - private val buffer = mutable.Queue.empty[A] - - def put(value: A): Future[Unit] = synchronized { - val handler = new Handler[Unit] - - if (takes.size > 0) { - val th = takes.dequeue() - th.promise.success(value) - handler.promise.success(()) - } else if (buffer.size < capacity) { - buffer.enqueue(value) - handler.promise.success(()) - } else { - if (puts.size >= MaxOps) { - handler.promise.failure( - new IllegalArgumentException("Too many pending put operations.")) - } else { - puts.enqueue(handler -> value) - } - } - handler.promise.future - } - def take(): Future[A] = synchronized { - val handler = new Handler[A] - - if (puts.size > 0) { - val (ph, pd) = puts.dequeue() - val data = if (capacity == 0) { - pd - } else { - val d = buffer.dequeue() - buffer.enqueue(pd) - d - } - ph.promise.success(()) - handler.promise.success(data) - } else if (buffer.isEmpty) { - if (takes.size >= MaxOps) { - handler.promise.failure( - new IllegalArgumentException("Too many pending take operations.")) - } else { - takes.enqueue(handler) - } - } else { - handler.promise.success(buffer.dequeue()) - } - handler.promise.future - } - -} - -object Channel { - final val MaxOps = 2 - - def apply[A](capacity: Int): Channel[A] = new Channel[A](capacity) - -} - -class Handler[A] { - val promise = Promise[A] -} -- cgit v1.2.3