diff options
author | Cheng Hao <hao.cheng@intel.com> | 2015-04-21 15:11:15 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2015-04-21 15:11:15 -0700 |
commit | 7662ec23bb6c4d4fe4c857b6928eaed0a97d3c04 (patch) | |
tree | d9d0b5821fa04de211f0072d860b57cf95c69a74 /sql/core/src/main | |
parent | 2a24bf92e6d36e876bad6a8b4e0ff12c407ebb8a (diff) | |
download | spark-7662ec23bb6c4d4fe4c857b6928eaed0a97d3c04.tar.gz spark-7662ec23bb6c4d4fe4c857b6928eaed0a97d3c04.tar.bz2 spark-7662ec23bb6c4d4fe4c857b6928eaed0a97d3c04.zip |
[SPARK-5817] [SQL] Fix bug of udtf with column names
It's a bug while do query like:
```sql
select d from (select explode(array(1,1)) d from src limit 1) t
```
And it will throws exception like:
```
org.apache.spark.sql.AnalysisException: cannot resolve 'd' given input columns _c0; line 1 pos 7
at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42)
at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$apply$3$$anonfun$apply$1.applyOrElse(CheckAnalysis.scala:48)
at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$apply$3$$anonfun$apply$1.applyOrElse(CheckAnalysis.scala:45)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:250)
at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:250)
at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:50)
at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:249)
at org.apache.spark.sql.catalyst.plans.QueryPlan.org$apache$spark$sql$catalyst$plans$QueryPlan$$transformExpressionUp$1(QueryPlan.scala:103)
at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$2$$anonfun$apply$2.apply(QueryPlan.scala:117)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:244)
at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:244)
at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59)
at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:47)
at scala.collection.TraversableLike$class.map(TraversableLike.scala:244)
at scala.collection.AbstractTraversable.map(Traversable.scala:105)
at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$2.apply(QueryPlan.scala:116)
at scala.collection.Iterator$$anon$11.next(Iterator.scala:328)
```
To solve the bug, it requires code refactoring for UDTF
The major changes are about:
* Simplifying the UDTF development, UDTF will manage the output attribute names any more, instead, the `logical.Generate` will handle that properly.
* UDTF will be asked for the output schema (data types) during the logical plan analyzing.
Author: Cheng Hao <hao.cheng@intel.com>
Closes #4602 from chenghao-intel/explode_bug and squashes the following commits:
c2a5132 [Cheng Hao] add back resolved for Alias
556e982 [Cheng Hao] revert the unncessary change
002c361 [Cheng Hao] change the rule of resolved for Generate
04ae500 [Cheng Hao] add qualifier only for generator output
5ee5d2c [Cheng Hao] prepend the new qualifier
d2e8b43 [Cheng Hao] Update the code as feedback
ca5e7f4 [Cheng Hao] shrink the commits
Diffstat (limited to 'sql/core/src/main')
3 files changed, 24 insertions, 24 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala index 45f5da3876..03d9834d1d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala @@ -34,7 +34,7 @@ import org.apache.spark.api.python.SerDeUtil import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser} -import org.apache.spark.sql.catalyst.analysis.{UnresolvedRelation, ResolvedStar} +import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, ResolvedStar} import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.catalyst.plans.{JoinType, Inner} import org.apache.spark.sql.catalyst.plans.logical._ @@ -711,12 +711,16 @@ class DataFrame private[sql]( */ def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame = { val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType] - val attributes = schema.toAttributes + + val elementTypes = schema.toAttributes.map { attr => (attr.dataType, attr.nullable) } + val names = schema.toAttributes.map(_.name) + val rowFunction = f.andThen(_.map(CatalystTypeConverters.convertToCatalyst(_, schema).asInstanceOf[Row])) - val generator = UserDefinedGenerator(attributes, rowFunction, input.map(_.expr)) + val generator = UserDefinedGenerator(elementTypes, rowFunction, input.map(_.expr)) - Generate(generator, join = true, outer = false, None, logicalPlan) + Generate(generator, join = true, outer = false, + qualifier = None, names.map(UnresolvedAttribute(_)), logicalPlan) } /** @@ -733,12 +737,17 @@ class DataFrame private[sql]( : DataFrame = { val dataType = ScalaReflection.schemaFor[B].dataType val attributes = AttributeReference(outputColumn, dataType)() :: Nil + // TODO handle the metadata? + val elementTypes = attributes.map { attr => (attr.dataType, attr.nullable) } + val names = attributes.map(_.name) + def rowFunction(row: Row): TraversableOnce[Row] = { f(row(0).asInstanceOf[A]).map(o => Row(CatalystTypeConverters.convertToCatalyst(o, dataType))) } - val generator = UserDefinedGenerator(attributes, rowFunction, apply(inputColumn).expr :: Nil) + val generator = UserDefinedGenerator(elementTypes, rowFunction, apply(inputColumn).expr :: Nil) - Generate(generator, join = true, outer = false, None, logicalPlan) + Generate(generator, join = true, outer = false, + qualifier = None, names.map(UnresolvedAttribute(_)), logicalPlan) } ///////////////////////////////////////////////////////////////////////////// diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala index 12271048bb..5201e20a10 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala @@ -27,44 +27,34 @@ import org.apache.spark.sql.catalyst.expressions._ * output of each into a new stream of rows. This operation is similar to a `flatMap` in functional * programming with one important additional feature, which allows the input rows to be joined with * their output. + * @param generator the generator expression * @param join when true, each output row is implicitly joined with the input tuple that produced * it. * @param outer when true, each input row will be output at least once, even if the output of the * given `generator` is empty. `outer` has no effect when `join` is false. + * @param output the output attributes of this node, which constructed in analysis phase, + * and we can not change it, as the parent node bound with it already. */ @DeveloperApi case class Generate( generator: Generator, join: Boolean, outer: Boolean, + output: Seq[Attribute], child: SparkPlan) extends UnaryNode { - // This must be a val since the generator output expr ids are not preserved by serialization. - protected val generatorOutput: Seq[Attribute] = { - if (join && outer) { - generator.output.map(_.withNullability(true)) - } else { - generator.output - } - } - - // This must be a val since the generator output expr ids are not preserved by serialization. - override val output = - if (join) child.output ++ generatorOutput else generatorOutput - val boundGenerator = BindReferences.bindReference(generator, child.output) override def execute(): RDD[Row] = { if (join) { child.execute().mapPartitions { iter => - val nullValues = Seq.fill(generator.output.size)(Literal(null)) + val nullValues = Seq.fill(generator.elementTypes.size)(Literal(null)) // Used to produce rows with no matches when outer = true. val outerProjection = newProjection(child.output ++ nullValues, child.output) - val joinProjection = - newProjection(child.output ++ generatorOutput, child.output ++ generatorOutput) + val joinProjection = newProjection(output, output) val joinedRow = new JoinedRow iter.flatMap {row => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index e687d01f57..030ef118f7 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -312,8 +312,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { execution.Except(planLater(left), planLater(right)) :: Nil case logical.Intersect(left, right) => execution.Intersect(planLater(left), planLater(right)) :: Nil - case logical.Generate(generator, join, outer, _, child) => - execution.Generate(generator, join = join, outer = outer, planLater(child)) :: Nil + case g @ logical.Generate(generator, join, outer, _, _, child) => + execution.Generate( + generator, join = join, outer = outer, g.output, planLater(child)) :: Nil case logical.OneRowRelation => execution.PhysicalRDD(Nil, singleRowRdd) :: Nil case logical.Repartition(expressions, child) => |