diff options
author | Ismael Juma <ismael@juma.me.uk> | 2011-05-27 09:37:34 +0100 |
---|---|---|
committer | Ismael Juma <ismael@juma.me.uk> | 2011-05-27 11:22:50 +0100 |
commit | 1396678baa0a0b9b47e50bb2da4970aca1351b2e (patch) | |
tree | 6bedbe61370dfba294ae7043c3d5615914b79d85 /repl/src/test | |
parent | 3e8114ddbdf598724d0e6cee8507b6afa111f7f3 (diff) | |
download | spark-1396678baa0a0b9b47e50bb2da4970aca1351b2e.tar.gz spark-1396678baa0a0b9b47e50bb2da4970aca1351b2e.tar.bz2 spark-1396678baa0a0b9b47e50bb2da4970aca1351b2e.zip |
Move REPL classes to separate module.
Diffstat (limited to 'repl/src/test')
-rw-r--r-- | repl/src/test/scala/spark/repl/ReplSuite.scala | 144 |
1 files changed, 144 insertions, 0 deletions
diff --git a/repl/src/test/scala/spark/repl/ReplSuite.scala b/repl/src/test/scala/spark/repl/ReplSuite.scala new file mode 100644 index 0000000000..829b1d934e --- /dev/null +++ b/repl/src/test/scala/spark/repl/ReplSuite.scala @@ -0,0 +1,144 @@ +package spark.repl + +import java.io._ +import java.net.URLClassLoader + +import scala.collection.mutable.ArrayBuffer +import scala.collection.JavaConversions._ + +import org.scalatest.FunSuite + +class ReplSuite extends FunSuite { + def runInterpreter(master: String, input: String): String = { + val in = new BufferedReader(new StringReader(input + "\n")) + val out = new StringWriter() + val cl = getClass.getClassLoader + var paths = new ArrayBuffer[String] + if (cl.isInstanceOf[URLClassLoader]) { + val urlLoader = cl.asInstanceOf[URLClassLoader] + for (url <- urlLoader.getURLs) { + if (url.getProtocol == "file") { + paths += url.getFile + } + } + } + val interp = new SparkInterpreterLoop(in, new PrintWriter(out), master) + spark.repl.Main.interp = interp + val separator = System.getProperty("path.separator") + interp.main(Array("-classpath", paths.mkString(separator))) + spark.repl.Main.interp = null + if (interp.sparkContext != null) + interp.sparkContext.stop() + return out.toString + } + + def assertContains(message: String, output: String) { + assert(output contains message, + "Interpreter output did not contain '" + message + "':\n" + output) + } + + def assertDoesNotContain(message: String, output: String) { + assert(!(output contains message), + "Interpreter output contained '" + message + "':\n" + output) + } + + test ("simple foreach with accumulator") { + val output = runInterpreter("local", """ + val accum = sc.accumulator(0) + sc.parallelize(1 to 10).foreach(x => accum += x) + accum.value + """) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res1: Int = 55", output) + } + + test ("external vars") { + val output = runInterpreter("local", """ + var v = 7 + sc.parallelize(1 to 10).map(x => v).collect.reduceLeft(_+_) + v = 10 + sc.parallelize(1 to 10).map(x => v).collect.reduceLeft(_+_) + """) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res0: Int = 70", output) + assertContains("res1: Int = 100", output) + } + + test ("external classes") { + val output = runInterpreter("local", """ + class C { + def foo = 5 + } + sc.parallelize(1 to 10).map(x => (new C).foo).collect.reduceLeft(_+_) + """) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res0: Int = 50", output) + } + + test ("external functions") { + val output = runInterpreter("local", """ + def double(x: Int) = x + x + sc.parallelize(1 to 10).map(x => double(x)).collect.reduceLeft(_+_) + """) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res0: Int = 110", output) + } + + test ("external functions that access vars") { + val output = runInterpreter("local", """ + var v = 7 + def getV() = v + sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + v = 10 + sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + """) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res0: Int = 70", output) + assertContains("res1: Int = 100", output) + } + + test ("broadcast vars") { + // Test that the value that a broadcast var had when it was created is used, + // even if that variable is then modified in the driver program + // TODO: This doesn't actually work for arrays when we run in local mode! + val output = runInterpreter("local", """ + var array = new Array[Int](5) + val broadcastArray = sc.broadcast(array) + sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + array(0) = 5 + sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + """) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res0: Array[Int] = Array(0, 0, 0, 0, 0)", output) + assertContains("res2: Array[Int] = Array(5, 0, 0, 0, 0)", output) + } + + if (System.getenv("MESOS_HOME") != null) { + test ("running on Mesos") { + val output = runInterpreter("localquiet", """ + var v = 7 + def getV() = v + sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + v = 10 + sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + var array = new Array[Int](5) + val broadcastArray = sc.broadcast(array) + sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + array(0) = 5 + sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + """) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res0: Int = 70", output) + assertContains("res1: Int = 100", output) + assertContains("res2: Array[Int] = Array(0, 0, 0, 0, 0)", output) + assertContains("res4: Array[Int] = Array(0, 0, 0, 0, 0)", output) + } + } +} |