diff options
author | Prashant Sharma <prashant.s@imaginea.com> | 2013-04-19 13:51:16 +0530 |
---|---|---|
committer | Prashant Sharma <prashant.s@imaginea.com> | 2013-04-19 13:51:16 +0530 |
commit | bf5fc07379c083cdf6de66f28344997651009787 (patch) | |
tree | 4a529e35128ef82a0a19e97280a150e5c06f6381 /repl/src/test/scala | |
parent | 36ccb35371682d5b960e9cbcc80bca7c5db4ce49 (diff) | |
download | spark-bf5fc07379c083cdf6de66f28344997651009787.tar.gz spark-bf5fc07379c083cdf6de66f28344997651009787.tar.bz2 spark-bf5fc07379c083cdf6de66f28344997651009787.zip |
Added more tests
Diffstat (limited to 'repl/src/test/scala')
-rw-r--r-- | repl/src/test/scala/spark/repl/ReplSuiteMixin.scala | 12 | ||||
-rw-r--r-- | repl/src/test/scala/spark/repl/StandaloneClusterReplSuite.scala | 79 |
2 files changed, 83 insertions, 8 deletions
diff --git a/repl/src/test/scala/spark/repl/ReplSuiteMixin.scala b/repl/src/test/scala/spark/repl/ReplSuiteMixin.scala index 35429bf01f..fd1a1b1e7c 100644 --- a/repl/src/test/scala/spark/repl/ReplSuiteMixin.scala +++ b/repl/src/test/scala/spark/repl/ReplSuiteMixin.scala @@ -14,12 +14,15 @@ import spark.deploy.master.Master import spark.deploy.worker.Worker trait ReplSuiteMixin { + val localIp = "127.0.1.2" + val port = "7089" + val sparkUrl = s"spark://$localIp:$port" def setupStandaloneCluster() { - future { Master.main(Array("-i", "127.0.1.2", "-p", "7089")) } + future { Master.main(Array("-i", localIp, "-p", port, "--webui-port", "0")) } Thread.sleep(2000) - future { Worker.main(Array("spark://127.0.1.2:7089", "--webui-port", "0")) } + future { Worker.main(Array(sparkUrl, "--webui-port", "0")) } } - + def runInterpreter(master: String, input: String): String = { val in = new BufferedReader(new StringReader(input + "\n")) val out = new StringWriter() @@ -33,6 +36,7 @@ trait ReplSuiteMixin { } } } + val interp = new SparkILoop(in, new PrintWriter(out), master) spark.repl.Main.interp = interp val separator = System.getProperty("path.separator") @@ -53,4 +57,4 @@ trait ReplSuiteMixin { assert(!(output contains message), "Interpreter output contained '" + message + "':\n" + output) } -}
\ No newline at end of file +} diff --git a/repl/src/test/scala/spark/repl/StandaloneClusterReplSuite.scala b/repl/src/test/scala/spark/repl/StandaloneClusterReplSuite.scala index a0940e2166..0822770fe2 100644 --- a/repl/src/test/scala/spark/repl/StandaloneClusterReplSuite.scala +++ b/repl/src/test/scala/spark/repl/StandaloneClusterReplSuite.scala @@ -1,12 +1,16 @@ package spark.repl +import java.io.FileWriter + import org.scalatest.FunSuite +import com.google.common.io.Files + class StandaloneClusterReplSuite extends FunSuite with ReplSuiteMixin { setupStandaloneCluster test("simple collect") { - val output = runInterpreter("spark://127.0.1.2:7089", """ + val output = runInterpreter(sparkUrl, """ var x = 123 val data = sc.parallelize(1 to 3).map(_ + x) data.take(3) @@ -17,9 +21,9 @@ class StandaloneClusterReplSuite extends FunSuite with ReplSuiteMixin { assertContains("125", output) assertContains("126", output) } - + test("simple foreach with accumulator") { - val output = runInterpreter("spark://127.0.1.2:7089", """ + val output = runInterpreter(sparkUrl, """ val accum = sc.accumulator(0) sc.parallelize(1 to 10).foreach(x => accum += x) accum.value @@ -29,4 +33,71 @@ class StandaloneClusterReplSuite extends FunSuite with ReplSuiteMixin { assertContains("res1: Int = 55", output) } -}
\ No newline at end of file + test("external vars") { + val output = runInterpreter(sparkUrl, """ + var v = 7 + sc.parallelize(1 to 10).map(x => v).take(10).reduceLeft(_+_) + v = 10 + sc.parallelize(1 to 10).map(x => v).take(10).reduceLeft(_+_) + """) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res0: Int = 70", output) + assertContains("res1: Int = 100", output) + } + + test("external classes") { + val output = runInterpreter(sparkUrl, """ + class C { + def foo = 5 + } + sc.parallelize(1 to 10).map(x => (new C).foo).take(10).reduceLeft(_+_) + """) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res0: Int = 50", output) + } + + test("external functions") { + val output = runInterpreter(sparkUrl, """ + def double(x: Int) = x + x + sc.parallelize(1 to 10).map(x => double(x)).take(10).reduceLeft(_+_) + """) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res0: Int = 110", output) + } + + test("external functions that access vars") { + val output = runInterpreter(sparkUrl, """ + var v = 7 + def getV() = v + sc.parallelize(1 to 10).map(x => getV()).take(10).reduceLeft(_+_) + v = 10 + sc.parallelize(1 to 10).map(x => getV()).take(10).reduceLeft(_+_) + """) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res0: Int = 70", output) + assertContains("res1: Int = 100", output) + } + + test("broadcast vars") { + // Test that the value that a broadcast var had when it was created is used, + // even if that variable is then modified in the driver program + + val output = runInterpreter(sparkUrl, """ + var array = new Array[Int](5) + val broadcastArray = sc.broadcast(array) + sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).take(5) + array(0) = 5 + sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).take(5) + """) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res0: Array[Int] = Array(0, 0, 0, 0, 0)", output) + assertContains("res2: Array[Int] = Array(5, 0, 0, 0, 0)", output) + } + + +} |