aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorCheng Hao <hao.cheng@intel.com>2015-04-21 15:11:15 -0700
committerMichael Armbrust <michael@databricks.com>2015-04-21 15:11:15 -0700
commit7662ec23bb6c4d4fe4c857b6928eaed0a97d3c04 (patch)
treed9d0b5821fa04de211f0072d860b57cf95c69a74 /sql
parent2a24bf92e6d36e876bad6a8b4e0ff12c407ebb8a (diff)
downloadspark-7662ec23bb6c4d4fe4c857b6928eaed0a97d3c04.tar.gz
spark-7662ec23bb6c4d4fe4c857b6928eaed0a97d3c04.tar.bz2
spark-7662ec23bb6c4d4fe4c857b6928eaed0a97d3c04.zip
[SPARK-5817] [SQL] Fix bug of udtf with column names
It's a bug while do query like: ```sql select d from (select explode(array(1,1)) d from src limit 1) t ``` And it will throws exception like: ``` org.apache.spark.sql.AnalysisException: cannot resolve 'd' given input columns _c0; line 1 pos 7 at org.apache.spark.sql.catalyst.analysis.package$AnalysisErrorAt.failAnalysis(package.scala:42) at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$apply$3$$anonfun$apply$1.applyOrElse(CheckAnalysis.scala:48) at org.apache.spark.sql.catalyst.analysis.CheckAnalysis$$anonfun$apply$3$$anonfun$apply$1.applyOrElse(CheckAnalysis.scala:45) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:250) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:250) at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:50) at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:249) at org.apache.spark.sql.catalyst.plans.QueryPlan.org$apache$spark$sql$catalyst$plans$QueryPlan$$transformExpressionUp$1(QueryPlan.scala:103) at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$2$$anonfun$apply$2.apply(QueryPlan.scala:117) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:244) at scala.collection.TraversableLike$$anonfun$map$1.apply(TraversableLike.scala:244) at scala.collection.mutable.ResizableArray$class.foreach(ResizableArray.scala:59) at scala.collection.mutable.ArrayBuffer.foreach(ArrayBuffer.scala:47) at scala.collection.TraversableLike$class.map(TraversableLike.scala:244) at scala.collection.AbstractTraversable.map(Traversable.scala:105) at org.apache.spark.sql.catalyst.plans.QueryPlan$$anonfun$2.apply(QueryPlan.scala:116) at scala.collection.Iterator$$anon$11.next(Iterator.scala:328) ``` To solve the bug, it requires code refactoring for UDTF The major changes are about: * Simplifying the UDTF development, UDTF will manage the output attribute names any more, instead, the `logical.Generate` will handle that properly. * UDTF will be asked for the output schema (data types) during the logical plan analyzing. Author: Cheng Hao <hao.cheng@intel.com> Closes #4602 from chenghao-intel/explode_bug and squashes the following commits: c2a5132 [Cheng Hao] add back resolved for Alias 556e982 [Cheng Hao] revert the unncessary change 002c361 [Cheng Hao] change the rule of resolved for Generate 04ae500 [Cheng Hao] add qualifier only for generator output 5ee5d2c [Cheng Hao] prepend the new qualifier d2e8b43 [Cheng Hao] Update the code as feedback ca5e7f4 [Cheng Hao] shrink the commits
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala57
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala3
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala49
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala37
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala2
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala21
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala22
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala5
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala1
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala37
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala38
-rw-r--r--sql/hive/src/test/resources/golden/Specify the udtf output-0-d1f244bce64f22b34ad5bf9fd360b6321
-rw-r--r--sql/hive/src/test/resources/golden/insert table with generator with column name-0-7ac701cf43e73e9e416888e4df6943480
-rw-r--r--sql/hive/src/test/resources/golden/insert table with generator with column name-1-5cdf9d51fc0e105e365d82e7611e37f30
-rw-r--r--sql/hive/src/test/resources/golden/insert table with generator with column name-2-f963396461294e06cb7cafe22a1419e43
-rw-r--r--sql/hive/src/test/resources/golden/insert table with generator with multiple column names-0-46bdb27b3359dc81d8c246b9f69d4b820
-rw-r--r--sql/hive/src/test/resources/golden/insert table with generator with multiple column names-1-cdf6989f3b055257f1692c3bbd80dc730
-rw-r--r--sql/hive/src/test/resources/golden/insert table with generator with multiple column names-2-ab3954b69d7a991bc801a509c3166cc53
-rw-r--r--sql/hive/src/test/resources/golden/insert table with generator without column name-0-7ac701cf43e73e9e416888e4df6943480
-rw-r--r--sql/hive/src/test/resources/golden/insert table with generator without column name-1-26599718c322ff4f9740040c066d82920
-rw-r--r--sql/hive/src/test/resources/golden/insert table with generator without column name-2-f963396461294e06cb7cafe22a1419e43
-rw-r--r--sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala40
26 files changed, 207 insertions, 145 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index cb49e5ad55..5e42b409dc 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -19,7 +19,6 @@ package org.apache.spark.sql.catalyst.analysis
import org.apache.spark.util.collection.OpenHashSet
import org.apache.spark.sql.AnalysisException
-import org.apache.spark.sql.catalyst.errors.TreeNodeException
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules._
@@ -59,6 +58,7 @@ class Analyzer(
ResolveReferences ::
ResolveGroupingAnalytics ::
ResolveSortReferences ::
+ ResolveGenerate ::
ImplicitGenerate ::
ResolveFunctions ::
GlobalAggregates ::
@@ -474,8 +474,59 @@ class Analyzer(
*/
object ImplicitGenerate extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case Project(Seq(Alias(g: Generator, _)), child) =>
- Generate(g, join = false, outer = false, None, child)
+ case Project(Seq(Alias(g: Generator, name)), child) =>
+ Generate(g, join = false, outer = false,
+ qualifier = None, UnresolvedAttribute(name) :: Nil, child)
+ case Project(Seq(MultiAlias(g: Generator, names)), child) =>
+ Generate(g, join = false, outer = false,
+ qualifier = None, names.map(UnresolvedAttribute(_)), child)
+ }
+ }
+
+ /**
+ * Resolve the Generate, if the output names specified, we will take them, otherwise
+ * we will try to provide the default names, which follow the same rule with Hive.
+ */
+ object ResolveGenerate extends Rule[LogicalPlan] {
+ // Construct the output attributes for the generator,
+ // The output attribute names can be either specified or
+ // auto generated.
+ private def makeGeneratorOutput(
+ generator: Generator,
+ generatorOutput: Seq[Attribute]): Seq[Attribute] = {
+ val elementTypes = generator.elementTypes
+
+ if (generatorOutput.length == elementTypes.length) {
+ generatorOutput.zip(elementTypes).map {
+ case (a, (t, nullable)) if !a.resolved =>
+ AttributeReference(a.name, t, nullable)()
+ case (a, _) => a
+ }
+ } else if (generatorOutput.length == 0) {
+ elementTypes.zipWithIndex.map {
+ // keep the default column names as Hive does _c0, _c1, _cN
+ case ((t, nullable), i) => AttributeReference(s"_c$i", t, nullable)()
+ }
+ } else {
+ throw new AnalysisException(
+ s"""
+ |The number of aliases supplied in the AS clause does not match
+ |the number of columns output by the UDTF expected
+ |${elementTypes.size} aliases but got ${generatorOutput.size}
+ """.stripMargin)
+ }
+ }
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan transform {
+ case p: Generate if !p.child.resolved || !p.generator.resolved => p
+ case p: Generate if p.resolved == false =>
+ // if the generator output names are not specified, we will use the default ones.
+ Generate(
+ p.generator,
+ join = p.join,
+ outer = p.outer,
+ p.qualifier,
+ makeGeneratorOutput(p.generator, p.generatorOutput), p.child)
}
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
index a986dd5387..2381689e17 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/CheckAnalysis.scala
@@ -38,6 +38,12 @@ trait CheckAnalysis {
throw new AnalysisException(msg)
}
+ def containsMultipleGenerators(exprs: Seq[Expression]): Boolean = {
+ exprs.flatMap(_.collect {
+ case e: Generator => true
+ }).length >= 1
+ }
+
def checkAnalysis(plan: LogicalPlan): Unit = {
// We transform up and order the rules so as to catch the first possible failure instead
// of the result of cascading resolution failures.
@@ -110,6 +116,12 @@ trait CheckAnalysis {
failAnalysis(
s"unresolved operator ${operator.simpleString}")
+ case p @ Project(exprs, _) if containsMultipleGenerators(exprs) =>
+ failAnalysis(
+ s"""Only a single table generating function is allowed in a SELECT clause, found:
+ | ${exprs.map(_.prettyString).mkString(",")}""".stripMargin)
+
+
case _ => // Analysis successful!
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
index 21c15ad14f..4e5c64bb63 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/dsl/package.scala
@@ -284,12 +284,13 @@ package object dsl {
seed: Int = (math.random * 1000).toInt): LogicalPlan =
Sample(fraction, withReplacement, seed, logicalPlan)
+ // TODO specify the output column names
def generate(
generator: Generator,
join: Boolean = false,
outer: Boolean = false,
alias: Option[String] = None): LogicalPlan =
- Generate(generator, join, outer, None, logicalPlan)
+ Generate(generator, join = join, outer = outer, alias, Nil, logicalPlan)
def insertInto(tableName: String, overwrite: Boolean = false): LogicalPlan =
InsertIntoTable(
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
index 67caadb839..9a6cb048af 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/generators.scala
@@ -42,47 +42,30 @@ abstract class Generator extends Expression {
override type EvaluatedType = TraversableOnce[Row]
- override lazy val dataType =
- ArrayType(StructType(output.map(a => StructField(a.name, a.dataType, a.nullable, a.metadata))))
+ // TODO ideally we should return the type of ArrayType(StructType),
+ // however, we don't keep the output field names in the Generator.
+ override def dataType: DataType = throw new UnsupportedOperationException
override def nullable: Boolean = false
/**
- * Should be overridden by specific generators. Called only once for each instance to ensure
- * that rule application does not change the output schema of a generator.
+ * The output element data types in structure of Seq[(DataType, Nullable)]
+ * TODO we probably need to add more information like metadata etc.
*/
- protected def makeOutput(): Seq[Attribute]
-
- private var _output: Seq[Attribute] = null
-
- def output: Seq[Attribute] = {
- if (_output == null) {
- _output = makeOutput()
- }
- _output
- }
+ def elementTypes: Seq[(DataType, Boolean)]
/** Should be implemented by child classes to perform specific Generators. */
override def eval(input: Row): TraversableOnce[Row]
-
- /** Overridden `makeCopy` also copies the attributes that are produced by this generator. */
- override def makeCopy(newArgs: Array[AnyRef]): this.type = {
- val copy = super.makeCopy(newArgs)
- copy._output = _output
- copy
- }
}
/**
* A generator that produces its output using the provided lambda function.
*/
case class UserDefinedGenerator(
- schema: Seq[Attribute],
+ elementTypes: Seq[(DataType, Boolean)],
function: Row => TraversableOnce[Row],
children: Seq[Expression])
- extends Generator{
-
- override protected def makeOutput(): Seq[Attribute] = schema
+ extends Generator {
override def eval(input: Row): TraversableOnce[Row] = {
// TODO(davies): improve this
@@ -98,30 +81,18 @@ case class UserDefinedGenerator(
/**
* Given an input array produces a sequence of rows for each value in the array.
*/
-case class Explode(attributeNames: Seq[String], child: Expression)
+case class Explode(child: Expression)
extends Generator with trees.UnaryNode[Expression] {
override lazy val resolved =
child.resolved &&
(child.dataType.isInstanceOf[ArrayType] || child.dataType.isInstanceOf[MapType])
- private lazy val elementTypes = child.dataType match {
+ override def elementTypes: Seq[(DataType, Boolean)] = child.dataType match {
case ArrayType(et, containsNull) => (et, containsNull) :: Nil
case MapType(kt, vt, valueContainsNull) => (kt, false) :: (vt, valueContainsNull) :: Nil
}
- // TODO: Move this pattern into Generator.
- protected def makeOutput() =
- if (attributeNames.size == elementTypes.size) {
- attributeNames.zip(elementTypes).map {
- case (n, (t, nullable)) => AttributeReference(n, t, nullable)()
- }
- } else {
- elementTypes.zipWithIndex.map {
- case ((t, nullable), i) => AttributeReference(s"c_$i", t, nullable)()
- }
- }
-
override def eval(input: Row): TraversableOnce[Row] = {
child.dataType match {
case ArrayType(_, _) =>
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
index bcbcbeb31c..afcb2ce8b9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/namedExpressions.scala
@@ -112,6 +112,8 @@ case class Alias(child: Expression, name: String)(
extends NamedExpression with trees.UnaryNode[Expression] {
override type EvaluatedType = Any
+ // Alias(Generator, xx) need to be transformed into Generate(generator, ...)
+ override lazy val resolved = childrenResolved && !child.isInstanceOf[Generator]
override def eval(input: Row): Any = child.eval(input)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 7c80634d2c..2d03fbfb0d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -482,16 +482,16 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] {
object PushPredicateThroughGenerate extends Rule[LogicalPlan] with PredicateHelper {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case filter @ Filter(condition,
- generate @ Generate(generator, join, outer, alias, grandChild)) =>
+ case filter @ Filter(condition, g: Generate) =>
// Predicates that reference attributes produced by the `Generate` operator cannot
// be pushed below the operator.
val (pushDown, stayUp) = splitConjunctivePredicates(condition).partition {
- conjunct => conjunct.references subsetOf grandChild.outputSet
+ conjunct => conjunct.references subsetOf g.child.outputSet
}
if (pushDown.nonEmpty) {
val pushDownPredicate = pushDown.reduce(And)
- val withPushdown = generate.copy(child = Filter(pushDownPredicate, grandChild))
+ val withPushdown = Generate(g.generator, join = g.join, outer = g.outer,
+ g.qualifier, g.generatorOutput, Filter(pushDownPredicate, g.child))
stayUp.reduceOption(And).map(Filter(_, withPushdown)).getOrElse(withPushdown)
} else {
filter
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 17522976dc..bbc94a7ab3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -40,34 +40,43 @@ case class Project(projectList: Seq[NamedExpression], child: LogicalPlan) extend
* output of each into a new stream of rows. This operation is similar to a `flatMap` in functional
* programming with one important additional feature, which allows the input rows to be joined with
* their output.
+ * @param generator the generator expression
* @param join when true, each output row is implicitly joined with the input tuple that produced
* it.
* @param outer when true, each input row will be output at least once, even if the output of the
* given `generator` is empty. `outer` has no effect when `join` is false.
- * @param alias when set, this string is applied to the schema of the output of the transformation
- * as a qualifier.
+ * @param qualifier Qualifier for the attributes of generator(UDTF)
+ * @param generatorOutput The output schema of the Generator.
+ * @param child Children logical plan node
*/
case class Generate(
generator: Generator,
join: Boolean,
outer: Boolean,
- alias: Option[String],
+ qualifier: Option[String],
+ generatorOutput: Seq[Attribute],
child: LogicalPlan)
extends UnaryNode {
- protected def generatorOutput: Seq[Attribute] = {
- val output = alias
- .map(a => generator.output.map(_.withQualifiers(a :: Nil)))
- .getOrElse(generator.output)
- if (join && outer) {
- output.map(_.withNullability(true))
- } else {
- output
- }
+ override lazy val resolved: Boolean = {
+ generator.resolved &&
+ childrenResolved &&
+ generator.elementTypes.length == generatorOutput.length &&
+ !generatorOutput.exists(!_.resolved)
}
- override def output: Seq[Attribute] =
- if (join) child.output ++ generatorOutput else generatorOutput
+ // we don't want the gOutput to be taken as part of the expressions
+ // as that will cause exceptions like unresolved attributes etc.
+ override def expressions: Seq[Expression] = generator :: Nil
+
+ def output: Seq[Attribute] = {
+ val qualified = qualifier.map(q =>
+ // prepend the new qualifier to the existed one
+ generatorOutput.map(a => a.withQualifiers(q +: a.qualifiers))
+ ).getOrElse(generatorOutput)
+
+ if (join) child.output ++ qualified else qualified
+ }
}
case class Filter(condition: Expression, child: LogicalPlan) extends UnaryNode {
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
index e10ddfdf51..7c249215bd 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/analysis/AnalysisSuite.scala
@@ -90,7 +90,7 @@ class AnalysisSuite extends FunSuite with BeforeAndAfter {
assert(!Project(Seq(UnresolvedAttribute("a")), testRelation).resolved)
- val explode = Explode(Nil, AttributeReference("a", IntegerType, nullable = true)())
+ val explode = Explode(AttributeReference("a", IntegerType, nullable = true)())
assert(!Project(Seq(Alias(explode, "explode")()), testRelation).resolved)
assert(!Project(Seq(Alias(Count(Literal(1)), "count")()), testRelation).resolved)
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index 1448098c77..45cf695d20 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -454,21 +454,21 @@ class FilterPushdownSuite extends PlanTest {
test("generate: predicate referenced no generated column") {
val originalQuery = {
testRelationWithArrayType
- .generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr"))
+ .generate(Explode('c_arr), true, false, Some("arr"))
.where(('b >= 5) && ('a > 6))
}
val optimized = Optimize(originalQuery.analyze)
val correctAnswer = {
testRelationWithArrayType
.where(('b >= 5) && ('a > 6))
- .generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr")).analyze
+ .generate(Explode('c_arr), true, false, Some("arr")).analyze
}
comparePlans(optimized, correctAnswer)
}
test("generate: part of conjuncts referenced generated column") {
- val generator = Explode(Seq("c"), 'c_arr)
+ val generator = Explode('c_arr)
val originalQuery = {
testRelationWithArrayType
.generate(generator, true, false, Some("arr"))
@@ -499,7 +499,7 @@ class FilterPushdownSuite extends PlanTest {
test("generate: all conjuncts referenced generated column") {
val originalQuery = {
testRelationWithArrayType
- .generate(Explode(Seq("c"), 'c_arr), true, false, Some("arr"))
+ .generate(Explode('c_arr), true, false, Some("arr"))
.where(('c > 6) || ('b > 5)).analyze
}
val optimized = Optimize(originalQuery)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
index 45f5da3876..03d9834d1d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/DataFrame.scala
@@ -34,7 +34,7 @@ import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.sql.catalyst.{CatalystTypeConverters, ScalaReflection, SqlParser}
-import org.apache.spark.sql.catalyst.analysis.{UnresolvedRelation, ResolvedStar}
+import org.apache.spark.sql.catalyst.analysis.{UnresolvedAttribute, UnresolvedRelation, ResolvedStar}
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.{JoinType, Inner}
import org.apache.spark.sql.catalyst.plans.logical._
@@ -711,12 +711,16 @@ class DataFrame private[sql](
*/
def explode[A <: Product : TypeTag](input: Column*)(f: Row => TraversableOnce[A]): DataFrame = {
val schema = ScalaReflection.schemaFor[A].dataType.asInstanceOf[StructType]
- val attributes = schema.toAttributes
+
+ val elementTypes = schema.toAttributes.map { attr => (attr.dataType, attr.nullable) }
+ val names = schema.toAttributes.map(_.name)
+
val rowFunction =
f.andThen(_.map(CatalystTypeConverters.convertToCatalyst(_, schema).asInstanceOf[Row]))
- val generator = UserDefinedGenerator(attributes, rowFunction, input.map(_.expr))
+ val generator = UserDefinedGenerator(elementTypes, rowFunction, input.map(_.expr))
- Generate(generator, join = true, outer = false, None, logicalPlan)
+ Generate(generator, join = true, outer = false,
+ qualifier = None, names.map(UnresolvedAttribute(_)), logicalPlan)
}
/**
@@ -733,12 +737,17 @@ class DataFrame private[sql](
: DataFrame = {
val dataType = ScalaReflection.schemaFor[B].dataType
val attributes = AttributeReference(outputColumn, dataType)() :: Nil
+ // TODO handle the metadata?
+ val elementTypes = attributes.map { attr => (attr.dataType, attr.nullable) }
+ val names = attributes.map(_.name)
+
def rowFunction(row: Row): TraversableOnce[Row] = {
f(row(0).asInstanceOf[A]).map(o => Row(CatalystTypeConverters.convertToCatalyst(o, dataType)))
}
- val generator = UserDefinedGenerator(attributes, rowFunction, apply(inputColumn).expr :: Nil)
+ val generator = UserDefinedGenerator(elementTypes, rowFunction, apply(inputColumn).expr :: Nil)
- Generate(generator, join = true, outer = false, None, logicalPlan)
+ Generate(generator, join = true, outer = false,
+ qualifier = None, names.map(UnresolvedAttribute(_)), logicalPlan)
}
/////////////////////////////////////////////////////////////////////////////
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
index 12271048bb..5201e20a10 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
@@ -27,44 +27,34 @@ import org.apache.spark.sql.catalyst.expressions._
* output of each into a new stream of rows. This operation is similar to a `flatMap` in functional
* programming with one important additional feature, which allows the input rows to be joined with
* their output.
+ * @param generator the generator expression
* @param join when true, each output row is implicitly joined with the input tuple that produced
* it.
* @param outer when true, each input row will be output at least once, even if the output of the
* given `generator` is empty. `outer` has no effect when `join` is false.
+ * @param output the output attributes of this node, which constructed in analysis phase,
+ * and we can not change it, as the parent node bound with it already.
*/
@DeveloperApi
case class Generate(
generator: Generator,
join: Boolean,
outer: Boolean,
+ output: Seq[Attribute],
child: SparkPlan)
extends UnaryNode {
- // This must be a val since the generator output expr ids are not preserved by serialization.
- protected val generatorOutput: Seq[Attribute] = {
- if (join && outer) {
- generator.output.map(_.withNullability(true))
- } else {
- generator.output
- }
- }
-
- // This must be a val since the generator output expr ids are not preserved by serialization.
- override val output =
- if (join) child.output ++ generatorOutput else generatorOutput
-
val boundGenerator = BindReferences.bindReference(generator, child.output)
override def execute(): RDD[Row] = {
if (join) {
child.execute().mapPartitions { iter =>
- val nullValues = Seq.fill(generator.output.size)(Literal(null))
+ val nullValues = Seq.fill(generator.elementTypes.size)(Literal(null))
// Used to produce rows with no matches when outer = true.
val outerProjection =
newProjection(child.output ++ nullValues, child.output)
- val joinProjection =
- newProjection(child.output ++ generatorOutput, child.output ++ generatorOutput)
+ val joinProjection = newProjection(output, output)
val joinedRow = new JoinedRow
iter.flatMap {row =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index e687d01f57..030ef118f7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -312,8 +312,9 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
execution.Except(planLater(left), planLater(right)) :: Nil
case logical.Intersect(left, right) =>
execution.Intersect(planLater(left), planLater(right)) :: Nil
- case logical.Generate(generator, join, outer, _, child) =>
- execution.Generate(generator, join = join, outer = outer, planLater(child)) :: Nil
+ case g @ logical.Generate(generator, join, outer, _, _, child) =>
+ execution.Generate(
+ generator, join = join, outer = outer, g.output, planLater(child)) :: Nil
case logical.OneRowRelation =>
execution.PhysicalRDD(Nil, singleRowRdd) :: Nil
case logical.Repartition(expressions, child) =>
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index 7c6a7df2bd..c4a73b3004 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -249,7 +249,6 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
catalog.CreateTables ::
catalog.PreInsertionCasts ::
ExtractPythonUdfs ::
- ResolveUdtfsAlias ::
sources.PreInsertCastAndRename ::
Nil
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
index fd305eb480..85061f2277 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveQl.scala
@@ -725,12 +725,14 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
val alias =
getClause("TOK_TABALIAS", clauses).getChildren.head.asInstanceOf[ASTNode].getText
- Generate(
- nodesToGenerator(clauses),
- join = true,
- outer = false,
- Some(alias.toLowerCase),
- withWhere)
+ val (generator, attributes) = nodesToGenerator(clauses)
+ Generate(
+ generator,
+ join = true,
+ outer = false,
+ Some(alias.toLowerCase),
+ attributes.map(UnresolvedAttribute(_)),
+ withWhere)
}.getOrElse(withWhere)
// The projection of the query can either be a normal projection, an aggregation
@@ -833,12 +835,14 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
val alias = getClause("TOK_TABALIAS", clauses).getChildren.head.asInstanceOf[ASTNode].getText
- Generate(
- nodesToGenerator(clauses),
- join = true,
- outer = isOuter.nonEmpty,
- Some(alias.toLowerCase),
- nodeToRelation(relationClause))
+ val (generator, attributes) = nodesToGenerator(clauses)
+ Generate(
+ generator,
+ join = true,
+ outer = isOuter.nonEmpty,
+ Some(alias.toLowerCase),
+ attributes.map(UnresolvedAttribute(_)),
+ nodeToRelation(relationClause))
/* All relations, possibly with aliases or sampling clauses. */
case Token("TOK_TABREF", clauses) =>
@@ -1311,7 +1315,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
val explode = "(?i)explode".r
- def nodesToGenerator(nodes: Seq[Node]): Generator = {
+ def nodesToGenerator(nodes: Seq[Node]): (Generator, Seq[String]) = {
val function = nodes.head
val attributes = nodes.flatMap {
@@ -1321,7 +1325,7 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
function match {
case Token("TOK_FUNCTION", Token(explode(), Nil) :: child :: Nil) =>
- Explode(attributes, nodeToExpr(child))
+ (Explode(nodeToExpr(child)), attributes)
case Token("TOK_FUNCTION", Token(functionName, Nil) :: children) =>
val functionInfo: FunctionInfo =
@@ -1329,10 +1333,9 @@ https://cwiki.apache.org/confluence/display/Hive/Enhanced+Aggregation%2C+Cube%2C
sys.error(s"Couldn't find function $functionName"))
val functionClassName = functionInfo.getFunctionClass.getName
- HiveGenericUdtf(
+ (HiveGenericUdtf(
new HiveFunctionWrapper(functionClassName),
- attributes,
- children.map(nodeToExpr))
+ children.map(nodeToExpr)), attributes)
case a: ASTNode =>
throw new NotImplementedError(
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
index 47305571e5..4b6f0ad75f 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala
@@ -66,7 +66,7 @@ private[hive] abstract class HiveFunctionRegistry
} else if (classOf[UDAF].isAssignableFrom(functionInfo.getFunctionClass)) {
HiveUdaf(new HiveFunctionWrapper(functionClassName), children)
} else if (classOf[GenericUDTF].isAssignableFrom(functionInfo.getFunctionClass)) {
- HiveGenericUdtf(new HiveFunctionWrapper(functionClassName), Nil, children)
+ HiveGenericUdtf(new HiveFunctionWrapper(functionClassName), children)
} else {
sys.error(s"No handler for udf ${functionInfo.getFunctionClass}")
}
@@ -266,7 +266,6 @@ private[hive] case class HiveUdaf(
*/
private[hive] case class HiveGenericUdtf(
funcWrapper: HiveFunctionWrapper,
- aliasNames: Seq[String],
children: Seq[Expression])
extends Generator with HiveInspectors {
@@ -282,23 +281,8 @@ private[hive] case class HiveGenericUdtf(
@transient
protected lazy val udtInput = new Array[AnyRef](children.length)
- protected lazy val outputDataTypes = outputInspector.getAllStructFieldRefs.map {
- field => inspectorToDataType(field.getFieldObjectInspector)
- }
-
- override protected def makeOutput() = {
- // Use column names when given, otherwise _c1, _c2, ... _cn.
- if (aliasNames.size == outputDataTypes.size) {
- aliasNames.zip(outputDataTypes).map {
- case (attrName, attrDataType) =>
- AttributeReference(attrName, attrDataType, nullable = true)()
- }
- } else {
- outputDataTypes.zipWithIndex.map {
- case (attrDataType, i) =>
- AttributeReference(s"_c$i", attrDataType, nullable = true)()
- }
- }
+ lazy val elementTypes = outputInspector.getAllStructFieldRefs.map {
+ field => (inspectorToDataType(field.getFieldObjectInspector), true)
}
override def eval(input: Row): TraversableOnce[Row] = {
@@ -333,22 +317,6 @@ private[hive] case class HiveGenericUdtf(
}
}
-/**
- * Resolve Udtfs Alias.
- */
-private[spark] object ResolveUdtfsAlias extends Rule[LogicalPlan] {
- def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case p @ Project(projectList, _)
- if projectList.exists(_.isInstanceOf[MultiAlias]) && projectList.size != 1 =>
- throw new TreeNodeException(p, "only single Generator supported for SELECT clause")
-
- case Project(Seq(Alias(udtf @ HiveGenericUdtf(_, _, _), name)), child) =>
- Generate(udtf.copy(aliasNames = Seq(name)), join = false, outer = false, None, child)
- case Project(Seq(MultiAlias(udtf @ HiveGenericUdtf(_, _, _), names)), child) =>
- Generate(udtf.copy(aliasNames = names), join = false, outer = false, None, child)
- }
-}
-
private[hive] case class HiveUdafFunction(
funcWrapper: HiveFunctionWrapper,
exprs: Seq[Expression],
diff --git a/sql/hive/src/test/resources/golden/Specify the udtf output-0-d1f244bce64f22b34ad5bf9fd360b632 b/sql/hive/src/test/resources/golden/Specify the udtf output-0-d1f244bce64f22b34ad5bf9fd360b632
new file mode 100644
index 0000000000..d00491fd7e
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/Specify the udtf output-0-d1f244bce64f22b34ad5bf9fd360b632
@@ -0,0 +1 @@
+1
diff --git a/sql/hive/src/test/resources/golden/insert table with generator with column name-0-7ac701cf43e73e9e416888e4df694348 b/sql/hive/src/test/resources/golden/insert table with generator with column name-0-7ac701cf43e73e9e416888e4df694348
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/insert table with generator with column name-0-7ac701cf43e73e9e416888e4df694348
diff --git a/sql/hive/src/test/resources/golden/insert table with generator with column name-1-5cdf9d51fc0e105e365d82e7611e37f3 b/sql/hive/src/test/resources/golden/insert table with generator with column name-1-5cdf9d51fc0e105e365d82e7611e37f3
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/insert table with generator with column name-1-5cdf9d51fc0e105e365d82e7611e37f3
diff --git a/sql/hive/src/test/resources/golden/insert table with generator with column name-2-f963396461294e06cb7cafe22a1419e4 b/sql/hive/src/test/resources/golden/insert table with generator with column name-2-f963396461294e06cb7cafe22a1419e4
new file mode 100644
index 0000000000..01e79c32a8
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/insert table with generator with column name-2-f963396461294e06cb7cafe22a1419e4
@@ -0,0 +1,3 @@
+1
+2
+3
diff --git a/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-0-46bdb27b3359dc81d8c246b9f69d4b82 b/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-0-46bdb27b3359dc81d8c246b9f69d4b82
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-0-46bdb27b3359dc81d8c246b9f69d4b82
diff --git a/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-1-cdf6989f3b055257f1692c3bbd80dc73 b/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-1-cdf6989f3b055257f1692c3bbd80dc73
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-1-cdf6989f3b055257f1692c3bbd80dc73
diff --git a/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-2-ab3954b69d7a991bc801a509c3166cc5 b/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-2-ab3954b69d7a991bc801a509c3166cc5
new file mode 100644
index 0000000000..0c7520f209
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/insert table with generator with multiple column names-2-ab3954b69d7a991bc801a509c3166cc5
@@ -0,0 +1,3 @@
+86 val_86
+238 val_238
+311 val_311
diff --git a/sql/hive/src/test/resources/golden/insert table with generator without column name-0-7ac701cf43e73e9e416888e4df694348 b/sql/hive/src/test/resources/golden/insert table with generator without column name-0-7ac701cf43e73e9e416888e4df694348
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/insert table with generator without column name-0-7ac701cf43e73e9e416888e4df694348
diff --git a/sql/hive/src/test/resources/golden/insert table with generator without column name-1-26599718c322ff4f9740040c066d8292 b/sql/hive/src/test/resources/golden/insert table with generator without column name-1-26599718c322ff4f9740040c066d8292
new file mode 100644
index 0000000000..e69de29bb2
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/insert table with generator without column name-1-26599718c322ff4f9740040c066d8292
diff --git a/sql/hive/src/test/resources/golden/insert table with generator without column name-2-f963396461294e06cb7cafe22a1419e4 b/sql/hive/src/test/resources/golden/insert table with generator without column name-2-f963396461294e06cb7cafe22a1419e4
new file mode 100644
index 0000000000..01e79c32a8
--- /dev/null
+++ b/sql/hive/src/test/resources/golden/insert table with generator without column name-2-f963396461294e06cb7cafe22a1419e4
@@ -0,0 +1,3 @@
+1
+2
+3
diff --git a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
index 300b1f7920..ac10b17330 100644
--- a/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
+++ b/sql/hive/src/test/scala/org/apache/spark/sql/hive/execution/HiveQuerySuite.scala
@@ -27,7 +27,7 @@ import scala.util.Try
import org.apache.hadoop.hive.conf.HiveConf.ConfVars
import org.apache.spark.{SparkFiles, SparkException}
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{AnalysisException, DataFrame, Row}
import org.apache.spark.sql.catalyst.plans.logical.Project
import org.apache.spark.sql.functions._
import org.apache.spark.sql.hive._
@@ -67,6 +67,40 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
}
}
+ createQueryTest("insert table with generator with column name",
+ """
+ | CREATE TABLE gen_tmp (key Int);
+ | INSERT OVERWRITE TABLE gen_tmp
+ | SELECT explode(array(1,2,3)) AS val FROM src LIMIT 3;
+ | SELECT key FROM gen_tmp ORDER BY key ASC;
+ """.stripMargin)
+
+ createQueryTest("insert table with generator with multiple column names",
+ """
+ | CREATE TABLE gen_tmp (key Int, value String);
+ | INSERT OVERWRITE TABLE gen_tmp
+ | SELECT explode(map(key, value)) as (k1, k2) FROM src LIMIT 3;
+ | SELECT key, value FROM gen_tmp ORDER BY key, value ASC;
+ """.stripMargin)
+
+ createQueryTest("insert table with generator without column name",
+ """
+ | CREATE TABLE gen_tmp (key Int);
+ | INSERT OVERWRITE TABLE gen_tmp
+ | SELECT explode(array(1,2,3)) FROM src LIMIT 3;
+ | SELECT key FROM gen_tmp ORDER BY key ASC;
+ """.stripMargin)
+
+ test("multiple generator in projection") {
+ intercept[AnalysisException] {
+ sql("SELECT explode(map(key, value)), key FROM src").collect()
+ }
+
+ intercept[AnalysisException] {
+ sql("SELECT explode(map(key, value)) as k1, k2, key FROM src").collect()
+ }
+ }
+
createQueryTest("! operator",
"""
|SELECT a FROM (
@@ -456,7 +490,6 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
createQueryTest("lateral view2",
"SELECT * FROM src LATERAL VIEW explode(array(1,2)) tbl")
-
createQueryTest("lateral view3",
"FROM src SELECT key, D.* lateral view explode(array(key+3, key+4)) D as CX")
@@ -478,6 +511,9 @@ class HiveQuerySuite extends HiveComparisonTest with BeforeAndAfter {
createQueryTest("lateral view6",
"SELECT * FROM src LATERAL VIEW explode(map(key+3,key+4)) D as k, v")
+ createQueryTest("Specify the udtf output",
+ "SELECT d FROM (SELECT explode(array(1,1)) d FROM src LIMIT 1) t")
+
test("sampling") {
sql("SELECT * FROM src TABLESAMPLE(0.1 PERCENT) s")
sql("SELECT * FROM src TABLESAMPLE(100 PERCENT) s")