From 3c59af848c37e1530876e95f7321c8757855d030 Mon Sep 17 00:00:00 2001 From: Jakob Odersky Date: Tue, 13 Nov 2018 21:08:51 -0800 Subject: Various enhancements - add select support for takes - add syntax sugar - add support for JS and Native --- shared/src/main/scala/escale/api.scala | 141 ++++++++++++++++++++++++++ shared/src/main/scala/escale/syntax.scala | 53 ++++++++++ shared/src/test/scala/escale/SelectTest.scala | 47 +++++++++ shared/src/test/scala/escale/SimpleTest.scala | 47 +++++++++ shared/src/test/scala/escale/SyntaxTest.scala | 44 ++++++++ 5 files changed, 332 insertions(+) create mode 100644 shared/src/main/scala/escale/api.scala create mode 100644 shared/src/main/scala/escale/syntax.scala create mode 100644 shared/src/test/scala/escale/SelectTest.scala create mode 100644 shared/src/test/scala/escale/SimpleTest.scala create mode 100644 shared/src/test/scala/escale/SyntaxTest.scala (limited to 'shared') diff --git a/shared/src/main/scala/escale/api.scala b/shared/src/main/scala/escale/api.scala new file mode 100644 index 0000000..ac475ae --- /dev/null +++ b/shared/src/main/scala/escale/api.scala @@ -0,0 +1,141 @@ +package escale + +import java.util.concurrent.atomic.AtomicBoolean +import scala.annotation.tailrec +import scala.collection.mutable +import scala.concurrent.{Future, Promise} + +class Channel[A](capacity: Int) { + require(capacity >= 0, "capacity must be >= 0") + import Channel._ + + private val puts = mutable.Queue.empty[(Handler[Unit], A)] + private val takes = mutable.Queue.empty[Handler[A]] + + private val buffer = mutable.Queue.empty[A] + + @tailrec final def put(handler: Handler[Unit], value: A): Unit = + synchronized { + if (takes.size > 0) { + val th = takes.dequeue() + val callback = th.commit() + if (th.active) { + + handler.commit()(()) + callback(value) + } else { + put(handler, value) + } + } else if (buffer.size < capacity) { + buffer.enqueue(value) + handler.commit()(()) + } else { + require(puts.size < MaxOps, "Too many pending put operations.") + puts.enqueue(handler -> value) + } + } + def put(value: A): Future[Unit] = { + val p = Promise[Unit] + put(new Handler[Unit](_ => p.success(())), value) + p.future + } + + def take(handler: Handler[A]): Unit = synchronized { + if (puts.size > 0) { + val callback = handler.commit() + if (handler.active) { + val (ph, pd) = puts.dequeue() + val data = if (capacity == 0) { + pd + } else { + val d = buffer.dequeue() + buffer.enqueue(pd) + d + } + ph.commit()(()) + callback(data) + } + } else if (buffer.isEmpty) { + require(takes.size < MaxOps, "Too many pending take operations") + takes.enqueue(handler) + } else { + val callback = handler.commit() + if (handler.active) { + callback(buffer.dequeue()) + } + } + } + def take(): Future[A] = { + val p = Promise[A] + take(new Handler[A](a => p.success(a))) + p.future + } + +} + +object Channel { + final val MaxOps = 1024 + + def apply[A](capacity: Int = 0): Channel[A] = new Channel[A](capacity) + + // TODO: this currently consumes a thread for every instance + def timeout(ms: Int): Channel[Unit] = { + val c = new Channel[Unit](0) + Future { + Thread.sleep(ms) + c.put(()) + }(scala.concurrent.ExecutionContext.global) + c + } + + //def select(ops: Op[_]*): Unit = ??? + + def select(channels: Channel[_]*): Future[(Channel[_], Any)] = { + val flag = new Flag + val result = Promise[(Channel[_], Any)] + for (ch <- channels) { + val handler = new SelectHandler[Any](flag, v => result.success((ch, v))) + ch.take(handler) + } + result.future + } + + type Op[A] = (Channel[A], A => Unit) + + def select2(reads: Op[_]*): Future[Unit] = { + val flag = new Flag + val done = Promise[Unit] + for ((ch, callback) <- reads) { + val c = callback.andThen { _ => + done.success(()) + () + } + val handler = new SelectHandler(flag, c) + ch.take(handler) + } + done.future + } + +} +class Handler[-A](callback: A => Unit) { + def active: Boolean = true + def commit(): A => Unit = callback +} + +class Flag { + val active = new AtomicBoolean(true) +} +class SelectHandler[A](flag: Flag, callback: A => Unit) + extends Handler[A](callback) { + var _active = true + override def active = _active + override def commit(): A => Unit = + if (flag.active.compareAndSet(true, false)) { + callback + } else { + _active = false + _ => + () + } + +} diff --git a/shared/src/main/scala/escale/syntax.scala b/shared/src/main/scala/escale/syntax.scala new file mode 100644 index 0000000..fc4b789 --- /dev/null +++ b/shared/src/main/scala/escale/syntax.scala @@ -0,0 +1,53 @@ +package escale + +import scala.concurrent.{ExecutionContext, Future} +import scala.language.experimental.macros + +object Macros { + import scala.reflect.macros.blackbox._ + + def goImpl[A: c.WeakTypeTag](c: Context)(body: c.Expr[A])( + execContext: c.Expr[ExecutionContext]): c.Tree = { + import c.universe._ + val pkg = c.mirror.staticPackage("scala.async") + q"""$pkg.Async.async($body)($execContext)""" + } + + def asyncTakeImpl[A: c.WeakTypeTag](c: Context)( + channel: c.Expr[Channel[A]]): c.Tree = { + import c.universe._ + val pkg = c.mirror.staticPackage("scala.async") + q"""$pkg.Async.await($channel.take())""" + } + + def asyncPutImpl[A: c.WeakTypeTag](c: Context)(value: c.Expr[A]): c.Tree = { + import c.universe._ + val pkg = c.mirror.staticPackage("scala.async") + q"""$pkg.Async.await(${c.prefix}.channel.put($value))""" + } + + def selectImpl(c: Context)(channels: c.Expr[Channel[_]]*): c.Tree = { + import c.universe._ + val pkg = c.mirror.staticPackage("scala.async") + val Channel = c.mirror.staticModule("escale.Channel") + q"""($pkg.Async.await($Channel.select(..$channels)): @unchecked)""" + } + +} + +package object syntax { + + def chan[A](capacity: Int = 0): Channel[A] = Channel[A](capacity) + + def go[A](body: => A)(implicit execContext: ExecutionContext): Future[A] = + macro Macros.goImpl[A] + + def !<[A](channel: Channel[A]): A = macro Macros.asyncTakeImpl[A] + + implicit class ChannelOps[A](val channel: Channel[A]) extends AnyVal { + def !<(value: A): Unit = macro Macros.asyncPutImpl[A] + } + + def select(channels: Channel[_]*): (Channel[_], Any) = macro Macros.selectImpl + +} diff --git a/shared/src/test/scala/escale/SelectTest.scala b/shared/src/test/scala/escale/SelectTest.scala new file mode 100644 index 0000000..3d41723 --- /dev/null +++ b/shared/src/test/scala/escale/SelectTest.scala @@ -0,0 +1,47 @@ +package escale + +import utest._ +import scala.async.Async._ +import scala.concurrent.ExecutionContext.Implicits.global +import syntax._ + +object SelectTest extends TestSuite { + val tests = Tests { + "select" - { + val ints = Channel[Int](0) + val strings = Channel[String](0) + val stop = Channel[Unit](0) + val cleaned = Channel[Int](10) + + val p0 = async { + var done = false + do { + (await(Channel.select(ints, strings, stop)): @unchecked) match { + case (`ints`, value: Int) => + cleaned !< value + case (`strings`, value: String) => + cleaned !< value.toInt + case (`stop`, _) => + done = true + } + } while (!done) + "done" + } + + val p1 = async{ + ints !< 2 + } + val p2 = async{ + strings !< "2" + ints !< 1 + } + val p3 = async{ + await(p1) + await(p2) + stop !< () + } + p0 + } + } + +} diff --git a/shared/src/test/scala/escale/SimpleTest.scala b/shared/src/test/scala/escale/SimpleTest.scala new file mode 100644 index 0000000..6f52b06 --- /dev/null +++ b/shared/src/test/scala/escale/SimpleTest.scala @@ -0,0 +1,47 @@ +package escale + +import scala.async.Async.{async, await} +import utest._ +import scala.concurrent.ExecutionContext.Implicits.global + +object SimpleTest extends TestSuite { + val tests = Tests{ + "put and take" - { + val ch = Channel[Int](0) + val p1 = async{ + await(ch.put(1)) + await(ch.put(2)) + await(ch.put(3)) + await(ch.put(4)) + } + async{ + await(ch.take()) + await(ch.take()) + await(ch.take()) + await(ch.take()) + } + p1 + } + "put and take while"- { + val ch = Channel[Int](0) + ch.put(1) + + val p1 = async { + await(ch.put(2)) + await(ch.put(3)) + await(ch.put(4)) + await(ch.put(5)) + } + + val p2 = async { + var sum = 0 + var a = 0 + while ({ a = await(ch.take()); a } < 5) { + sum += a + } + assert(sum == 10) + } + p2 + } + } +} \ No newline at end of file diff --git a/shared/src/test/scala/escale/SyntaxTest.scala b/shared/src/test/scala/escale/SyntaxTest.scala new file mode 100644 index 0000000..fffcbe9 --- /dev/null +++ b/shared/src/test/scala/escale/SyntaxTest.scala @@ -0,0 +1,44 @@ +package escale + +import utest._ +import scala.concurrent.ExecutionContext.Implicits.global +import escale.syntax._ +import scala.concurrent.Future + +object SyntaxTest extends TestSuite { + val tests = Tests { + "!< and !<" - { + val ch1 = chan[Int]() + val ch2 = chan[Int](1) + go { + ch1 !< 1 + ch1 !< 2 + ch1 !< 3 + } + go { + var sum = 0 + sum += !<(ch1) + sum += !<(ch1) + sum += !<(ch1) + ch2 !< 4 + sum += !<(ch2) + assert(sum == 10) + } + } + "select syntax" - { + def run(): Future[String] = go { + val Ch1 = chan[Int]() + val Ch2 = chan[Int]() + + go {/*Thread.sleep(1);*/ Ch1 !< 1} + go {/*Thread.sleep(1);*/ Ch2 !< 1} + + select(Ch1, Ch2) match { + case (Ch1, _) => "ch1 was first" + case (Ch2, _) => "ch2 was first" + } + } + run() + } + } +} -- cgit v1.2.3