aboutsummaryrefslogtreecommitdiff
path: root/repl
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2015-11-20 15:36:30 -0800
committerMichael Armbrust <michael@databricks.com>2015-11-20 15:36:30 -0800
commit968acf3bd9a502fcad15df3e53e359695ae702cc (patch)
tree186858602964d9dc672f6d4e8709fb6addd90ef4 /repl
parent58b4e4f88a330135c4cec04a30d24ef91bc61d91 (diff)
downloadspark-968acf3bd9a502fcad15df3e53e359695ae702cc.tar.gz
spark-968acf3bd9a502fcad15df3e53e359695ae702cc.tar.bz2
spark-968acf3bd9a502fcad15df3e53e359695ae702cc.zip
[SPARK-11889][SQL] Fix type inference for GroupedDataset.agg in REPL
In this PR I delete a method that breaks type inference for aggregators (only in the REPL) The error when this method is present is: ``` <console>:38: error: missing parameter type for expanded function ((x$2) => x$2._2) ds.groupBy(_._1).agg(sum(_._2), sum(_._3)).collect() ``` Author: Michael Armbrust <michael@databricks.com> Closes #9870 from marmbrus/dataset-repl-agg.
Diffstat (limited to 'repl')
-rw-r--r--repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala24
1 files changed, 24 insertions, 0 deletions
diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index 081aa03002..cbcccb11f1 100644
--- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -339,6 +339,30 @@ class ReplSuite extends SparkFunSuite {
}
}
+ test("Datasets agg type-inference") {
+ val output = runInterpreter("local",
+ """
+ |import org.apache.spark.sql.functions._
+ |import org.apache.spark.sql.Encoder
+ |import org.apache.spark.sql.expressions.Aggregator
+ |import org.apache.spark.sql.TypedColumn
+ |/** An `Aggregator` that adds up any numeric type returned by the given function. */
+ |class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializable {
+ | val numeric = implicitly[Numeric[N]]
+ | override def zero: N = numeric.zero
+ | override def reduce(b: N, a: I): N = numeric.plus(b, f(a))
+ | override def merge(b1: N,b2: N): N = numeric.plus(b1, b2)
+ | override def finish(reduction: N): N = reduction
+ |}
+ |
+ |def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn
+ |val ds = Seq((1, 1, 2L), (1, 2, 3L), (1, 3, 4L), (2, 1, 5L)).toDS()
+ |ds.groupBy(_._1).agg(sum(_._2), sum(_._3)).collect()
+ """.stripMargin)
+ assertDoesNotContain("error:", output)
+ assertDoesNotContain("Exception", output)
+ }
+
test("collecting objects of class defined in repl") {
val output = runInterpreter("local[2]",
"""