diff options
author | Wenchen Fan <wenchen@databricks.com> | 2015-11-13 11:13:09 -0800 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2015-11-13 11:13:09 -0800 |
commit | 23b8188f75d945ef70fbb1c4dc9720c2c5f8cbc3 (patch) | |
tree | 637f94c371b0f360eddc48837190015089fda784 /sql/core/src/main/scala/org/apache | |
parent | a24477996e936b0861819ffb420f763f80f0b1da (diff) | |
download | spark-23b8188f75d945ef70fbb1c4dc9720c2c5f8cbc3.tar.gz spark-23b8188f75d945ef70fbb1c4dc9720c2c5f8cbc3.tar.bz2 spark-23b8188f75d945ef70fbb1c4dc9720c2c5f8cbc3.zip |
[SPARK-11654][SQL][FOLLOW-UP] fix some mistakes and clean up
* rename `AppendColumn` to `AppendColumns` to be consistent with the physical plan name.
* clean up stale comments.
* always pass in resolved encoder to `TypedColumn.withInputType`(test added)
* enable a mistakenly disabled java test.
Author: Wenchen Fan <wenchen@databricks.com>
Closes #9688 from cloud-fan/follow.
Diffstat (limited to 'sql/core/src/main/scala/org/apache')
4 files changed, 8 insertions, 11 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala index 929224460d..82e9cd7f50 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Column.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Column.scala @@ -58,10 +58,11 @@ class TypedColumn[-T, U]( private[sql] def withInputType( inputEncoder: ExpressionEncoder[_], schema: Seq[Attribute]): TypedColumn[T, U] = { + val boundEncoder = inputEncoder.bind(schema).asInstanceOf[ExpressionEncoder[Any]] new TypedColumn[T, U] (expr transform { case ta: TypedAggregateExpression if ta.aEncoder.isEmpty => ta.copy( - aEncoder = Some(inputEncoder.asInstanceOf[ExpressionEncoder[Any]]), + aEncoder = Some(boundEncoder), children = schema) }, encoder) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala index b930e4661c..4cc3aa2465 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/Dataset.scala @@ -299,7 +299,7 @@ class Dataset[T] private[sql]( */ def groupBy[K : Encoder](func: T => K): GroupedDataset[K, T] = { val inputPlan = queryExecution.analyzed - val withGroupingKey = AppendColumn(func, resolvedTEncoder, inputPlan) + val withGroupingKey = AppendColumns(func, resolvedTEncoder, inputPlan) val executed = sqlContext.executePlan(withGroupingKey) new GroupedDataset( @@ -364,13 +364,11 @@ class Dataset[T] private[sql]( * @since 1.6.0 */ def select[U1: Encoder](c1: TypedColumn[T, U1]): Dataset[U1] = { - // We use an unbound encoder since the expression will make up its own schema. - // TODO: This probably doesn't work if we are relying on reordering of the input class fields. new Dataset[U1]( sqlContext, Project( c1.withInputType( - resolvedTEncoder.bind(queryExecution.analyzed.output), + resolvedTEncoder, queryExecution.analyzed.output).named :: Nil, logicalPlan)) } @@ -382,10 +380,8 @@ class Dataset[T] private[sql]( */ protected def selectUntyped(columns: TypedColumn[_, _]*): Dataset[_] = { val encoders = columns.map(_.encoder) - // We use an unbound encoder since the expression will make up its own schema. - // TODO: This probably doesn't work if we are relying on reordering of the input class fields. val namedColumns = - columns.map(_.withInputType(unresolvedTEncoder, queryExecution.analyzed.output).named) + columns.map(_.withInputType(resolvedTEncoder, queryExecution.analyzed.output).named) val execution = new QueryExecution(sqlContext, Project(namedColumns, logicalPlan)) new Dataset(sqlContext, execution, ExpressionEncoder.tuple(encoders)) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala index ae1272ae53..9c16940707 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/GroupedDataset.scala @@ -89,7 +89,7 @@ class GroupedDataset[K, T] private[sql]( } /** - * Applies the given function to each group of data. For each unique group, the function will + * Applies the given function to each group of data. For each unique group, the function will * be passed the group key and an iterator that contains all of the elements in the group. The * function can return an iterator containing elements of an arbitrary type which will be returned * as a new [[Dataset]]. @@ -162,7 +162,7 @@ class GroupedDataset[K, T] private[sql]( val encoders = columns.map(_.encoder) val namedColumns = columns.map( - _.withInputType(resolvedTEncoder.bind(dataAttributes), dataAttributes).named) + _.withInputType(resolvedTEncoder, dataAttributes).named) val aggregate = Aggregate(groupingAttributes, groupingAttributes ++ namedColumns, logicalPlan) val execution = new QueryExecution(sqlContext, aggregate) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index a99ae4674b..67201a2c19 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -321,7 +321,7 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.MapPartitions(f, tEnc, uEnc, output, child) => execution.MapPartitions(f, tEnc, uEnc, output, planLater(child)) :: Nil - case logical.AppendColumn(f, tEnc, uEnc, newCol, child) => + case logical.AppendColumns(f, tEnc, uEnc, newCol, child) => execution.AppendColumns(f, tEnc, uEnc, newCol, planLater(child)) :: Nil case logical.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, child) => execution.MapGroups(f, kEnc, tEnc, uEnc, grouping, output, planLater(child)) :: Nil |