aboutsummaryrefslogtreecommitdiff
path: root/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala')
-rw-r--r--repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala56
1 files changed, 30 insertions, 26 deletions
diff --git a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index dbfacba346..d3dafe9c42 100644
--- a/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/scala-2.11/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -267,7 +267,7 @@ class ReplSuite extends SparkFunSuite {
val output = runInterpreter("local",
"""
|import org.apache.spark.sql.functions._
- |import org.apache.spark.sql.Encoder
+ |import org.apache.spark.sql.{Encoder, Encoders}
|import org.apache.spark.sql.expressions.Aggregator
|import org.apache.spark.sql.TypedColumn
|val simpleSum = new Aggregator[Int, Int, Int] {
@@ -275,6 +275,8 @@ class ReplSuite extends SparkFunSuite {
| def reduce(b: Int, a: Int) = b + a // Add an element to the running total
| def merge(b1: Int, b2: Int) = b1 + b2 // Merge intermediate values.
| def finish(b: Int) = b // Return the final result.
+ | def bufferEncoder: Encoder[Int] = Encoders.scalaInt
+ | def outputEncoder: Encoder[Int] = Encoders.scalaInt
|}.toColumn
|
|val ds = Seq(1, 2, 3, 4).toDS()
@@ -321,31 +323,6 @@ class ReplSuite extends SparkFunSuite {
}
}
- test("Datasets agg type-inference") {
- val output = runInterpreter("local",
- """
- |import org.apache.spark.sql.functions._
- |import org.apache.spark.sql.Encoder
- |import org.apache.spark.sql.expressions.Aggregator
- |import org.apache.spark.sql.TypedColumn
- |/** An `Aggregator` that adds up any numeric type returned by the given function. */
- |class SumOf[I, N : Numeric](f: I => N) extends
- | org.apache.spark.sql.expressions.Aggregator[I, N, N] {
- | val numeric = implicitly[Numeric[N]]
- | override def zero: N = numeric.zero
- | override def reduce(b: N, a: I): N = numeric.plus(b, f(a))
- | override def merge(b1: N,b2: N): N = numeric.plus(b1, b2)
- | override def finish(reduction: N): N = reduction
- |}
- |
- |def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn
- |val ds = Seq((1, 1, 2L), (1, 2, 3L), (1, 3, 4L), (2, 1, 5L)).toDS()
- |ds.groupByKey(_._1).agg(sum(_._2), sum(_._3)).collect()
- """.stripMargin)
- assertDoesNotContain("error:", output)
- assertDoesNotContain("Exception", output)
- }
-
test("collecting objects of class defined in repl") {
val output = runInterpreter("local[2]",
"""
@@ -396,4 +373,31 @@ class ReplSuite extends SparkFunSuite {
assertDoesNotContain("error:", output)
assertDoesNotContain("Exception", output)
}
+
+ test("should clone and clean line object in ClosureCleaner") {
+ val output = runInterpreterInPasteMode("local-cluster[1,4,4096]",
+ """
+ |import org.apache.spark.rdd.RDD
+ |
+ |val lines = sc.textFile("pom.xml")
+ |case class Data(s: String)
+ |val dataRDD = lines.map(line => Data(line.take(3)))
+ |dataRDD.cache.count
+ |val repartitioned = dataRDD.repartition(dataRDD.partitions.size)
+ |repartitioned.cache.count
+ |
+ |def getCacheSize(rdd: RDD[_]) = {
+ | sc.getRDDStorageInfo.filter(_.id == rdd.id).map(_.memSize).sum
+ |}
+ |val cacheSize1 = getCacheSize(dataRDD)
+ |val cacheSize2 = getCacheSize(repartitioned)
+ |
+ |// The cache size of dataRDD and the repartitioned one should be similar.
+ |val deviation = math.abs(cacheSize2 - cacheSize1).toDouble / cacheSize1
+ |assert(deviation < 0.2,
+ | s"deviation too large: $deviation, first size: $cacheSize1, second size: $cacheSize2")
+ """.stripMargin)
+ assertDoesNotContain("AssertionError", output)
+ assertDoesNotContain("Exception", output)
+ }
}