diff options
author | Wenchen Fan <wenchen@databricks.com> | 2015-11-11 10:21:53 -0800 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2015-11-11 10:21:53 -0800 |
commit | 9c57bc0efce0ac37d8319666f5a8d3e8dce7651c (patch) | |
tree | 444138b2d7cb55b4007689dbe68f18c4af15c809 | |
parent | c964fc101585171aee76996981fe2c9fdafc614e (diff) | |
download | spark-9c57bc0efce0ac37d8319666f5a8d3e8dce7651c.tar.gz spark-9c57bc0efce0ac37d8319666f5a8d3e8dce7651c.tar.bz2 spark-9c57bc0efce0ac37d8319666f5a8d3e8dce7651c.zip |
[SPARK-11656][SQL] support typed aggregate in project list
insert `aEncoder` like we do in `agg`
Author: Wenchen Fan <wenchen@databricks.com>
Closes #9630 from cloud-fan/select.
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 20 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala | 11 |
2 files changed, 27 insertions, 4 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index a7e5ab19bf..87dae6b331 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -21,14 +21,15 @@ import scala.collection.JavaConverters._ import org.apache.spark.annotation.Experimental import org.apache.spark.rdd.RDD -import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias import org.apache.spark.api.java.function._ import org.apache.spark.sql.catalyst.encoders._ import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias import org.apache.spark.sql.catalyst.plans.Inner import org.apache.spark.sql.catalyst.plans.logical._ import org.apache.spark.sql.execution.{Queryable, QueryExecution} +import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression import org.apache.spark.sql.types.StructType /** @@ -359,7 +360,7 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = { - new Dataset[U1](sqlContext, Project(Alias(c1.expr, "_1")() :: Nil, logicalPlan)) + new Dataset[U1](sqlContext, Project(Alias(withEncoder(c1).expr, "_1")() :: Nil, logicalPlan)) } /** @@ -368,11 +369,12 @@ class Dataset[T] private[sql]( * that cast appropriately for the user facing interface. */ protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { - val aliases = columns.zipWithIndex.map { case (c, i) => Alias(c.expr, s"_${i + 1}")() } + val withEncoders = columns.map(withEncoder) + val aliases = withEncoders.zipWithIndex.map { case (c, i) => Alias(c.expr, s"_${i + 1}")() } val unresolvedPlan = Project(aliases, logicalPlan) val execution = new QueryExecution(sqlContext, unresolvedPlan) // Rebind the encoders to the nested schema that will be produced by the select. - val encoders = columns.map(_.encoder.asInstanceOf[ExpressionEncoder[_]]).zip(aliases).map { + val encoders = withEncoders.map(_.encoder.asInstanceOf[ExpressionEncoder[_]]).zip(aliases).map { case (e: ExpressionEncoder[_], a) if !e.flat => e.nested(a.toAttribute).resolve(execution.analyzed.output) case (e, a) => @@ -381,6 +383,16 @@ class Dataset[T] private[sql]( new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders)) } + private def withEncoder(c: TypedColumn[_, _]): TypedColumn[_, _] = { + val e = c.expr transform { + case ta: TypedAggregateExpression if ta.aEncoder.isEmpty => + ta.copy( + aEncoder = Some(encoder.asInstanceOf[ExpressionEncoder[Any]]), + children = queryExecution.analyzed.output) + } + new TypedColumn(e, c.encoder) + } + /** * Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element. * @since 1.6.0 diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala index 002d5c18f0..d4f0ab76cf 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala @@ -114,4 +114,15 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext { ComplexResultAgg.toColumn), ("a", 2.0, (2L, 4L)), ("b", 3.0, (1L, 3L))) } + + test("typed aggregation: in project list") { + val ds = Seq(1, 3, 2, 5).toDS() + + checkAnswer( + ds.select(sum((i: Int) => i)), + 11) + checkAnswer( + ds.select(sum((i: Int) => i), sum((i: Int) => i * 2)), + 11 -> 22) + } } |