aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorSean Zhong <seanzhong@databricks.com>2016-08-08 22:20:54 +0800
committerWenchen Fan <wenchen@databricks.com>2016-08-08 22:20:54 +0800
commit94a9d11ed1f61205af8067bf17d14dc93935ddf8 (patch)
tree019e8403fe4dec7d62eb25384b0da55511809905 /sql/core
parent06f5dc841517e7156f5f445655d97ba541ebbd7e (diff)
downloadspark-94a9d11ed1f61205af8067bf17d14dc93935ddf8.tar.gz
spark-94a9d11ed1f61205af8067bf17d14dc93935ddf8.tar.bz2
spark-94a9d11ed1f61205af8067bf17d14dc93935ddf8.zip
[SPARK-16906][SQL] Adds auxiliary info like input class and input schema in TypedAggregateExpression
## What changes were proposed in this pull request? This PR adds auxiliary info like input class and input schema in TypedAggregateExpression ## How was this patch tested? Manual test. Author: Sean Zhong <seanzhong@databricks.com> Closes #14501 from clockfly/typed_aggregation.
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Column.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala4
5 files changed, 14 insertions, 7 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
index a46d1949e9..844ca7a8e9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala
@@ -69,12 +69,15 @@ class TypedColumn[-T, U](
* on a decoded object.
*/
private[sql] def withInputType(
- inputDeserializer: Expression,
+ inputEncoder: ExpressionEncoder[_],
inputAttributes: Seq[Attribute]): TypedColumn[T, U] = {
- val unresolvedDeserializer = UnresolvedDeserializer(inputDeserializer, inputAttributes)
+ val unresolvedDeserializer = UnresolvedDeserializer(inputEncoder.deserializer, inputAttributes)
val newExpr = expr transform {
case ta: TypedAggregateExpression if ta.inputDeserializer.isEmpty =>
- ta.copy(inputDeserializer = Some(unresolvedDeserializer))
+ ta.copy(
+ inputDeserializer = Some(unresolvedDeserializer),
+ inputClass = Some(inputEncoder.clsTag.runtimeClass),
+ inputSchema = Some(inputEncoder.schema))
}
new TypedColumn[T, U](newExpr, encoder)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
index 9eef5cc5fe..c119df83b3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala
@@ -1059,7 +1059,7 @@ class Dataset[T] private[sql](
@Experimental
def select[U1](c1: TypedColumn[T, U1]): Dataset[U1] = {
implicit val encoder = c1.encoder
- val project = Project(c1.withInputType(exprEnc.deserializer, logicalPlan.output).named :: Nil,
+ val project = Project(c1.withInputType(exprEnc, logicalPlan.output).named :: Nil,
logicalPlan)
if (encoder.flat) {
@@ -1078,7 +1078,7 @@ class Dataset[T] private[sql](
protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
val encoders = columns.map(_.encoder)
val namedColumns =
- columns.map(_.withInputType(exprEnc.deserializer, logicalPlan.output).named)
+ columns.map(_.withInputType(exprEnc, logicalPlan.output).named)
val execution = new QueryExecution(sparkSession, Project(namedColumns, logicalPlan))
new Dataset(sparkSession, execution, ExpressionEncoder.tuple(encoders))
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
index a6867a67ee..65a725f3d4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/KeyValueGroupedDataset.scala
@@ -201,7 +201,7 @@ class KeyValueGroupedDataset[K, V] private[sql](
protected def aggUntyped(columns: TypedColumn[_, _]*): Dataset[_] = {
val encoders = columns.map(_.encoder)
val namedColumns =
- columns.map(_.withInputType(vExprEnc.deserializer, dataAttributes).named)
+ columns.map(_.withInputType(vExprEnc, dataAttributes).named)
val keyColumn = if (kExprEnc.flat) {
assert(groupingAttributes.length == 1)
groupingAttributes.head
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
index 1aa5767038..7cfd1cdc7d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/RelationalGroupedDataset.scala
@@ -219,7 +219,7 @@ class RelationalGroupedDataset protected[sql](
def agg(expr: Column, exprs: Column*): DataFrame = {
toDF((expr +: exprs).map {
case typed: TypedColumn[_, _] =>
- typed.withInputType(df.exprEnc.deserializer, df.logicalPlan.output).expr
+ typed.withInputType(df.exprEnc, df.logicalPlan.output).expr
case c => c.expr
})
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
index 2cdf4703a5..6f7f2f842c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TypedAggregateExpression.scala
@@ -47,6 +47,8 @@ object TypedAggregateExpression {
new TypedAggregateExpression(
aggregator.asInstanceOf[Aggregator[Any, Any, Any]],
None,
+ None,
+ None,
bufferSerializer,
bufferDeserializer,
outputEncoder.serializer,
@@ -62,6 +64,8 @@ object TypedAggregateExpression {
case class TypedAggregateExpression(
aggregator: Aggregator[Any, Any, Any],
inputDeserializer: Option[Expression],
+ inputClass: Option[Class[_]],
+ inputSchema: Option[StructType],
bufferSerializer: Seq[NamedExpression],
bufferDeserializer: Expression,
outputSerializer: Seq[Expression],