aboutsummaryrefslogtreecommitdiff
path: root/repl/src/test
diff options
context:
space:
mode:
authorPrashant Sharma <prashant.s@imaginea.com>2013-09-06 17:53:01 +0530
committerPrashant Sharma <prashant.s@imaginea.com>2013-09-06 17:53:01 +0530
commit4106ae9fbf8a582697deba2198b3b966dec00bfe (patch)
tree7c3046faee5f62f9ec4c4176125988d7cb5d70e2 /repl/src/test
parente0dd24dc858777904335218f3001a24bffe73b27 (diff)
parent5c7494d7c1b7301138fb3dc155a1b0c961126ec6 (diff)
downloadspark-4106ae9fbf8a582697deba2198b3b966dec00bfe.tar.gz
spark-4106ae9fbf8a582697deba2198b3b966dec00bfe.tar.bz2
spark-4106ae9fbf8a582697deba2198b3b966dec00bfe.zip
Merged with master
Diffstat (limited to 'repl/src/test')
-rw-r--r--repl/src/test/resources/log4j.properties19
-rw-r--r--repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala189
-rw-r--r--repl/src/test/scala/org/apache/spark/repl/ReplSuiteMixin.scala (renamed from repl/src/test/scala/spark/repl/ReplSuiteMixin.scala)4
-rw-r--r--repl/src/test/scala/spark/repl/ReplSuite.scala152
4 files changed, 208 insertions, 156 deletions
diff --git a/repl/src/test/resources/log4j.properties b/repl/src/test/resources/log4j.properties
index cfb1a390e6..a6d33e69d2 100644
--- a/repl/src/test/resources/log4j.properties
+++ b/repl/src/test/resources/log4j.properties
@@ -1,4 +1,21 @@
-# Set everything to be logged to the repl/target/unit-tests.log
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+# Set everything to be logged to the repl/target/unit-tests.log
log4j.rootCategory=INFO, file
log4j.appender.file=org.apache.log4j.FileAppender
log4j.appender.file.append=false
diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala
new file mode 100644
index 0000000000..b06999a42c
--- /dev/null
+++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -0,0 +1,189 @@
+package org.apache.spark.repl
+
+import java.io._
+import java.net.URLClassLoader
+
+import scala.collection.mutable.ArrayBuffer
+
+import org.scalatest.FunSuite
+import com.google.common.io.Files
+
+class ReplSuite extends FunSuite {
+ def runInterpreter(master: String, input: String): String = {
+ val in = new BufferedReader(new StringReader(input + "\n"))
+ val out = new StringWriter()
+ val cl = getClass.getClassLoader
+ var paths = new ArrayBuffer[String]
+ if (cl.isInstanceOf[URLClassLoader]) {
+ val urlLoader = cl.asInstanceOf[URLClassLoader]
+ for (url <- urlLoader.getURLs) {
+ if (url.getProtocol == "file") {
+ paths += url.getFile
+ }
+ }
+ }
+ val interp = new SparkILoop(in, new PrintWriter(out), master)
+ org.apache.spark.repl.Main.interp = interp
+ val separator = System.getProperty("path.separator")
+ interp.process(Array("-classpath", paths.mkString(separator)))
+ org.apache.spark.repl.Main.interp = null
+ if (interp.sparkContext != null) {
+ interp.sparkContext.stop()
+ }
+ // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown
+ System.clearProperty("spark.driver.port")
+ System.clearProperty("spark.hostPort")
+ return out.toString
+ }
+
+ def assertContains(message: String, output: String) {
+ assert(output.contains(message),
+ "Interpreter output did not contain '" + message + "':\n" + output)
+ }
+
+ def assertDoesNotContain(message: String, output: String) {
+ assert(!output.contains(message),
+ "Interpreter output contained '" + message + "':\n" + output)
+ }
+
+ test("simple foreach with accumulator") {
+ val output = runInterpreter("local", """
+ |val accum = sc.accumulator(0)
+ |sc.parallelize(1 to 10).foreach(x => accum += x)
+ |accum.value
+ """)
+ assertDoesNotContain("error:", output)
+ assertDoesNotContain("Exception", output)
+ assertContains("res1: Int = 55", output)
+ }
+
+ test("external vars") {
+ val output = runInterpreter("local", """
+ |var v = 7
+ |sc.parallelize(1 to 10).map(x => v).collect.reduceLeft(_+_)
+ |v = 10
+ |sc.parallelize(1 to 10).map(x => v).collect.reduceLeft(_+_)
+ """)
+ assertDoesNotContain("error:", output)
+ assertDoesNotContain("Exception", output)
+ assertContains("res0: Int = 70", output)
+ assertContains("res1: Int = 100", output)
+ }
+
+ test("external classes") {
+ val output = runInterpreter("local", """
+ |class C {
+ |def foo = 5
+ |}
+ |sc.parallelize(1 to 10).map(x => (new C).foo).collect.reduceLeft(_+_)
+ """)
+ assertDoesNotContain("error:", output)
+ assertDoesNotContain("Exception", output)
+ assertContains("res0: Int = 50", output)
+ }
+
+ test("external functions") {
+ val output = runInterpreter("local", """
+ |def double(x: Int) = x + x
+ |sc.parallelize(1 to 10).map(x => double(x)).collect.reduceLeft(_+_)
+ """)
+ assertDoesNotContain("error:", output)
+ assertDoesNotContain("Exception", output)
+ assertContains("res0: Int = 110", output)
+ }
+
+ test("external functions that access vars") {
+ val output = runInterpreter("local", """
+ |var v = 7
+ |def getV() = v
+ |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_)
+ |v = 10
+ |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_)
+ """)
+ assertDoesNotContain("error:", output)
+ assertDoesNotContain("Exception", output)
+ assertContains("res0: Int = 70", output)
+ assertContains("res1: Int = 100", output)
+ }
+
+ test("broadcast vars") {
+ // Test that the value that a broadcast var had when it was created is used,
+ // even if that variable is then modified in the driver program
+ // TODO: This doesn't actually work for arrays when we run in local mode!
+ val output = runInterpreter("local", """
+ |var array = new Array[Int](5)
+ |val broadcastArray = sc.broadcast(array)
+ |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect
+ |array(0) = 5
+ |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect
+ """.stripMargin)
+ assertDoesNotContain("error:", output)
+ assertDoesNotContain("Exception", output)
+ assertContains("res0: Array[Int] = Array(0, 0, 0, 0, 0)", output)
+ assertContains("res2: Array[Int] = Array(5, 0, 0, 0, 0)", output)
+ }
+
+ test("interacting with files") {
+ val tempDir = Files.createTempDir()
+ val out = new FileWriter(tempDir + "/input")
+ out.write("Hello world!\n")
+ out.write("What's up?\n")
+ out.write("Goodbye\n")
+ out.close()
+ val output = runInterpreter("local", """
+ |var file = sc.textFile("%s/input").cache()
+ |file.count()
+ |file.count()
+ |file.count()
+ """.stripMargin.format(tempDir.getAbsolutePath))
+ assertDoesNotContain("error:", output)
+ assertDoesNotContain("Exception", output)
+ assertContains("res0: Long = 3", output)
+ assertContains("res1: Long = 3", output)
+ assertContains("res2: Long = 3", output)
+ }
+
+ test("local-cluster mode") {
+ val output = runInterpreter("local-cluster[1,1,512]", """
+ |var v = 7
+ |def getV() = v
+ |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_)
+ |v = 10
+ |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_)
+ |var array = new Array[Int](5)
+ |val broadcastArray = sc.broadcast(array)
+ |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect
+ |array(0) = 5
+ |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect
+ """.stripMargin)
+ assertDoesNotContain("error:", output)
+ assertDoesNotContain("Exception", output)
+ assertContains("res0: Int = 70", output)
+ assertContains("res1: Int = 100", output)
+ assertContains("res2: Array[Int] = Array(0, 0, 0, 0, 0)", output)
+ assertContains("res4: Array[Int] = Array(0, 0, 0, 0, 0)", output)
+ }
+
+ if (System.getenv("MESOS_NATIVE_LIBRARY") != null) {
+ test("running on Mesos") {
+ val output = runInterpreter("localquiet", """
+ |var v = 7
+ |def getV() = v
+ |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_)
+ |v = 10
+ |sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_)
+ |var array = new Array[Int](5)
+ |val broadcastArray = sc.broadcast(array)
+ |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect
+ |array(0) = 5
+ |sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect
+ """)
+ assertDoesNotContain("error:", output)
+ assertDoesNotContain("Exception", output)
+ assertContains("res0: Int = 70", output)
+ assertContains("res1: Int = 100", output)
+ assertContains("res2: Array[Int] = Array(0, 0, 0, 0, 0)", output)
+ assertContains("res4: Array[Int] = Array(0, 0, 0, 0, 0)", output)
+ }
+ }
+}
diff --git a/repl/src/test/scala/spark/repl/ReplSuiteMixin.scala b/repl/src/test/scala/org/apache/spark/repl/ReplSuiteMixin.scala
index d88e44ad19..ccfbf5193a 100644
--- a/repl/src/test/scala/spark/repl/ReplSuiteMixin.scala
+++ b/repl/src/test/scala/org/apache/spark/repl/ReplSuiteMixin.scala
@@ -1,4 +1,4 @@
-package spark.repl
+package org.apache.spark.repl
import java.io.BufferedReader
import java.io.PrintWriter
@@ -10,8 +10,6 @@ import scala.collection.mutable.ArrayBuffer
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent.future
-import spark.deploy.master.Master
-import spark.deploy.worker.Worker
trait ReplSuiteMixin {
def runInterpreter(master: String, input: String): String = {
diff --git a/repl/src/test/scala/spark/repl/ReplSuite.scala b/repl/src/test/scala/spark/repl/ReplSuite.scala
deleted file mode 100644
index 8df0c3a7f4..0000000000
--- a/repl/src/test/scala/spark/repl/ReplSuite.scala
+++ /dev/null
@@ -1,152 +0,0 @@
-package spark.repl
-
-import java.io.FileWriter
-
-import org.scalatest.FunSuite
-
-import com.google.common.io.Files
-
-class ReplSuite extends FunSuite with ReplSuiteMixin {
-
- test("simple foreach with accumulator") {
- val output = runInterpreter("local", """
- val accum = sc.accumulator(0)
- sc.parallelize(1 to 10).foreach(x => accum += x)
- accum.value
- """)
- assertDoesNotContain("error:", output)
- assertDoesNotContain("Exception", output)
- assertContains("res1: Int = 55", output)
- }
-
- test ("external vars") {
- val output = runInterpreter("local", """
- var v = 7
- sc.parallelize(1 to 10).map(x => v).collect.reduceLeft(_+_)
- v = 10
- sc.parallelize(1 to 10).map(x => v).collect.reduceLeft(_+_)
- """)
- assertDoesNotContain("error:", output)
- assertDoesNotContain("Exception", output)
- assertContains("res0: Int = 70", output)
- assertContains("res1: Int = 100", output)
- }
-
- test("external classes") {
- val output = runInterpreter("local", """
- class C {
- def foo = 5
- }
- sc.parallelize(1 to 10).map(x => (new C).foo).collect.reduceLeft(_+_)
- """)
- assertDoesNotContain("error:", output)
- assertDoesNotContain("Exception", output)
- assertContains("res0: Int = 50", output)
- }
-
- test("external functions") {
- val output = runInterpreter("local", """
- def double(x: Int) = x + x
- sc.parallelize(1 to 10).map(x => double(x)).collect.reduceLeft(_+_)
- """)
- assertDoesNotContain("error:", output)
- assertDoesNotContain("Exception", output)
- assertContains("res0: Int = 110", output)
- }
-
- test("external functions that access vars") {
- val output = runInterpreter("local", """
- var v = 7
- def getV() = v
- sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_)
- v = 10
- sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_)
- """)
- assertDoesNotContain("error:", output)
- assertDoesNotContain("Exception", output)
- assertContains("res0: Int = 70", output)
- assertContains("res1: Int = 100", output)
- }
-
- test("broadcast vars") {
- // Test that the value that a broadcast var had when it was created is used,
- // even if that variable is then modified in the driver program
- // TODO: This doesn't actually work for arrays when we run in local mode!
- val output = runInterpreter("local", """
- var array = new Array[Int](5)
- val broadcastArray = sc.broadcast(array)
- sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect
- array(0) = 5
- sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect
- """)
- assertDoesNotContain("error:", output)
- assertDoesNotContain("Exception", output)
- assertContains("res0: Array[Int] = Array(0, 0, 0, 0, 0)", output)
- assertContains("res2: Array[Int] = Array(5, 0, 0, 0, 0)", output)
- }
-
- test("interacting with files") {
- val tempDir = Files.createTempDir()
- val out = new FileWriter(tempDir + "/input")
- out.write("Hello world!\n")
- out.write("What's up?\n")
- out.write("Goodbye\n")
- out.close()
- val output = runInterpreter("local", """
- var file = sc.textFile("%s/input").cache()
- file.count()
- file.count()
- file.count()
- """.format(tempDir.getAbsolutePath))
- assertDoesNotContain("error:", output)
- assertDoesNotContain("Exception", output)
- assertContains("res0: Long = 3", output)
- assertContains("res1: Long = 3", output)
- assertContains("res2: Long = 3", output)
- }
-
- test ("local-cluster mode") {
- val output = runInterpreter("local-cluster[1,1,512]", """
- var v = 7
- def getV() = v
- sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_)
- v = 10
- sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_)
- var array = new Array[Int](5)
- val broadcastArray = sc.broadcast(array)
- sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect
- array(0) = 5
- sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect
- """)
- assertDoesNotContain("error:", output)
- assertDoesNotContain("Exception", output)
- assertContains("res0: Int = 70", output)
- assertContains("res1: Int = 100", output)
- assertContains("res2: Array[Int] = Array(0, 0, 0, 0, 0)", output)
- assertContains("res4: Array[Int] = Array(0, 0, 0, 0, 0)", output)
- }
-
- if (System.getenv("MESOS_NATIVE_LIBRARY") != null) {
- test("running on Mesos") {
- val output = runInterpreter("localquiet", """
- var v = 7
- def getV() = v
- sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_)
- v = 10
- sc.parallelize(1 to 10).map(x => getV()).collect.reduceLeft(_+_)
- var array = new Array[Int](5)
- val broadcastArray = sc.broadcast(array)
- sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect
- array(0) = 5
- sc.parallelize(0 to 4).map(x => broadcastArray.value(x)).collect
- """)
- assertDoesNotContain("error:", output)
- assertDoesNotContain("Exception", output)
- assertContains("res0: Int = 70", output)
- assertContains("res1: Int = 100", output)
- assertContains("res2: Array[Int] = Array(0, 0, 0, 0, 0)", output)
- assertContains("res4: Array[Int] = Array(0, 0, 0, 0, 0)", output)
- }
- }
-
-}