aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
authorCheng Hao <hao.cheng@intel.com>2015-04-21 15:11:15 -0700
committerMichael Armbrust <michael@databricks.com>2015-04-21 15:11:15 -0700
commit7662ec23bb6c4d4fe4c857b6928eaed0a97d3c04 (patch)
treed9d0b5821fa04de211f0072d860b57cf95c69a74 /sql/core
parent2a24bf92e6d36e876bad6a8b4e0ff12c407ebb8a (diff)
downloadspark-7662ec23bb6c4d4fe4c857b6928eaed0a97d3c04.tar.gz
spark-7662ec23bb6c4d4fe4c857b6928eaed0a97d3c04.tar.bz2
spark-7662ec23bb6c4d4fe4c857b6928eaed0a97d3c04.zip
[SPARK-5817] [SQL] Fix bug of udtf with column names
It's a bug while do query like: ```sql select d from (select explode(array(1,1)) d from src limit 1) t ``` And it will throws exception like: ``` org.apache.spark.sql.AnalysisException: cannot resolve 'd' given input columns _c0; line 1 pos 7 at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42) at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$apply$3$$anonfun$apply$1.applyOrElse(CheckAnalysis.scala:48) at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$apply$3$$anonfun$apply$1.applyOrElse(CheckAnalysis.scala:45) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:250) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:250) at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:50) at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:249) at org.apache.spark.sql.catalyst.plans.QueryPlan.org$apache$spark$sql$catalyst$plans$QueryPlan$$transformExpressionUp$1(QueryPlan.scala:103) at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$2$$anonfun$apply$2.apply(QueryPlan.scala:117) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:244) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:244) at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59) at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:47) at scala.collection.TraversableLike$class.map(TraversableLike.scala:244) at scala.collection.AbstractTraversable.map(Traversable.scala:105) at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$2.apply(QueryPlan.scala:116) at scala.collection.Iterator$$anon$11.next(Iterator.scala:328) ``` To solve the bug, it requires code refactoring for UDTF The major changes are about: * Simplifying the UDTF development, UDTF will manage the output attribute names any more, instead, the `logical.Generate` will handle that properly. * UDTF will be asked for the output schema (data types) during the logical plan analyzing. Author: Cheng Hao <hao.cheng@intel.com> Closes #4602 from chenghao-intel/explode_bug and squashes the following commits: c2a5132 [Cheng Hao] add back resolved for Alias 556e982 [Cheng Hao] revert the unncessary change 002c361 [Cheng Hao] change the rule of resolved for Generate 04ae500 [Cheng Hao] add qualifier only for generator output 5ee5d2c [Cheng Hao] prepend the new qualifier d2e8b43 [Cheng Hao] Update the code as feedback ca5e7f4 [Cheng Hao] shrink the commits
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala21
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala22
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala5
3 files changed, 24 insertions, 24 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 45f5da3876..03d9834d1d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -34,7 +34,7 @@ import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser}
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedRelation, ResolvedStar}
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, ResolvedStar}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
import org.apache.spark.sql.catalyst.plans.logical._
@@ -711,12 +711,16 @@ class DataFrame private[sql](
*/
def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame = {
val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
- val attributes = schema.toAttributes
+
+ val elementTypes = schema.toAttributes.map { attr => (attr.dataType, attr.nullable) }
+ val names = schema.toAttributes.map(_.name)
+
val rowFunction =
f.andThen(_.map(CatalystTypeConverters.convertToCatalyst(_, schema).asInstanceOf[Row]))
- val generator = UserDefinedGenerator(attributes, rowFunction, input.map(_.expr))
+ val generator = UserDefinedGenerator(elementTypes, rowFunction, input.map(_.expr))
- Generate(generator, join = true, outer = false, None, logicalPlan)
+ Generate(generator, join = true, outer = false,
+ qualifier = None, names.map(UnresolvedAttribute(_)), logicalPlan)
}
/**
@@ -733,12 +737,17 @@ class DataFrame private[sql](
: DataFrame = {
val dataType = ScalaReflection.schemaFor[B].dataType
val attributes = AttributeReference(outputColumn, dataType)() :: Nil
+ // TODO handle the metadata?
+ val elementTypes = attributes.map { attr => (attr.dataType, attr.nullable) }
+ val names = attributes.map(_.name)
+
def rowFunction(row: Row): TraversableOnce[Row] = {
f(row(0).asInstanceOf[A]).map(o => Row(CatalystTypeConverters.convertToCatalyst(o, dataType)))
}
- val generator = UserDefinedGenerator(attributes, rowFunction, apply(inputColumn).expr :: Nil)
+ val generator = UserDefinedGenerator(elementTypes, rowFunction, apply(inputColumn).expr :: Nil)
- Generate(generator, join = true, outer = false, None, logicalPlan)
+ Generate(generator, join = true, outer = false,
+ qualifier = None, names.map(UnresolvedAttribute(_)), logicalPlan)
}
/////////////////////////////////////////////////////////////////////////////
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
index 12271048bb..5201e20a10 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
@@ -27,44 +27,34 @@ import org.apache.spark.sql.catalyst.expressions._
* output of each into a new stream of rows. This operation is similar to a `flatMap` in functional
* programming with one important additional feature, which allows the input rows to be joined with
* their output.
+ * @param generator the generator expression
* @param join when true, each output row is implicitly joined with the input tuple that produced
* it.
* @param outer when true, each input row will be output at least once, even if the output of the
* given `generator` is empty. `outer` has no effect when `join` is false.
+ * @param output the output attributes of this node, which constructed in analysis phase,
+ * and we can not change it, as the parent node bound with it already.
*/
@DeveloperApi
case class Generate(
generator: Generator,
join: Boolean,
outer: Boolean,
+ output: Seq[Attribute],
child: SparkPlan)
extends UnaryNode {
- // This must be a val since the generator output expr ids are not preserved by serialization.
- protected val generatorOutput: Seq[Attribute] = {
- if (join && outer) {
- generator.output.map(_.withNullability(true))
- } else {
- generator.output
- }
- }
-
- // This must be a val since the generator output expr ids are not preserved by serialization.
- override val output =
- if (join) child.output ++ generatorOutput else generatorOutput
-
val boundGenerator = BindReferences.bindReference(generator, child.output)
override def execute(): RDD[Row] = {
if (join) {
child.execute().mapPartitions { iter =>
- val nullValues = Seq.fill(generator.output.size)(Literal(null))
+ val nullValues = Seq.fill(generator.elementTypes.size)(Literal(null))
// Used to produce rows with no matches when outer = true.
val outerProjection =
newProjection(child.output ++ nullValues, child.output)
- val joinProjection =
- newProjection(child.output ++ generatorOutput, child.output ++ generatorOutput)
+ val joinProjection = newProjection(output, output)
val joinedRow = new JoinedRow
iter.flatMap {row =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index e687d01f57..030ef118f7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -312,8 +312,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.Except(planLater(left), planLater(right)) :: Nil
case logical.Intersect(left, right) =>
execution.Intersect(planLater(left), planLater(right)) :: Nil
- case logical.Generate(generator, join, outer, _, child) =>
- execution.Generate(generator, join = join, outer = outer, planLater(child)) :: Nil
+ case g @ logical.Generate(generator, join, outer, _, _, child) =>
+ execution.Generate(
+ generator, join = join, outer = outer, g.output, planLater(child)) :: Nil
case logical.OneRowRelation =>
execution.PhysicalRDD(Nil, singleRowRdd) :: Nil
case logical.Repartition(expressions, child) =>