aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorphaller <hallerp@gmail.com>2012-10-31 02:11:20 +0100
committerphaller <hallerp@gmail.com>2012-10-31 02:11:20 +0100
commit5c87d7d1ecf528360d74f219d8a857a516514cdd (patch)
treef86ac1dbcd243107f3df6d6cd465ae1a7d9bb803
parent1584de89b122fbea98b95bdb4a3a45205c932842 (diff)
downloadscala-async-5c87d7d1ecf528360d74f219d8a857a516514cdd.tar.gz
scala-async-5c87d7d1ecf528360d74f219d8a857a516514cdd.tar.bz2
scala-async-5c87d7d1ecf528360d74f219d8a857a516514cdd.zip
Fix #2
-rw-r--r--src/async/library/scala/async/ExprBuilder.scala6
-rw-r--r--test/files/run/await0/MinimalScalaTest.scala102
-rw-r--r--test/files/run/await0/await0.scala76
3 files changed, 183 insertions, 1 deletions
diff --git a/src/async/library/scala/async/ExprBuilder.scala b/src/async/library/scala/async/ExprBuilder.scala
index 32af1b3..776cc7b 100644
--- a/src/async/library/scala/async/ExprBuilder.scala
+++ b/src/async/library/scala/async/ExprBuilder.scala
@@ -267,12 +267,16 @@ class ExprBuilder[C <: Context with Singleton](val c: C) extends AsyncUtils {
mkHandlerTree(state, Block((stats :+ mkOnCompleteStateTree(nextState)): _*))
}
- //TODO: complete for other primitive types, how to handle value classes?
override def varDefForResult: Option[c.Tree] = {
val rhs =
if (resultType <:< definitions.IntTpe) Literal(Constant(0))
else if (resultType <:< definitions.LongTpe) Literal(Constant(0L))
else if (resultType <:< definitions.BooleanTpe) Literal(Constant(false))
+ else if (resultType <:< definitions.FloatTpe) Literal(Constant(0.0f))
+ else if (resultType <:< definitions.DoubleTpe) Literal(Constant(0.0d))
+ else if (resultType <:< definitions.CharTpe) Literal(Constant(0.toChar))
+ else if (resultType <:< definitions.ShortTpe) Literal(Constant(0.toShort))
+ else if (resultType <:< definitions.ByteTpe) Literal(Constant(0.toByte))
else Literal(Constant(null))
Some(
ValDef(Modifiers(Flag.MUTABLE), resultName, TypeTree(resultType), rhs)
diff --git a/test/files/run/await0/MinimalScalaTest.scala b/test/files/run/await0/MinimalScalaTest.scala
new file mode 100644
index 0000000..91de1fc
--- /dev/null
+++ b/test/files/run/await0/MinimalScalaTest.scala
@@ -0,0 +1,102 @@
+import language.reflectiveCalls
+import language.postfixOps
+import language.implicitConversions
+
+import scala.reflect.{ ClassTag, classTag }
+
+import scala.collection.mutable
+import scala.concurrent.{ Future, Awaitable, CanAwait }
+import java.util.concurrent.{ TimeoutException, CountDownLatch, TimeUnit }
+import scala.concurrent.duration.Duration
+
+
+
+trait Output {
+ val buffer = new StringBuilder
+
+ def bufferPrintln(a: Any): Unit = buffer.synchronized {
+ buffer.append(a.toString + "\n")
+ }
+}
+
+
+trait MinimalScalaTest extends Output {
+
+ val throwables = mutable.ArrayBuffer[Throwable]()
+
+ def check() {
+ if (throwables.nonEmpty) println(buffer.toString)
+ }
+
+ implicit def stringops(s: String) = new {
+
+ def should[U](snippets: =>U): U = {
+ bufferPrintln(s + " should:")
+ snippets
+ }
+
+ def in[U](snippet: =>U): Unit = {
+ try {
+ bufferPrintln("- " + s)
+ snippet
+ bufferPrintln("[OK] Test passed.")
+ } catch {
+ case e: Throwable =>
+ bufferPrintln("[FAILED] " + e)
+ bufferPrintln(e.getStackTrace().mkString("\n"))
+ throwables += e
+ }
+ }
+
+ }
+
+ implicit def objectops(obj: Any) = new {
+
+ def mustBe(other: Any) = assert(obj == other, obj + " is not " + other)
+ def mustEqual(other: Any) = mustBe(other)
+
+ }
+
+ def intercept[T <: Throwable: ClassTag](body: =>Any): T = {
+ try {
+ body
+ throw new Exception("Exception of type %s was not thrown".format(classTag[T]))
+ } catch {
+ case t: Throwable =>
+ if (classTag[T].runtimeClass != t.getClass) throw t
+ else t.asInstanceOf[T]
+ }
+ }
+
+ def checkType[T: ClassTag, S](in: Future[T], refclasstag: ClassTag[S]): Boolean = classTag[T] == refclasstag
+}
+
+
+object TestLatch {
+ val DefaultTimeout = Duration(5, TimeUnit.SECONDS)
+
+ def apply(count: Int = 1) = new TestLatch(count)
+}
+
+
+class TestLatch(count: Int = 1) extends Awaitable[Unit] {
+ private var latch = new CountDownLatch(count)
+
+ def countDown() = latch.countDown()
+ def isOpen: Boolean = latch.getCount == 0
+ def open() = while (!isOpen) countDown()
+ def reset() = latch = new CountDownLatch(count)
+
+ @throws(classOf[TimeoutException])
+ def ready(atMost: Duration)(implicit permit: CanAwait) = {
+ val opened = latch.await(atMost.toNanos, TimeUnit.NANOSECONDS)
+ if (!opened) throw new TimeoutException("Timeout of %s." format (atMost.toString))
+ this
+ }
+
+ @throws(classOf[Exception])
+ def result(atMost: Duration)(implicit permit: CanAwait): Unit = {
+ ready(atMost)
+ }
+
+}
diff --git a/test/files/run/await0/await0.scala b/test/files/run/await0/await0.scala
new file mode 100644
index 0000000..dfa3370
--- /dev/null
+++ b/test/files/run/await0/await0.scala
@@ -0,0 +1,76 @@
+/**
+ * Copyright (C) 2012 Typesafe Inc. <http://www.typesafe.com>
+ */
+
+import language.{ reflectiveCalls, postfixOps }
+
+import scala.concurrent.{ Future, ExecutionContext, future, Await }
+import scala.concurrent.duration._
+import scala.async.Async.{ async, await }
+
+object Test extends App {
+
+ Await0Spec.check()
+
+}
+
+class Await0Class {
+ import ExecutionContext.Implicits.global
+
+ def m1(x: Double): Future[Double] = future {
+ Thread.sleep(200)
+ x + 2.0
+ }
+
+ def m2(x: Float): Future[Float] = future {
+ Thread.sleep(200)
+ x + 2.0f
+ }
+
+ def m3(x: Char): Future[Char] = future {
+ Thread.sleep(200)
+ (x.toInt + 2).toChar
+ }
+
+ def m4(x: Short): Future[Short] = future {
+ Thread.sleep(200)
+ (x + 2).toShort
+ }
+
+ def m5(x: Byte): Future[Byte] = future {
+ Thread.sleep(200)
+ (x + 2).toByte
+ }
+
+ def m0(y: Int): Future[Double] = async {
+ val f1 = m1(y.toDouble)
+ val x1: Double = await(f1)
+
+ val f2 = m2(y.toFloat)
+ val x2: Float = await(f2)
+
+ val f3 = m3(y.toChar)
+ val x3: Char = await(f3)
+
+ val f4 = m4(y.toShort)
+ val x4: Short = await(f4)
+
+ val f5 = m5(y.toByte)
+ val x5: Byte = await(f5)
+
+ x1 + x2 + 2.0
+ }
+}
+
+object Await0Spec extends MinimalScalaTest {
+
+ "An async method" should {
+ "support a simple await" in {
+ val o = new Await0Class
+ val fut = o.m0(10)
+ val res = Await.result(fut, 10 seconds)
+ res mustBe(26.0)
+ }
+ }
+
+}