aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMichael Armbrust <michael@databricks.com>2015-11-05 21:42:32 -0800
committerMichael Armbrust <michael@databricks.com>2015-11-05 21:42:32 -0800
commit363a476c3fefb0263e63fd24df0b2779a64f79ec (patch)
tree74a240205a164f87b8b66a3deefb94c21a25b367
parenteec74ba8bde7f9446cc38e687bda103e85669d35 (diff)
downloadspark-363a476c3fefb0263e63fd24df0b2779a64f79ec.tar.gz
spark-363a476c3fefb0263e63fd24df0b2779a64f79ec.tar.bz2
spark-363a476c3fefb0263e63fd24df0b2779a64f79ec.zip
[SPARK-11528] [SQL] Typed aggregations for Datasets
This PR adds the ability to do typed SQL aggregations. We will likely also want to provide an interface to allow users to do aggregations on objects, but this is deferred to another PR. ```scala val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS() ds.groupBy(_._1).agg(sum("_2").as[Int]).collect() res0: Array(("a", 30), ("b", 3), ("c", 1)) ``` Author: Michael Armbrust <michael@databricks.com> Closes #9499 from marmbrus/dataset-agg.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala93
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala36
4 files changed, 132 insertions, 3 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index 8957df0be6..9ab5c299d0 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -254,6 +254,10 @@ case class AttributeReference(
}
override def toString: String = s"$name#${exprId.id}$typeSuffix"
+
+ // Since the expression id is not in the first constructor it is missing from the default
+ // tree string.
+ override def simpleString: String = s"$name#${exprId.id}: ${dataType.simpleString}"
}
/**
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 500227e93a..4bca9c3b3f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -55,7 +55,7 @@ import org.apache.spark.sql.types.StructType
* @since 1.6.0
*/
@Experimental
-class Dataset[T] private(
+class Dataset[T] private[sql](
@transient val sqlContext: SQLContext,
@transient val queryExecution: QueryExecution,
unresolvedEncoder: Encoder[T]) extends Serializable {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index 96d6e9dd54..b8fc373dff 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -17,16 +17,25 @@
package org.apache.spark.sql
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAlias, UnresolvedAttribute}
import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, Encoder}
-import org.apache.spark.sql.catalyst.expressions.Attribute
+import org.apache.spark.sql.catalyst.expressions.{Expression, NamedExpression, Alias, Attribute}
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.QueryExecution
/**
+ * :: Experimental ::
* A [[Dataset]] has been logically grouped by a user specified grouping key. Users should not
* construct a [[GroupedDataset]] directly, but should instead call `groupBy` on an existing
* [[Dataset]].
+ *
+ * COMPATIBILITY NOTE: Long term we plan to make [[GroupedDataset)]] extend `GroupedData`. However,
+ * making this change to the class hierarchy would break some function signatures. As such, this
+ * class should be considered a preview of the final API. Changes will be made to the interface
+ * after Spark 1.6.
*/
+@Experimental
class GroupedDataset[K, T] private[sql](
private val kEncoder: Encoder[K],
private val tEncoder: Encoder[T],
@@ -35,7 +44,7 @@ class GroupedDataset[K, T] private[sql](
private val groupingAttributes: Seq[Attribute]) extends Serializable {
private implicit val kEnc = kEncoder match {
- case e: ExpressionEncoder[K] => e.resolve(groupingAttributes)
+ case e: ExpressionEncoder[K] => e.unbind(groupingAttributes).resolve(groupingAttributes)
case other =>
throw new UnsupportedOperationException("Only expression encoders are currently supported")
}
@@ -46,9 +55,16 @@ class GroupedDataset[K, T] private[sql](
throw new UnsupportedOperationException("Only expression encoders are currently supported")
}
+ /** Encoders for built in aggregations. */
+ private implicit def newLongEncoder: Encoder[Long] = ExpressionEncoder[Long](flat = true)
+
private def logicalPlan = queryExecution.analyzed
private def sqlContext = queryExecution.sqlContext
+ private def groupedData =
+ new GroupedData(
+ new DataFrame(sqlContext, logicalPlan), groupingAttributes, GroupedData.GroupByType)
+
/**
* Returns a new [[GroupedDataset]] where the type of the key has been mapped to the specified
* type. The mapping of key columns to the type follows the same rules as `as` on [[Dataset]].
@@ -88,6 +104,79 @@ class GroupedDataset[K, T] private[sql](
MapGroups(f, groupingAttributes, logicalPlan))
}
+ // To ensure valid overloading.
+ protected def agg(expr: Column, exprs: Column*): DataFrame =
+ groupedData.agg(expr, exprs: _*)
+
+ /**
+ * Internal helper function for building typed aggregations that return tuples. For simplicity
+ * and code reuse, we do this without the help of the type system and then use helper functions
+ * that cast appropriately for the user facing interface.
+ * TODO: does not handle aggrecations that return nonflat results,
+ */
+ protected def aggUntyped(columns: TypedColumn[_]*): Dataset[_] = {
+ val aliases = (groupingAttributes ++ columns.map(_.expr)).map {
+ case u: UnresolvedAttribute => UnresolvedAlias(u)
+ case expr: NamedExpression => expr
+ case expr: Expression => Alias(expr, expr.prettyString)()
+ }
+
+ val unresolvedPlan = Aggregate(groupingAttributes, aliases, logicalPlan)
+ val execution = new QueryExecution(sqlContext, unresolvedPlan)
+
+ val columnEncoders = columns.map(_.encoder.asInstanceOf[ExpressionEncoder[_]])
+
+ // Rebind the encoders to the nested schema that will be produced by the aggregation.
+ val encoders = (kEnc +: columnEncoders).zip(execution.analyzed.output).map {
+ case (e: ExpressionEncoder[_], a) if !e.flat =>
+ e.nested(a).resolve(execution.analyzed.output)
+ case (e, a) =>
+ e.unbind(a :: Nil).resolve(execution.analyzed.output)
+ }
+ new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders))
+ }
+
+ /**
+ * Computes the given aggregation, returning a [[Dataset]] of tuples for each unique key
+ * and the result of computing this aggregation over all elements in the group.
+ */
+ def agg[A1](col1: TypedColumn[A1]): Dataset[(K, A1)] =
+ aggUntyped(col1).asInstanceOf[Dataset[(K, A1)]]
+
+ /**
+ * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
+ * and the result of computing these aggregations over all elements in the group.
+ */
+ def agg[A1, A2](col1: TypedColumn[A1], col2: TypedColumn[A2]): Dataset[(K, A1, A2)] =
+ aggUntyped(col1, col2).asInstanceOf[Dataset[(K, A1, A2)]]
+
+ /**
+ * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
+ * and the result of computing these aggregations over all elements in the group.
+ */
+ def agg[A1, A2, A3](
+ col1: TypedColumn[A1],
+ col2: TypedColumn[A2],
+ col3: TypedColumn[A3]): Dataset[(K, A1, A2, A3)] =
+ aggUntyped(col1, col2, col3).asInstanceOf[Dataset[(K, A1, A2, A3)]]
+
+ /**
+ * Computes the given aggregations, returning a [[Dataset]] of tuples for each unique key
+ * and the result of computing these aggregations over all elements in the group.
+ */
+ def agg[A1, A2, A3, A4](
+ col1: TypedColumn[A1],
+ col2: TypedColumn[A2],
+ col3: TypedColumn[A3],
+ col4: TypedColumn[A4]): Dataset[(K, A1, A2, A3, A4)] =
+ aggUntyped(col1, col2, col3, col4).asInstanceOf[Dataset[(K, A1, A2, A3, A4)]]
+
+ /**
+ * Returns a [[Dataset]] that contains a tuple with each key and the number of items present
+ * for that key.
+ */
+ def count(): Dataset[(K, Long)] = agg(functions.count("*").as[Long])
+
/**
* Applies the given function to each cogrouped data. For each unique group, the function will
* be passed the grouping key and 2 iterators containing all elements in the group from
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 3e9b621cfd..d61e17edc6 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -258,6 +258,42 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
(ClassData("a", 1), 30), (ClassData("b", 1), 3), (ClassData("c", 1), 1))
}
+ test("typed aggregation: expr") {
+ val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+
+ checkAnswer(
+ ds.groupBy(_._1).agg(sum("_2").as[Int]),
+ ("a", 30), ("b", 3), ("c", 1))
+ }
+
+ test("typed aggregation: expr, expr") {
+ val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+
+ checkAnswer(
+ ds.groupBy(_._1).agg(sum("_2").as[Int], sum($"_2" + 1).as[Long]),
+ ("a", 30, 32L), ("b", 3, 5L), ("c", 1, 2L))
+ }
+
+ test("typed aggregation: expr, expr, expr") {
+ val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+
+ checkAnswer(
+ ds.groupBy(_._1).agg(sum("_2").as[Int], sum($"_2" + 1).as[Long], count("*").as[Long]),
+ ("a", 30, 32L, 2L), ("b", 3, 5L, 2L), ("c", 1, 2L, 1L))
+ }
+
+ test("typed aggregation: expr, expr, expr, expr") {
+ val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+
+ checkAnswer(
+ ds.groupBy(_._1).agg(
+ sum("_2").as[Int],
+ sum($"_2" + 1).as[Long],
+ count("*").as[Long],
+ avg("_2").as[Double]),
+ ("a", 30, 32L, 2L, 15.0), ("b", 3, 5L, 2L, 1.5), ("c", 1, 2L, 1L, 1.0))
+ }
+
test("cogroup") {
val ds1 = Seq(1 -> "a", 3 -> "abc", 5 -> "hello", 3 -> "foo").toDS()
val ds2 = Seq(2 -> "q", 3 -> "w", 5 -> "e", 5 -> "r").toDS()