aboutsummaryrefslogtreecommitdiff
path: root/repl
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2016-03-21 10:37:24 -0700
committerYin Huai <yhuai@databricks.com>2016-03-21 10:37:24 -0700
commit43ebf7a9cbd70d6af75e140a6fc91bf0ffc2b877 (patch)
tree2c4561fc312dd29156c1406b3777cbb8abbd48fe /repl
parent060a28c633e559376976561248bcb30c4739b76d (diff)
downloadspark-43ebf7a9cbd70d6af75e140a6fc91bf0ffc2b877.tar.gz
spark-43ebf7a9cbd70d6af75e140a6fc91bf0ffc2b877.tar.bz2
spark-43ebf7a9cbd70d6af75e140a6fc91bf0ffc2b877.zip
[SPARK-13456][SQL] fix creating encoders for case classes defined in Spark shell
## What changes were proposed in this pull request? case classes defined in REPL are wrapped by line classes, and we have a trick for scala 2.10 REPL to automatically register the wrapper classes to `OuterScope` so that we can use when create encoders. However, this trick doesn't work right after we upgrade to scala 2.11, and unfortunately the tests are only in scala 2.10, which makes this bug hidden until now. This PR moves the encoder tests to scala 2.11 `ReplSuite`, and fixes this bug by another approach(the previous trick can't port to scala 2.11 REPL): make `OuterScope` smarter that can detect classes defined in REPL and load the singleton of line wrapper classes automatically. ## How was this patch tested? the migrated encoder tests in `ReplSuite` Author: Wenchen Fan <wenchen@databricks.com> Closes #11410 from cloud-fan/repl.
Diffstat (limited to 'repl')
-rw-r--r--repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala4
-rw-r--r--repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala68
2 files changed, 68 insertions, 4 deletions
diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index cbcccb11f1..6b9aa5071e 100644
--- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -288,7 +288,7 @@ class ReplSuite extends SparkFunSuite {
|import org.apache.spark.sql.Encoder
|import org.apache.spark.sql.expressions.Aggregator
|import org.apache.spark.sql.TypedColumn
- |val simpleSum = new Aggregator[Int, Int, Int] with Serializable {
+ |val simpleSum = new Aggregator[Int, Int, Int] {
| def zero: Int = 0 // The initial value.
| def reduce(b: Int, a: Int) = b + a // Add an element to the running total
| def merge(b1: Int, b2: Int) = b1 + b2 // Merge intermediate values.
@@ -347,7 +347,7 @@ class ReplSuite extends SparkFunSuite {
|import org.apache.spark.sql.expressions.Aggregator
|import org.apache.spark.sql.TypedColumn
|/** An `Aggregator` that adds up any numeric type returned by the given function. */
- |class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializable {
+ |class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] {
| val numeric = implicitly[Numeric[N]]
| override def zero: N = numeric.zero
| override def reduce(b: N, a: I): N = numeric.plus(b, f(a))
diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index 6bee880640..f148a6df47 100644
--- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -249,10 +249,32 @@ class ReplSuite extends SparkFunSuite {
// We need to use local-cluster to test this case.
val output = runInterpreter("local-cluster[1,1,1024]",
"""
- |val sqlContext = new org.apache.spark.sql.SQLContext(sc)
- |import sqlContext.implicits._
|case class TestCaseClass(value: Int)
|sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toDF().collect()
+ |
+ |// Test Dataset Serialization in the REPL
+ |Seq(TestCaseClass(1)).toDS().collect()
+ """.stripMargin)
+ assertDoesNotContain("error:", output)
+ assertDoesNotContain("Exception", output)
+ }
+
+ test("Datasets and encoders") {
+ val output = runInterpreter("local",
+ """
+ |import org.apache.spark.sql.functions._
+ |import org.apache.spark.sql.Encoder
+ |import org.apache.spark.sql.expressions.Aggregator
+ |import org.apache.spark.sql.TypedColumn
+ |val simpleSum = new Aggregator[Int, Int, Int] {
+ | def zero: Int = 0 // The initial value.
+ | def reduce(b: Int, a: Int) = b + a // Add an element to the running total
+ | def merge(b1: Int, b2: Int) = b1 + b2 // Merge intermediate values.
+ | def finish(b: Int) = b // Return the final result.
+ |}.toColumn
+ |
+ |val ds = Seq(1, 2, 3, 4).toDS()
+ |ds.select(simpleSum).collect
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
@@ -295,6 +317,31 @@ class ReplSuite extends SparkFunSuite {
}
}
+ test("Datasets agg type-inference") {
+ val output = runInterpreter("local",
+ """
+ |import org.apache.spark.sql.functions._
+ |import org.apache.spark.sql.Encoder
+ |import org.apache.spark.sql.expressions.Aggregator
+ |import org.apache.spark.sql.TypedColumn
+ |/** An `Aggregator` that adds up any numeric type returned by the given function. */
+ |class SumOf[I, N : Numeric](f: I => N) extends
+ | org.apache.spark.sql.expressions.Aggregator[I, N, N] {
+ | val numeric = implicitly[Numeric[N]]
+ | override def zero: N = numeric.zero
+ | override def reduce(b: N, a: I): N = numeric.plus(b, f(a))
+ | override def merge(b1: N,b2: N): N = numeric.plus(b1, b2)
+ | override def finish(reduction: N): N = reduction
+ |}
+ |
+ |def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn
+ |val ds = Seq((1, 1, 2L), (1, 2, 3L), (1, 3, 4L), (2, 1, 5L)).toDS()
+ |ds.groupByKey(_._1).agg(sum(_._2), sum(_._3)).collect()
+ """.stripMargin)
+ assertDoesNotContain("error:", output)
+ assertDoesNotContain("Exception", output)
+ }
+
test("collecting objects of class defined in repl") {
val output = runInterpreter("local[2]",
"""
@@ -317,4 +364,21 @@ class ReplSuite extends SparkFunSuite {
assertDoesNotContain("Exception", output)
assertContains("ret: Array[(Int, Iterable[Foo])] = Array((1,", output)
}
+
+ test("line wrapper only initialized once when used as encoder outer scope") {
+ val output = runInterpreter("local",
+ """
+ |val fileName = "repl-test-" + System.currentTimeMillis
+ |val tmpDir = System.getProperty("java.io.tmpdir")
+ |val file = new java.io.File(tmpDir, fileName)
+ |def createFile(): Unit = file.createNewFile()
+ |
+ |createFile();case class TestCaseClass(value: Int)
+ |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).collect()
+ |
+ |file.delete()
+ """.stripMargin)
+ assertDoesNotContain("error:", output)
+ assertDoesNotContain("Exception", output)
+ }
}