From 94a9d11ed1f61205af8067bf17d14dc93935ddf8 Mon Sep 17 00:00:00 2001 From: Sean Zhong Date: Mon, 8 Aug 2016 22:20:54 +0800 Subject: [SPARK-16906][SQL] Adds auxiliary info like input class and input schema in TypedAggregateExpression ## What changes were proposed in this pull request? This PR adds auxiliary info like input class and input schema in TypedAggregateExpression ## How was this patch tested? Manual test. Author: Sean Zhong Closes #14501 from clockfly/typed_aggregation. --- sql/core/src/main/scala/org/apache/spark/sql/Column.scala | 9 ++++++--- sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala | 4 ++-- .../main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala | 2 +- .../scala/org/apache/spark/sql/RelationalGroupedDataset.scala | 2 +- .../spark/sql/execution/aggregate/TypedAggregateExpression.scala | 4 ++++ 5 files changed, 14 insertions(+), 7 deletions(-) (limited to 'sql/core') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index a46d1949e9..844ca7a8e9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -69,12 +69,15 @@ class TypedColumn[-T, U]( * on a decoded object. */ private[sql] def withInputType( - inputDeserializer: Expression, + inputEncoder: ExpressionEncoder[_], inputAttributes: Seq[Attribute]): TypedColumn[T, U] = { - val unresolvedDeserializer = UnresolvedDeserializer(inputDeserializer, inputAttributes) + val unresolvedDeserializer = UnresolvedDeserializer(inputEncoder.deserializer, inputAttributes) val newExpr = expr transform { case ta: TypedAggregateExpression if ta.inputDeserializer.isEmpty => - ta.copy(inputDeserializer = Some(unresolvedDeserializer)) + ta.copy( + inputDeserializer = Some(unresolvedDeserializer), + inputClass = Some(inputEncoder.clsTag.runtimeClass), + inputSchema = Some(inputEncoder.schema)) } new TypedColumn[T, U](newExpr, encoder) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index 9eef5cc5fe..c119df83b3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -1059,7 +1059,7 @@ class Dataset[T] private[sql]( @Experimental def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = { implicit val encoder = c1.encoder - val project = Project(c1.withInputType(exprEnc.deserializer, logicalPlan.output).named :: Nil, + val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil, logicalPlan) if (encoder.flat) { @@ -1078,7 +1078,7 @@ class Dataset[T] private[sql]( protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val encoders = columns.map(_.encoder) val namedColumns = - columns.map(_.withInputType(exprEnc.deserializer, logicalPlan.output).named) + columns.map(_.withInputType(exprEnc, logicalPlan.output).named) val execution = new QueryExecution(sparkSession, Project(namedColumns, logicalPlan)) new Dataset(sparkSession, execution, ExpressionEncoder.tuple(encoders)) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala index a6867a67ee..65a725f3d4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala @@ -201,7 +201,7 @@ class KeyValueGroupedDataset[K, V] private[sql]( protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val encoders = columns.map(_.encoder) val namedColumns = - columns.map(_.withInputType(vExprEnc.deserializer, dataAttributes).named) + columns.map(_.withInputType(vExprEnc, dataAttributes).named) val keyColumn = if (kExprEnc.flat) { assert(groupingAttributes.length == 1) groupingAttributes.head diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala index 1aa5767038..7cfd1cdc7d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala @@ -219,7 +219,7 @@ class RelationalGroupedDataset protected[sql]( def agg(expr: Column, exprs: Column*): DataFrame = { toDF((expr +: exprs).map { case typed: TypedColumn[_, _] => - typed.withInputType(df.exprEnc.deserializer, df.logicalPlan.output).expr + typed.withInputType(df.exprEnc, df.logicalPlan.output).expr case c => c.expr }) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala index 2cdf4703a5..6f7f2f842c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala @@ -47,6 +47,8 @@ object TypedAggregateExpression { new TypedAggregateExpression( aggregator.asInstanceOf[Aggregator[Any, Any, Any]], None, + None, + None, bufferSerializer, bufferDeserializer, outputEncoder.serializer, @@ -62,6 +64,8 @@ object TypedAggregateExpression { case class TypedAggregateExpression( aggregator: Aggregator[Any, Any, Any], inputDeserializer: Option[Expression], + inputClass: Option[Class[_]], + inputSchema: Option[StructType], bufferSerializer: Seq[NamedExpression], bufferDeserializer: Expression, outputSerializer: Seq[Expression], -- cgit v1.2.3