aboutsummaryrefslogtreecommitdiff
path: root/repl
diff options
context:
space:
mode:
Diffstat (limited to 'repl')
-rw-r--r--repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala4
-rw-r--r--repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala68
2 files changed, 68 insertions, 4 deletions
diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index cbcccb11f1..6b9aa5071e 100644
--- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -288,7 +288,7 @@ class ReplSuite extends SparkFunSuite {
|import org.apache.spark.sql.Encoder
|import org.apache.spark.sql.expressions.Aggregator
|import org.apache.spark.sql.TypedColumn
- |val simpleSum = new Aggregator[Int, Int, Int] with Serializable {
+ |val simpleSum = new Aggregator[Int, Int, Int] {
| def zero: Int = 0 // The initial value.
| def reduce(b: Int, a: Int) = b + a // Add an element to the running total
| def merge(b1: Int, b2: Int) = b1 + b2 // Merge intermediate values.
@@ -347,7 +347,7 @@ class ReplSuite extends SparkFunSuite {
|import org.apache.spark.sql.expressions.Aggregator
|import org.apache.spark.sql.TypedColumn
|/** An `Aggregator` that adds up any numeric type returned by the given function. */
- |class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializable {
+ |class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] {
| val numeric = implicitly[Numeric[N]]
| override def zero: N = numeric.zero
| override def reduce(b: N, a: I): N = numeric.plus(b, f(a))
diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index 6bee880640..f148a6df47 100644
--- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -249,10 +249,32 @@ class ReplSuite extends SparkFunSuite {
// We need to use local-cluster to test this case.
val output = runInterpreter("local-cluster[1,1,1024]",
"""
- |val sqlContext = new org.apache.spark.sql.SQLContext(sc)
- |import sqlContext.implicits._
|case class TestCaseClass(value: Int)
|sc.parallelize(1 to 10).map(x => TestCaseClass(x)).toDF().collect()
+ |
+ |// Test Dataset Serialization in the REPL
+ |Seq(TestCaseClass(1)).toDS().collect()
+ """.stripMargin)
+ assertDoesNotContain("error:", output)
+ assertDoesNotContain("Exception", output)
+ }
+
+ test("Datasets and encoders") {
+ val output = runInterpreter("local",
+ """
+ |import org.apache.spark.sql.functions._
+ |import org.apache.spark.sql.Encoder
+ |import org.apache.spark.sql.expressions.Aggregator
+ |import org.apache.spark.sql.TypedColumn
+ |val simpleSum = new Aggregator[Int, Int, Int] {
+ | def zero: Int = 0 // The initial value.
+ | def reduce(b: Int, a: Int) = b + a // Add an element to the running total
+ | def merge(b1: Int, b2: Int) = b1 + b2 // Merge intermediate values.
+ | def finish(b: Int) = b // Return the final result.
+ |}.toColumn
+ |
+ |val ds = Seq(1, 2, 3, 4).toDS()
+ |ds.select(simpleSum).collect
""".stripMargin)
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
@@ -295,6 +317,31 @@ class ReplSuite extends SparkFunSuite {
}
}
+ test("Datasets agg type-inference") {
+ val output = runInterpreter("local",
+ """
+ |import org.apache.spark.sql.functions._
+ |import org.apache.spark.sql.Encoder
+ |import org.apache.spark.sql.expressions.Aggregator
+ |import org.apache.spark.sql.TypedColumn
+ |/** An `Aggregator` that adds up any numeric type returned by the given function. */
+ |class SumOf[I, N : Numeric](f: I => N) extends
+ | org.apache.spark.sql.expressions.Aggregator[I, N, N] {
+ | val numeric = implicitly[Numeric[N]]
+ | override def zero: N = numeric.zero
+ | override def reduce(b: N, a: I): N = numeric.plus(b, f(a))
+ | override def merge(b1: N,b2: N): N = numeric.plus(b1, b2)
+ | override def finish(reduction: N): N = reduction
+ |}
+ |
+ |def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn
+ |val ds = Seq((1, 1, 2L), (1, 2, 3L), (1, 3, 4L), (2, 1, 5L)).toDS()
+ |ds.groupByKey(_._1).agg(sum(_._2), sum(_._3)).collect()
+ """.stripMargin)
+ assertDoesNotContain("error:", output)
+ assertDoesNotContain("Exception", output)
+ }
+
test("collecting objects of class defined in repl") {
val output = runInterpreter("local[2]",
"""
@@ -317,4 +364,21 @@ class ReplSuite extends SparkFunSuite {
assertDoesNotContain("Exception", output)
assertContains("ret: Array[(Int, Iterable[Foo])] = Array((1,", output)
}
+
+ test("line wrapper only initialized once when used as encoder outer scope") {
+ val output = runInterpreter("local",
+ """
+ |val fileName = "repl-test-" + System.currentTimeMillis
+ |val tmpDir = System.getProperty("java.io.tmpdir")
+ |val file = new java.io.File(tmpDir, fileName)
+ |def createFile(): Unit = file.createNewFile()
+ |
+ |createFile();case class TestCaseClass(value: Int)
+ |sc.parallelize(1 to 10).map(x => TestCaseClass(x)).collect()
+ |
+ |file.delete()
+ """.stripMargin)
+ assertDoesNotContain("error:", output)
+ assertDoesNotContain("Exception", output)
+ }
}