diff options
author | Prashant Sharma <prashant.s@imaginea.com> | 2013-09-06 17:53:01 +0530 |
---|---|---|
committer | Prashant Sharma <prashant.s@imaginea.com> | 2013-09-06 17:53:01 +0530 |
commit | 4106ae9fbf8a582697deba2198b3b966dec00bfe (patch) | |
tree | 7c3046faee5f62f9ec4c4176125988d7cb5d70e2 /repl/src/test | |
parent | e0dd24dc858777904335218f3001a24bffe73b27 (diff) | |
parent | 5c7494d7c1b7301138fb3dc155a1b0c961126ec6 (diff) | |
download | spark-4106ae9fbf8a582697deba2198b3b966dec00bfe.tar.gz spark-4106ae9fbf8a582697deba2198b3b966dec00bfe.tar.bz2 spark-4106ae9fbf8a582697deba2198b3b966dec00bfe.zip |
Merged with master
Diffstat (limited to 'repl/src/test')
-rw-r--r-- | repl/src/test/resources/log4j.properties | 19 | ||||
-rw-r--r-- | repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala | 189 | ||||
-rw-r--r-- | repl/src/test/scala/org/apache/spark/repl/ReplSuiteMixin.scala (renamed from repl/src/test/scala/spark/repl/ReplSuiteMixin.scala) | 4 | ||||
-rw-r--r-- | repl/src/test/scala/spark/repl/ReplSuite.scala | 152 |
4 files changed, 208 insertions, 156 deletions
diff --git a/repl/src/test/resources/log4j.properties b/repl/src/test/resources/log4j.properties index cfb1a390e6..a6d33e69d2 100644 --- a/repl/src/test/resources/log4j.properties +++ b/repl/src/test/resources/log4j.properties @@ -1,4 +1,21 @@ -# Set everything to be logged to the repl/target/unit-tests.log +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +# Set everything to be logged to the repl/target/unit-tests.log log4j.rootCategory=INFO, file log4j.appender.file=org.apache.log4j.FileAppender log4j.appender.file.append=false diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala new file mode 100644 index 0000000000..b06999a42c --- /dev/null +++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala @@ -0,0 +1,189 @@ +package org.apache.spark.repl + +import java.io._ +import java.net.URLClassLoader + +import scala.collection.mutable.ArrayBuffer + +import org.scalatest.FunSuite +import com.google.common.io.Files + +class ReplSuite extends FunSuite { + def runInterpreter(master: String, input: String): String = { + val in = new BufferedReader(new StringReader(input + "\n")) + val out = new StringWriter() + val cl = getClass.getClassLoader + var paths = new ArrayBuffer[String] + if (cl.isInstanceOf[URLClassLoader]) { + val urlLoader = cl.asInstanceOf[URLClassLoader] + for (url <- urlLoader.getURLs) { + if (url.getProtocol == "file") { + paths += url.getFile + } + } + } + val interp = new SparkILoop(in, new PrintWriter(out), master) + org.apache.spark.repl.Main.interp = interp + val separator = System.getProperty("path.separator") + interp.process(Array("-classpath", paths.mkString(separator))) + org.apache.spark.repl.Main.interp = null + if (interp.sparkContext != null) { + interp.sparkContext.stop() + } + // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown + System.clearProperty("spark.driver.port") + System.clearProperty("spark.hostPort") + return out.toString + } + + def assertContains(message: String, output: String) { + assert(output.contains(message), + "Interpreter output did not contain '" + message + "':\n" + output) + } + + def assertDoesNotContain(message: String, output: String) { + assert(!output.contains(message), + "Interpreter output contained '" + message + "':\n" + output) + } + + test("simple foreach with accumulator") { + val output = runInterpreter("local", """ + |val accum = sc.accumulator(0) + |sc.parallelize(1 to 10).foreach(x => accum += x) + |accum.value + """) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res1: Int = 55", output) + } + + test("external vars") { + val output = runInterpreter("local", """ + |var v = 7 + |sc.parallelize(1 to 10).map(x => v).collect.reduceLeft(_+_) + |v = 10 + |sc.parallelize(1 to 10).map(x => v).collect.reduceLeft(_+_) + """) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res0: Int = 70", output) + assertContains("res1: Int = 100", output) + } + + test("external classes") { + val output = runInterpreter("local", """ + |class C { + |def foo = 5 + |} + |sc.parallelize(1 to 10).map(x => (new C).foo).collect.reduceLeft(_+_) + """) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res0: Int = 50", output) + } + + test("external functions") { + val output = runInterpreter("local", """ + |def double(x: Int) = x + x + |sc.parallelize(1 to 10).map(x => double(x)).collect.reduceLeft(_+_) + """) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res0: Int = 110", output) + } + + test("external functions that access vars") { + val output = runInterpreter("local", """ + |var v = 7 + |def getV() = v + |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + |v = 10 + |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + """) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res0: Int = 70", output) + assertContains("res1: Int = 100", output) + } + + test("broadcast vars") { + // Test that the value that a broadcast var had when it was created is used, + // even if that variable is then modified in the driver program + // TODO: This doesn't actually work for arrays when we run in local mode! + val output = runInterpreter("local", """ + |var array = new Array[Int](5) + |val broadcastArray = sc.broadcast(array) + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + |array(0) = 5 + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res0: Array[Int] = Array(0, 0, 0, 0, 0)", output) + assertContains("res2: Array[Int] = Array(5, 0, 0, 0, 0)", output) + } + + test("interacting with files") { + val tempDir = Files.createTempDir() + val out = new FileWriter(tempDir + "/input") + out.write("Hello world!\n") + out.write("What's up?\n") + out.write("Goodbye\n") + out.close() + val output = runInterpreter("local", """ + |var file = sc.textFile("%s/input").cache() + |file.count() + |file.count() + |file.count() + """.stripMargin.format(tempDir.getAbsolutePath)) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res0: Long = 3", output) + assertContains("res1: Long = 3", output) + assertContains("res2: Long = 3", output) + } + + test("local-cluster mode") { + val output = runInterpreter("local-cluster[1,1,512]", """ + |var v = 7 + |def getV() = v + |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + |v = 10 + |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + |var array = new Array[Int](5) + |val broadcastArray = sc.broadcast(array) + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + |array(0) = 5 + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + """.stripMargin) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res0: Int = 70", output) + assertContains("res1: Int = 100", output) + assertContains("res2: Array[Int] = Array(0, 0, 0, 0, 0)", output) + assertContains("res4: Array[Int] = Array(0, 0, 0, 0, 0)", output) + } + + if (System.getenv("MESOS_NATIVE_LIBRARY") != null) { + test("running on Mesos") { + val output = runInterpreter("localquiet", """ + |var v = 7 + |def getV() = v + |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + |v = 10 + |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) + |var array = new Array[Int](5) + |val broadcastArray = sc.broadcast(array) + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + |array(0) = 5 + |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect + """) + assertDoesNotContain("error:", output) + assertDoesNotContain("Exception", output) + assertContains("res0: Int = 70", output) + assertContains("res1: Int = 100", output) + assertContains("res2: Array[Int] = Array(0, 0, 0, 0, 0)", output) + assertContains("res4: Array[Int] = Array(0, 0, 0, 0, 0)", output) + } + } +} diff --git a/repl/src/test/scala/spark/repl/ReplSuiteMixin.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuiteMixin.scala index d88e44ad19..ccfbf5193a 100644 --- a/repl/src/test/scala/spark/repl/ReplSuiteMixin.scala +++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuiteMixin.scala @@ -1,4 +1,4 @@ -package spark.repl +package org.apache.spark.repl import java.io.BufferedReader import java.io.PrintWriter @@ -10,8 +10,6 @@ import scala.collection.mutable.ArrayBuffer import scala.concurrent.ExecutionContext.Implicits.global import scala.concurrent.future -import spark.deploy.master.Master -import spark.deploy.worker.Worker trait ReplSuiteMixin { def runInterpreter(master: String, input: String): String = { diff --git a/repl/src/test/scala/spark/repl/ReplSuite.scala b/repl/src/test/scala/spark/repl/ReplSuite.scala deleted file mode 100644 index 8df0c3a7f4..0000000000 --- a/repl/src/test/scala/spark/repl/ReplSuite.scala +++ /dev/null @@ -1,152 +0,0 @@ -package spark.repl - -import java.io.FileWriter - -import org.scalatest.FunSuite - -import com.google.common.io.Files - -class ReplSuite extends FunSuite with ReplSuiteMixin { - - test("simple foreach with accumulator") { - val output = runInterpreter("local", """ - val accum = sc.accumulator(0) - sc.parallelize(1 to 10).foreach(x => accum += x) - accum.value - """) - assertDoesNotContain("error:", output) - assertDoesNotContain("Exception", output) - assertContains("res1: Int = 55", output) - } - - test ("external vars") { - val output = runInterpreter("local", """ - var v = 7 - sc.parallelize(1 to 10).map(x => v).collect.reduceLeft(_+_) - v = 10 - sc.parallelize(1 to 10).map(x => v).collect.reduceLeft(_+_) - """) - assertDoesNotContain("error:", output) - assertDoesNotContain("Exception", output) - assertContains("res0: Int = 70", output) - assertContains("res1: Int = 100", output) - } - - test("external classes") { - val output = runInterpreter("local", """ - class C { - def foo = 5 - } - sc.parallelize(1 to 10).map(x => (new C).foo).collect.reduceLeft(_+_) - """) - assertDoesNotContain("error:", output) - assertDoesNotContain("Exception", output) - assertContains("res0: Int = 50", output) - } - - test("external functions") { - val output = runInterpreter("local", """ - def double(x: Int) = x + x - sc.parallelize(1 to 10).map(x => double(x)).collect.reduceLeft(_+_) - """) - assertDoesNotContain("error:", output) - assertDoesNotContain("Exception", output) - assertContains("res0: Int = 110", output) - } - - test("external functions that access vars") { - val output = runInterpreter("local", """ - var v = 7 - def getV() = v - sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) - v = 10 - sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) - """) - assertDoesNotContain("error:", output) - assertDoesNotContain("Exception", output) - assertContains("res0: Int = 70", output) - assertContains("res1: Int = 100", output) - } - - test("broadcast vars") { - // Test that the value that a broadcast var had when it was created is used, - // even if that variable is then modified in the driver program - // TODO: This doesn't actually work for arrays when we run in local mode! - val output = runInterpreter("local", """ - var array = new Array[Int](5) - val broadcastArray = sc.broadcast(array) - sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect - array(0) = 5 - sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect - """) - assertDoesNotContain("error:", output) - assertDoesNotContain("Exception", output) - assertContains("res0: Array[Int] = Array(0, 0, 0, 0, 0)", output) - assertContains("res2: Array[Int] = Array(5, 0, 0, 0, 0)", output) - } - - test("interacting with files") { - val tempDir = Files.createTempDir() - val out = new FileWriter(tempDir + "/input") - out.write("Hello world!\n") - out.write("What's up?\n") - out.write("Goodbye\n") - out.close() - val output = runInterpreter("local", """ - var file = sc.textFile("%s/input").cache() - file.count() - file.count() - file.count() - """.format(tempDir.getAbsolutePath)) - assertDoesNotContain("error:", output) - assertDoesNotContain("Exception", output) - assertContains("res0: Long = 3", output) - assertContains("res1: Long = 3", output) - assertContains("res2: Long = 3", output) - } - - test ("local-cluster mode") { - val output = runInterpreter("local-cluster[1,1,512]", """ - var v = 7 - def getV() = v - sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) - v = 10 - sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) - var array = new Array[Int](5) - val broadcastArray = sc.broadcast(array) - sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect - array(0) = 5 - sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect - """) - assertDoesNotContain("error:", output) - assertDoesNotContain("Exception", output) - assertContains("res0: Int = 70", output) - assertContains("res1: Int = 100", output) - assertContains("res2: Array[Int] = Array(0, 0, 0, 0, 0)", output) - assertContains("res4: Array[Int] = Array(0, 0, 0, 0, 0)", output) - } - - if (System.getenv("MESOS_NATIVE_LIBRARY") != null) { - test("running on Mesos") { - val output = runInterpreter("localquiet", """ - var v = 7 - def getV() = v - sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) - v = 10 - sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_) - var array = new Array[Int](5) - val broadcastArray = sc.broadcast(array) - sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect - array(0) = 5 - sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect - """) - assertDoesNotContain("error:", output) - assertDoesNotContain("Exception", output) - assertContains("res0: Int = 70", output) - assertContains("res1: Int = 100", output) - assertContains("res2: Array[Int] = Array(0, 0, 0, 0, 0)", output) - assertContains("res4: Array[Int] = Array(0, 0, 0, 0, 0)", output) - } - } - -} |