aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala36
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala28
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala48
3 files changed, 106 insertions, 6 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index ed98a25415..7b75aeec4c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql
import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.Inner
@@ -78,9 +79,17 @@ class Dataset[T] private(
* ************* */
/**
- * Returns a new `Dataset` where each record has been mapped on to the specified type.
- * TODO: should bind here...
- * TODO: document binding rules
+ * Returns a new `Dataset` where each record has been mapped on to the specified type. The
+ * method used to map columns depend on the type of `U`:
+ * - When `U` is a class, fields for the class will be mapped to columns of the same name
+ * (case sensitivity is determined by `spark.sql.caseSensitive`)
+ * - When `U` is a tuple, the columns will be be mapped by ordinal (i.e. the first column will
+ * be assigned to `_1`).
+ * - When `U` is a primitive type (i.e. String, Int, etc). then the first column of the
+ * [[DataFrame]] will be used.
+ *
+ * If the schema of the [[DataFrame]] does not match the desired `U` type, you can use `select`
+ * along with `alias` or `as` to rearrange or rename as required.
* @since 1.6.0
*/
def as[U : Encoder]: Dataset[U] = {
@@ -225,6 +234,27 @@ class Dataset[T] private(
withGroupingKey.newColumns)
}
+ /**
+ * Returns a [[GroupedDataset]] where the data is grouped by the given [[Column]] expressions.
+ * @since 1.6.0
+ */
+ @scala.annotation.varargs
+ def groupBy(cols: Column*): GroupedDataset[Row, T] = {
+ val withKeyColumns = logicalPlan.output ++ cols.map(_.expr).map(UnresolvedAlias)
+ val withKey = Project(withKeyColumns, logicalPlan)
+ val executed = sqlContext.executePlan(withKey)
+
+ val dataAttributes = executed.analyzed.output.dropRight(cols.size)
+ val keyAttributes = executed.analyzed.output.takeRight(cols.size)
+
+ new GroupedDataset(
+ RowEncoder(keyAttributes.toStructType),
+ encoderFor[T],
+ executed,
+ dataAttributes,
+ keyAttributes)
+ }
+
/* ****************** *
* Typed Relational *
* ****************** */
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
index 612f2b60cd..96d6e9dd54 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql
-import org.apache.spark.sql.catalyst.encoders.Encoder
+import org.apache.spark.sql.catalyst.encoders.{ExpressionEncoder, encoderFor, Encoder}
import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.QueryExecution
@@ -34,12 +34,34 @@ class GroupedDataset[K, T] private[sql](
private val dataAttributes: Seq[Attribute],
private val groupingAttributes: Seq[Attribute]) extends Serializable {
- private implicit def kEnc = kEncoder
- private implicit def tEnc = tEncoder
+ private implicit val kEnc = kEncoder match {
+ case e: ExpressionEncoder[K] => e.resolve(groupingAttributes)
+ case other =>
+ throw new UnsupportedOperationException("Only expression encoders are currently supported")
+ }
+
+ private implicit val tEnc = tEncoder match {
+ case e: ExpressionEncoder[T] => e.resolve(dataAttributes)
+ case other =>
+ throw new UnsupportedOperationException("Only expression encoders are currently supported")
+ }
+
private def logicalPlan = queryExecution.analyzed
private def sqlContext = queryExecution.sqlContext
/**
+ * Returns a new [[GroupedDataset]] where the type of the key has been mapped to the specified
+ * type. The mapping of key columns to the type follows the same rules as `as` on [[Dataset]].
+ */
+ def asKey[L : Encoder]: GroupedDataset[L, T] =
+ new GroupedDataset(
+ encoderFor[L],
+ tEncoder,
+ queryExecution,
+ dataAttributes,
+ groupingAttributes)
+
+ /**
* Returns a [[Dataset]] that contains each unique key.
*/
def keys: Dataset[K] = {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
index 95b8d05cf4..5973fa7f2a 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetSuite.scala
@@ -203,6 +203,54 @@ class DatasetSuite extends QueryTest with SharedSQLContext {
("a", 30), ("b", 3), ("c", 1))
}
+ test("groupBy columns, mapGroups") {
+ val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+ val grouped = ds.groupBy($"_1")
+ val agged = grouped.mapGroups { case (g, iter) =>
+ Iterator((g.getString(0), iter.map(_._2).sum))
+ }
+
+ checkAnswer(
+ agged,
+ ("a", 30), ("b", 3), ("c", 1))
+ }
+
+ test("groupBy columns asKey, mapGroups") {
+ val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+ val grouped = ds.groupBy($"_1").asKey[String]
+ val agged = grouped.mapGroups { case (g, iter) =>
+ Iterator((g, iter.map(_._2).sum))
+ }
+
+ checkAnswer(
+ agged,
+ ("a", 30), ("b", 3), ("c", 1))
+ }
+
+ test("groupBy columns asKey tuple, mapGroups") {
+ val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+ val grouped = ds.groupBy($"_1", lit(1)).asKey[(String, Int)]
+ val agged = grouped.mapGroups { case (g, iter) =>
+ Iterator((g, iter.map(_._2).sum))
+ }
+
+ checkAnswer(
+ agged,
+ (("a", 1), 30), (("b", 1), 3), (("c", 1), 1))
+ }
+
+ test("groupBy columns asKey class, mapGroups") {
+ val ds = Seq(("a", 10), ("a", 20), ("b", 1), ("b", 2), ("c", 1)).toDS()
+ val grouped = ds.groupBy($"_1".as("a"), lit(1).as("b")).asKey[ClassData]
+ val agged = grouped.mapGroups { case (g, iter) =>
+ Iterator((g, iter.map(_._2).sum))
+ }
+
+ checkAnswer(
+ agged,
+ (ClassData("a", 1), 30), (ClassData("b", 1), 3), (ClassData("c", 1), 1))
+ }
+
test("cogroup") {
val ds1 = Seq(1 -> "a", 3 -> "abc", 5 -> "hello", 3 -> "foo").toDS()
val ds2 = Seq(2 -> "q", 3 -> "w", 5 -> "e", 5 -> "r").toDS()