aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala24
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala27
-rw-r--r--sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java8
3 files changed, 30 insertions, 29 deletions
diff --git a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
index 081aa03002..cbcccb11f1 100644
--- a/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
+++ b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
@@ -339,6 +339,30 @@ class ReplSuite extends SparkFunSuite {
}
}
+ test("Datasets agg type-inference") {
+ val output = runInterpreter("local",
+ """
+ |import org.apache.spark.sql.functions._
+ |import org.apache.spark.sql.Encoder
+ |import org.apache.spark.sql.expressions.Aggregator
+ |import org.apache.spark.sql.TypedColumn
+ |/** An `Aggregator` that adds up any numeric type returned by the given function. */
+ |class SumOf[I, N : Numeric](f: I => N) extends Aggregator[I, N, N] with Serializable {
+ | val numeric = implicitly[Numeric[N]]
+ | override def zero: N = numeric.zero
+ | override def reduce(b: N, a: I): N = numeric.plus(b, f(a))
+ | override def merge(b1: N,b2: N): N = numeric.plus(b1, b2)
+ | override def finish(reduction: N): N = reduction
+ |}
+ |
+ |def sum[I, N : Numeric : Encoder](f: I => N): TypedColumn[I, N] = new SumOf(f).toColumn
+ |val ds = Seq((1, 1, 2L), (1, 2, 3L), (1, 3, 4L), (2, 1, 5L)).toDS()
+ |ds.groupBy(_._1).agg(sum(_._2), sum(_._3)).collect()
+ """.stripMargin)
+ assertDoesNotContain("error:", output)
+ assertDoesNotContain("Exception", output)
+ }
+
test("collecting objects of class defined in repl") {
val output = runInterpreter("local[2]",
"""
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index 6de3dd6265..263f049104 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -146,31 +146,10 @@ class GroupedDataset[K, T] private[sql](
reduce(f.call _)
}
- /**
- * Compute aggregates by specifying a series of aggregate columns, and return a [[DataFrame]].
- * We can call `as[T : Encoder]` to turn the returned [[DataFrame]] to [[Dataset]] again.
- *
- * The available aggregate methods are defined in [[org.apache.spark.sql.functions]].
- *
- * {{{
- * // Selects the age of the oldest employee and the aggregate expense for each department
- *
- * // Scala:
- * import org.apache.spark.sql.functions._
- * df.groupBy("department").agg(max("age"), sum("expense"))
- *
- * // Java:
- * import static org.apache.spark.sql.functions.*;
- * df.groupBy("department").agg(max("age"), sum("expense"));
- * }}}
- *
- * We can also use `Aggregator.toColumn` to pass in typed aggregate functions.
- *
- * @since 1.6.0
- */
+ // This is here to prevent us from adding overloads that would be ambiguous.
@scala.annotation.varargs
- def agg(expr: Column, exprs: Column*): DataFrame =
- groupedData.agg(withEncoder(expr), exprs.map(withEncoder): _*)
+ private def agg(exprs: Column*): DataFrame =
+ groupedData.agg(withEncoder(exprs.head), exprs.tail.map(withEncoder): _*)
private def withEncoder(c: Column): Column = c match {
case tc: TypedColumn[_, _] =>
diff --git a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
index ce40dd856f..f7249b8945 100644
--- a/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
+++ b/sql/core/src/test/java/test/org/apache/spark/sql/JavaDatasetSuite.java
@@ -404,11 +404,9 @@ public class JavaDatasetSuite implements Serializable {
grouped.agg(new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()));
Assert.assertEquals(Arrays.asList(tuple2("a", 3), tuple2("b", 3)), agged.collectAsList());
- Dataset<Tuple4<String, Integer, Long, Long>> agged2 = grouped.agg(
- new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()),
- expr("sum(_2)"),
- count("*"))
- .as(Encoders.tuple(Encoders.STRING(), Encoders.INT(), Encoders.LONG(), Encoders.LONG()));
+ Dataset<Tuple2<String, Integer>> agged2 = grouped.agg(
+ new IntSumOf().toColumn(Encoders.INT(), Encoders.INT()))
+ .as(Encoders.tuple(Encoders.STRING(), Encoders.INT()));
Assert.assertEquals(
Arrays.asList(
new Tuple4<>("a", 3, 3L, 2L),