aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2015-11-11 10:21:53 -0800
committerMichael Armbrust <michael@databricks.com>2015-11-11 10:21:53 -0800
commit9c57bc0efce0ac37d8319666f5a8d3e8dce7651c (patch)
tree444138b2d7cb55b4007689dbe68f18c4af15c809 /sql
parentc964fc101585171aee76996981fe2c9fdafc614e (diff)
downloadspark-9c57bc0efce0ac37d8319666f5a8d3e8dce7651c.tar.gz
spark-9c57bc0efce0ac37d8319666f5a8d3e8dce7651c.tar.bz2
spark-9c57bc0efce0ac37d8319666f5a8d3e8dce7651c.zip
[SPARK-11656][SQL] support typed aggregate in project list
insert `aEncoder` like we do in `agg` Author: Wenchen Fan <wenchen@databricks.com> Closes #9630 from cloud-fan/select.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala20
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala11
2 files changed, 27 insertions, 4 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index a7e5ab19bf..87dae6b331 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -21,14 +21,15 @@ import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias
import org.apache.spark.api.java.function._
import org.apache.spark.sql.catalyst.encoders._
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.analysis.UnresolvedAlias
import org.apache.spark.sql.catalyst.plans.Inner
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.execution.{Queryable, QueryExecution}
+import org.apache.spark.sql.execution.aggregate.TypedAggregateExpression
import org.apache.spark.sql.types.StructType
/**
@@ -359,7 +360,7 @@ class Dataset[T] private[sql](
* @since 1.6.0
*/
def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = {
- new Dataset[U1](sqlContext, Project(Alias(c1.expr, "_1")() :: Nil, logicalPlan))
+ new Dataset[U1](sqlContext, Project(Alias(withEncoder(c1).expr, "_1")() :: Nil, logicalPlan))
}
/**
@@ -368,11 +369,12 @@ class Dataset[T] private[sql](
* that cast appropriately for the user facing interface.
*/
protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
- val aliases = columns.zipWithIndex.map { case (c, i) => Alias(c.expr, s"_${i + 1}")() }
+ val withEncoders = columns.map(withEncoder)
+ val aliases = withEncoders.zipWithIndex.map { case (c, i) => Alias(c.expr, s"_${i + 1}")() }
val unresolvedPlan = Project(aliases, logicalPlan)
val execution = new QueryExecution(sqlContext, unresolvedPlan)
// Rebind the encoders to the nested schema that will be produced by the select.
- val encoders = columns.map(_.encoder.asInstanceOf[ExpressionEncoder[_]]).zip(aliases).map {
+ val encoders = withEncoders.map(_.encoder.asInstanceOf[ExpressionEncoder[_]]).zip(aliases).map {
case (e: ExpressionEncoder[_], a) if !e.flat =>
e.nested(a.toAttribute).resolve(execution.analyzed.output)
case (e, a) =>
@@ -381,6 +383,16 @@ class Dataset[T] private[sql](
new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders))
}
+ private def withEncoder(c: TypedColumn[_, _]): TypedColumn[_, _] = {
+ val e = c.expr transform {
+ case ta: TypedAggregateExpression if ta.aEncoder.isEmpty =>
+ ta.copy(
+ aEncoder = Some(encoder.asInstanceOf[ExpressionEncoder[Any]]),
+ children = queryExecution.analyzed.output)
+ }
+ new TypedColumn(e, c.encoder)
+ }
+
/**
* Returns a new [[Dataset]] by computing the given [[Column]] expressions for each element.
* @since 1.6.0
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
index 002d5c18f0..d4f0ab76cf 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DatasetAggregatorSuite.scala
@@ -114,4 +114,15 @@ class DatasetAggregatorSuite extends QueryTest with SharedSQLContext {
ComplexResultAgg.toColumn),
("a", 2.0, (2L, 4L)), ("b", 3.0, (1L, 3L)))
}
+
+ test("typed aggregation: in project list") {
+ val ds = Seq(1, 3, 2, 5).toDS()
+
+ checkAnswer(
+ ds.select(sum((i: Int) => i)),
+ 11)
+ checkAnswer(
+ ds.select(sum((i: Int) => i), sum((i: Int) => i * 2)),
+ 11 -> 22)
+ }
}