summaryrefslogtreecommitdiff
path: root/shared
diff options
context:
space:
mode:
Diffstat (limited to 'shared')
-rw-r--r--shared/src/main/scala/escale/api.scala141
-rw-r--r--shared/src/main/scala/escale/syntax.scala53
-rw-r--r--shared/src/test/scala/escale/SelectTest.scala47
-rw-r--r--shared/src/test/scala/escale/SimpleTest.scala47
-rw-r--r--shared/src/test/scala/escale/SyntaxTest.scala44
5 files changed, 332 insertions, 0 deletions
diff --git a/shared/src/main/scala/escale/api.scala b/shared/src/main/scala/escale/api.scala
new file mode 100644
index 0000000..ac475ae
--- /dev/null
+++ b/shared/src/main/scala/escale/api.scala
@@ -0,0 +1,141 @@
+package escale
+
+import java.util.concurrent.atomic.AtomicBoolean
+import scala.annotation.tailrec
+import scala.collection.mutable
+import scala.concurrent.{Future, Promise}
+
+class Channel[A](capacity: Int) {
+ require(capacity >= 0, "capacity must be >= 0")
+ import Channel._
+
+ private val puts = mutable.Queue.empty[(Handler[Unit], A)]
+ private val takes = mutable.Queue.empty[Handler[A]]
+
+ private val buffer = mutable.Queue.empty[A]
+
+ @tailrec final def put(handler: Handler[Unit], value: A): Unit =
+ synchronized {
+ if (takes.size > 0) {
+ val th = takes.dequeue()
+ val callback = th.commit()
+ if (th.active) {
+
+ handler.commit()(())
+ callback(value)
+ } else {
+ put(handler, value)
+ }
+ } else if (buffer.size < capacity) {
+ buffer.enqueue(value)
+ handler.commit()(())
+ } else {
+ require(puts.size < MaxOps, "Too many pending put operations.")
+ puts.enqueue(handler -> value)
+ }
+ }
+ def put(value: A): Future[Unit] = {
+ val p = Promise[Unit]
+ put(new Handler[Unit](_ => p.success(())), value)
+ p.future
+ }
+
+ def take(handler: Handler[A]): Unit = synchronized {
+ if (puts.size > 0) {
+ val callback = handler.commit()
+ if (handler.active) {
+ val (ph, pd) = puts.dequeue()
+ val data = if (capacity == 0) {
+ pd
+ } else {
+ val d = buffer.dequeue()
+ buffer.enqueue(pd)
+ d
+ }
+ ph.commit()(())
+ callback(data)
+ }
+ } else if (buffer.isEmpty) {
+ require(takes.size < MaxOps, "Too many pending take operations")
+ takes.enqueue(handler)
+ } else {
+ val callback = handler.commit()
+ if (handler.active) {
+ callback(buffer.dequeue())
+ }
+ }
+ }
+ def take(): Future[A] = {
+ val p = Promise[A]
+ take(new Handler[A](a => p.success(a)))
+ p.future
+ }
+
+}
+
+object Channel {
+ final val MaxOps = 1024
+
+ def apply[A](capacity: Int = 0): Channel[A] = new Channel[A](capacity)
+
+ // TODO: this currently consumes a thread for every instance
+ def timeout(ms: Int): Channel[Unit] = {
+ val c = new Channel[Unit](0)
+ Future {
+ Thread.sleep(ms)
+ c.put(())
+ }(scala.concurrent.ExecutionContext.global)
+ c
+ }
+
+ //def select(ops: Op[_]*): Unit = ???
+
+ def select(channels: Channel[_]*): Future[(Channel[_], Any)] = {
+ val flag = new Flag
+ val result = Promise[(Channel[_], Any)]
+ for (ch <- channels) {
+ val handler = new SelectHandler[Any](flag, v => result.success((ch, v)))
+ ch.take(handler)
+ }
+ result.future
+ }
+
+ type Op[A] = (Channel[A], A => Unit)
+
+ def select2(reads: Op[_]*): Future[Unit] = {
+ val flag = new Flag
+ val done = Promise[Unit]
+ for ((ch, callback) <- reads) {
+ val c = callback.andThen { _ =>
+ done.success(())
+ ()
+ }
+ val handler = new SelectHandler(flag, c)
+ ch.take(handler)
+ }
+ done.future
+ }
+
+}
+class Handler[-A](callback: A => Unit) {
+ def active: Boolean = true
+ def commit(): A => Unit = callback
+}
+
+class Flag {
+ val active = new AtomicBoolean(true)
+}
+class SelectHandler[A](flag: Flag, callback: A => Unit)
+ extends Handler[A](callback) {
+ var _active = true
+ override def active = _active
+ override def commit(): A => Unit =
+ if (flag.active.compareAndSet(true, false)) {
+ callback
+ } else {
+ _active = false
+ _ =>
+ ()
+ }
+
+}
diff --git a/shared/src/main/scala/escale/syntax.scala b/shared/src/main/scala/escale/syntax.scala
new file mode 100644
index 0000000..fc4b789
--- /dev/null
+++ b/shared/src/main/scala/escale/syntax.scala
@@ -0,0 +1,53 @@
+package escale
+
+import scala.concurrent.{ExecutionContext, Future}
+import scala.language.experimental.macros
+
+object Macros {
+ import scala.reflect.macros.blackbox._
+
+ def goImpl[A: c.WeakTypeTag](c: Context)(body: c.Expr[A])(
+ execContext: c.Expr[ExecutionContext]): c.Tree = {
+ import c.universe._
+ val pkg = c.mirror.staticPackage("scala.async")
+ q"""$pkg.Async.async($body)($execContext)"""
+ }
+
+ def asyncTakeImpl[A: c.WeakTypeTag](c: Context)(
+ channel: c.Expr[Channel[A]]): c.Tree = {
+ import c.universe._
+ val pkg = c.mirror.staticPackage("scala.async")
+ q"""$pkg.Async.await($channel.take())"""
+ }
+
+ def asyncPutImpl[A: c.WeakTypeTag](c: Context)(value: c.Expr[A]): c.Tree = {
+ import c.universe._
+ val pkg = c.mirror.staticPackage("scala.async")
+ q"""$pkg.Async.await(${c.prefix}.channel.put($value))"""
+ }
+
+ def selectImpl(c: Context)(channels: c.Expr[Channel[_]]*): c.Tree = {
+ import c.universe._
+ val pkg = c.mirror.staticPackage("scala.async")
+ val Channel = c.mirror.staticModule("escale.Channel")
+ q"""($pkg.Async.await($Channel.select(..$channels)): @unchecked)"""
+ }
+
+}
+
+package object syntax {
+
+ def chan[A](capacity: Int = 0): Channel[A] = Channel[A](capacity)
+
+ def go[A](body: => A)(implicit execContext: ExecutionContext): Future[A] =
+ macro Macros.goImpl[A]
+
+ def !<[A](channel: Channel[A]): A = macro Macros.asyncTakeImpl[A]
+
+ implicit class ChannelOps[A](val channel: Channel[A]) extends AnyVal {
+ def !<(value: A): Unit = macro Macros.asyncPutImpl[A]
+ }
+
+ def select(channels: Channel[_]*): (Channel[_], Any) = macro Macros.selectImpl
+
+}
diff --git a/shared/src/test/scala/escale/SelectTest.scala b/shared/src/test/scala/escale/SelectTest.scala
new file mode 100644
index 0000000..3d41723
--- /dev/null
+++ b/shared/src/test/scala/escale/SelectTest.scala
@@ -0,0 +1,47 @@
+package escale
+
+import utest._
+import scala.async.Async._
+import scala.concurrent.ExecutionContext.Implicits.global
+import syntax._
+
+object SelectTest extends TestSuite {
+ val tests = Tests {
+ "select" - {
+ val ints = Channel[Int](0)
+ val strings = Channel[String](0)
+ val stop = Channel[Unit](0)
+ val cleaned = Channel[Int](10)
+
+ val p0 = async {
+ var done = false
+ do {
+ (await(Channel.select(ints, strings, stop)): @unchecked) match {
+ case (`ints`, value: Int) =>
+ cleaned !< value
+ case (`strings`, value: String) =>
+ cleaned !< value.toInt
+ case (`stop`, _) =>
+ done = true
+ }
+ } while (!done)
+ "done"
+ }
+
+ val p1 = async{
+ ints !< 2
+ }
+ val p2 = async{
+ strings !< "2"
+ ints !< 1
+ }
+ val p3 = async{
+ await(p1)
+ await(p2)
+ stop !< ()
+ }
+ p0
+ }
+ }
+
+}
diff --git a/shared/src/test/scala/escale/SimpleTest.scala b/shared/src/test/scala/escale/SimpleTest.scala
new file mode 100644
index 0000000..6f52b06
--- /dev/null
+++ b/shared/src/test/scala/escale/SimpleTest.scala
@@ -0,0 +1,47 @@
+package escale
+
+import scala.async.Async.{async, await}
+import utest._
+import scala.concurrent.ExecutionContext.Implicits.global
+
+object SimpleTest extends TestSuite {
+ val tests = Tests{
+ "put and take" - {
+ val ch = Channel[Int](0)
+ val p1 = async{
+ await(ch.put(1))
+ await(ch.put(2))
+ await(ch.put(3))
+ await(ch.put(4))
+ }
+ async{
+ await(ch.take())
+ await(ch.take())
+ await(ch.take())
+ await(ch.take())
+ }
+ p1
+ }
+ "put and take while"- {
+ val ch = Channel[Int](0)
+ ch.put(1)
+
+ val p1 = async {
+ await(ch.put(2))
+ await(ch.put(3))
+ await(ch.put(4))
+ await(ch.put(5))
+ }
+
+ val p2 = async {
+ var sum = 0
+ var a = 0
+ while ({ a = await(ch.take()); a } < 5) {
+ sum += a
+ }
+ assert(sum == 10)
+ }
+ p2
+ }
+ }
+} \ No newline at end of file
diff --git a/shared/src/test/scala/escale/SyntaxTest.scala b/shared/src/test/scala/escale/SyntaxTest.scala
new file mode 100644
index 0000000..fffcbe9
--- /dev/null
+++ b/shared/src/test/scala/escale/SyntaxTest.scala
@@ -0,0 +1,44 @@
+package escale
+
+import utest._
+import scala.concurrent.ExecutionContext.Implicits.global
+import escale.syntax._
+import scala.concurrent.Future
+
+object SyntaxTest extends TestSuite {
+ val tests = Tests {
+ "!< and !<" - {
+ val ch1 = chan[Int]()
+ val ch2 = chan[Int](1)
+ go {
+ ch1 !< 1
+ ch1 !< 2
+ ch1 !< 3
+ }
+ go {
+ var sum = 0
+ sum += !<(ch1)
+ sum += !<(ch1)
+ sum += !<(ch1)
+ ch2 !< 4
+ sum += !<(ch2)
+ assert(sum == 10)
+ }
+ }
+ "select syntax" - {
+ def run(): Future[String] = go {
+ val Ch1 = chan[Int]()
+ val Ch2 = chan[Int]()
+
+ go {/*Thread.sleep(1);*/ Ch1 !< 1}
+ go {/*Thread.sleep(1);*/ Ch2 !< 1}
+
+ select(Ch1, Ch2) match {
+ case (Ch1, _) => "ch1 was first"
+ case (Ch2, _) => "ch2 was first"
+ }
+ }
+ run()
+ }
+ }
+}