aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorgatorsmile <gatorsmile@gmail.com>2015-12-28 12:48:30 -0800
committerMichael Armbrust <michael@databricks.com>2015-12-28 12:48:30 -0800
commit01ba95d8bfc16a2542c67b066b0a1d1e465f91da (patch)
tree416aa86355d55123ff54c42a8869f3f9bca410ea
parenta6a4812434c6f43cd4742437f957fecd86220255 (diff)
downloadspark-01ba95d8bfc16a2542c67b066b0a1d1e465f91da.tar.gz
spark-01ba95d8bfc16a2542c67b066b0a1d1e465f91da.tar.bz2
spark-01ba95d8bfc16a2542c67b066b0a1d1e465f91da.zip
[SPARK-12441][SQL] Fixing missingInput in Generate/MapPartitions/AppendColumns/MapGroups/CoGroup
When explain any plan with Generate, we will see an exclamation mark in the plan. Normally, when we see this mark, it means the plan has an error. This PR is to correct the `missingInput` in `Generate`. For example, ```scala val df = Seq((1, "a b c"), (2, "a b"), (3, "a")).toDF("number", "letters") val df2 = df.explode('letters) { case Row(letters: String) => letters.split(" ").map(Tuple1(_)).toSeq } df2.explain(true) ``` Before the fix, the plan is like ``` == Parsed Logical Plan == 'Generate UserDefinedGenerator('letters), true, false, None +- Project [_1#0 AS number#2,_2#1 AS letters#3] +- LocalRelation [_1#0,_2#1], [[1,a b c],[2,a b],[3,a]] == Analyzed Logical Plan == number: int, letters: string, _1: string Generate UserDefinedGenerator(letters#3), true, false, None, [_1#8] +- Project [_1#0 AS number#2,_2#1 AS letters#3] +- LocalRelation [_1#0,_2#1], [[1,a b c],[2,a b],[3,a]] == Optimized Logical Plan == Generate UserDefinedGenerator(letters#3), true, false, None, [_1#8] +- LocalRelation [number#2,letters#3], [[1,a b c],[2,a b],[3,a]] == Physical Plan == !Generate UserDefinedGenerator(letters#3), true, false, [number#2,letters#3,_1#8] +- LocalTableScan [number#2,letters#3], [[1,a b c],[2,a b],[3,a]] ``` **Updates**: The same issues are also found in the other four Dataset operators: `MapPartitions`/`AppendColumns`/`MapGroups`/`CoGroup`. Fixed all these four. Author: gatorsmile <gatorsmile@gmail.com> Author: xiaoli <lixiao1983@gmail.com> Author: Xiao Li <xiaoli@Xiaos-MacBook-Pro.local> Closes #10393 from gatorsmile/generateExplain.
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala11
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala14
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala3
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala2
15 files changed, 63 insertions, 18 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index d2626440b9..b43b7ee71e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -44,15 +44,16 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
AttributeSet(children.flatMap(_.asInstanceOf[QueryPlan[PlanType]].output))
/**
+ * The set of all attributes that are produced by this node.
+ */
+ def producedAttributes: AttributeSet = AttributeSet.empty
+
+ /**
* Attributes that are referenced by expressions but not provided by this nodes children.
* Subclasses should override this method if they produce attributes internally as it is used by
* assertions designed to prevent the construction of invalid plans.
- *
- * Note that virtual columns should be excluded. Currently, we only support the grouping ID
- * virtual column.
*/
- def missingInput: AttributeSet =
- (references -- inputSet).filter(_.name != VirtualColumn.groupingIdName)
+ def missingInput: AttributeSet = references -- inputSet -- producedAttributes
/**
* Runs [[transform]] with `rule` on all expressions present in this query operator.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
index e3e7a11dba..572d7d2f0b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
@@ -18,8 +18,8 @@
package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, analysis}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet}
+import org.apache.spark.sql.catalyst.{analysis, CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.types.{StructField, StructType}
object LocalRelation {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 8f8747e105..6d859551f8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -295,6 +295,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
*/
abstract class LeafNode extends LogicalPlan {
override def children: Seq[LogicalPlan] = Nil
+ override def producedAttributes: AttributeSet = outputSet
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 64ef4d7996..5f34d4a4eb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -526,7 +526,7 @@ case class MapPartitions[T, U](
uEncoder: ExpressionEncoder[U],
output: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
- override def missingInput: AttributeSet = AttributeSet.empty
+ override def producedAttributes: AttributeSet = outputSet
}
/** Factory for constructing new `AppendColumn` nodes. */
@@ -552,7 +552,7 @@ case class AppendColumns[T, U](
newColumns: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output ++ newColumns
- override def missingInput: AttributeSet = super.missingInput -- newColumns
+ override def producedAttributes: AttributeSet = AttributeSet(newColumns)
}
/** Factory for constructing new `MapGroups` nodes. */
@@ -587,7 +587,7 @@ case class MapGroups[K, T, U](
groupingAttributes: Seq[Attribute],
output: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
- override def missingInput: AttributeSet = AttributeSet.empty
+ override def producedAttributes: AttributeSet = outputSet
}
/** Factory for constructing new `CoGroup` nodes. */
@@ -630,5 +630,5 @@ case class CoGroup[Key, Left, Right, Result](
rightGroup: Seq[Attribute],
left: LogicalPlan,
right: LogicalPlan) extends BinaryNode {
- override def missingInput: AttributeSet = AttributeSet.empty
+ override def producedAttributes: AttributeSet = outputSet
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index ea5a9afe03..5c01af011d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -18,11 +18,11 @@
package org.apache.spark.sql.execution
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters}
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
-import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, GenericMutableRow}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
-import org.apache.spark.sql.sources.{HadoopFsRelation, BaseRelation}
+import org.apache.spark.sql.sources.{BaseRelation, HadoopFsRelation}
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.{Row, SQLContext}
@@ -84,6 +84,8 @@ private[sql] case class LogicalRDD(
case _ => false
}
+ override def producedAttributes: AttributeSet = outputSet
+
@transient override lazy val statistics: Statistics = Statistics(
// TODO: Instead of returning a default value here, find a way to return a meaningful size
// estimate for RDDs. See PR 1238 for more discussions.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
index 54b8cb5828..0c613e91b9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
@@ -54,6 +54,8 @@ case class Generate(
child: SparkPlan)
extends UnaryNode {
+ override def expressions: Seq[Expression] = generator :: Nil
+
val boundGenerator = BindReferences.bindReference(generator, child.output)
protected override def doExecute(): RDD[InternalRow] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index ec98f81041..fe9b2ad4a0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -279,6 +279,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
private[sql] trait LeafNode extends SparkPlan {
override def children: Seq[SparkPlan] = Nil
+ override def producedAttributes: AttributeSet = outputSet
}
private[sql] trait UnaryNode extends SparkPlan {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
index c5470a6989..c4587ba677 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
@@ -36,6 +36,15 @@ case class SortBasedAggregate(
child: SparkPlan)
extends UnaryNode {
+ private[this] val aggregateBufferAttributes = {
+ aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
+ }
+
+ override def producedAttributes: AttributeSet =
+ AttributeSet(aggregateAttributes) ++
+ AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
+ AttributeSet(aggregateBufferAttributes)
+
override private[sql] lazy val metrics = Map(
"numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"),
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index b8849c8270..9d758eb3b7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -55,6 +55,11 @@ case class TungstenAggregate(
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
+ override def producedAttributes: AttributeSet =
+ AttributeSet(aggregateAttributes) ++
+ AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
+ AttributeSet(aggregateBufferAttributes)
+
override def requiredChildDistribution: List[Distribution] = {
requiredChildDistributionExpressions match {
case Some(exprs) if exprs.length == 0 => AllTuples :: Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 6b7b3bbbf6..f19d72f067 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -369,6 +369,7 @@ case class MapPartitions[T, U](
uEncoder: ExpressionEncoder[U],
output: Seq[Attribute],
child: SparkPlan) extends UnaryNode {
+ override def producedAttributes: AttributeSet = outputSet
override def canProcessSafeRows: Boolean = true
override def canProcessUnsafeRows: Boolean = true
@@ -391,6 +392,7 @@ case class AppendColumns[T, U](
uEncoder: ExpressionEncoder[U],
newColumns: Seq[Attribute],
child: SparkPlan) extends UnaryNode {
+ override def producedAttributes: AttributeSet = AttributeSet(newColumns)
// We are using an unsafe combiner.
override def canProcessSafeRows: Boolean = false
@@ -424,6 +426,7 @@ case class MapGroups[K, T, U](
groupingAttributes: Seq[Attribute],
output: Seq[Attribute],
child: SparkPlan) extends UnaryNode {
+ override def producedAttributes: AttributeSet = outputSet
override def canProcessSafeRows: Boolean = true
override def canProcessUnsafeRows: Boolean = true
@@ -467,6 +470,7 @@ case class CoGroup[Key, Left, Right, Result](
rightGroup: Seq[Attribute],
left: SparkPlan,
right: SparkPlan) extends BinaryNode {
+ override def producedAttributes: AttributeSet = outputSet
override def canProcessSafeRows: Boolean = true
override def canProcessUnsafeRows: Boolean = true
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala
index 4afa5f8ec1..aa7a668e0e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala
@@ -66,6 +66,8 @@ private[sql] case class InMemoryRelation(
private[sql] var _batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = null)
extends LogicalPlan with MultiInstanceRelation {
+ override def producedAttributes: AttributeSet = outputSet
+
private val batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] =
if (_batchStats == null) {
child.sqlContext.sparkContext.accumulableCollection(ArrayBuffer.empty[InternalRow])
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala
index 78a98798ef..359a1e7f84 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala
@@ -15,16 +15,14 @@
* limitations under the License.
*/
-package test.org.apache.spark.sql
+package org.apache.spark.sql
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Literal, GenericInternalRow, Attribute}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan}
import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.{Row, Strategy, QueryTest}
import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.unsafe.types.UTF8String
case class FastOperator(output: Seq[Attribute]) extends SparkPlan {
@@ -34,6 +32,7 @@ case class FastOperator(output: Seq[Attribute]) extends SparkPlan {
sparkContext.parallelize(Seq(row))
}
+ override def producedAttributes: AttributeSet = outputSet
override def children: Seq[SparkPlan] = Nil
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 442ae79f4f..815372f192 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -130,6 +130,8 @@ abstract class QueryTest extends PlanTest {
checkJsonFormat(analyzedDF)
+ assertEmptyMissingInput(df)
+
QueryTest.checkAnswer(analyzedDF, expectedAnswer) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
@@ -275,6 +277,18 @@ abstract class QueryTest extends PlanTest {
""".stripMargin)
}
}
+
+ /**
+ * Asserts that a given [[Queryable]] does not have missing inputs in all the analyzed plans.
+ */
+ def assertEmptyMissingInput(query: Queryable): Unit = {
+ assert(query.queryExecution.analyzed.missingInput.isEmpty,
+ s"The analyzed logical plan has missing inputs: ${query.queryExecution.analyzed}")
+ assert(query.queryExecution.optimizedPlan.missingInput.isEmpty,
+ s"The optimized logical plan has missing inputs: ${query.queryExecution.optimizedPlan}")
+ assert(query.queryExecution.executedPlan.missingInput.isEmpty,
+ s"The physical plan has missing inputs: ${query.queryExecution.executedPlan}")
+ }
}
object QueryTest {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala
index 806d2b9b0b..8141136de5 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala
@@ -51,6 +51,9 @@ case class HiveTableScan(
require(partitionPruningPred.isEmpty || relation.hiveQlTable.isPartitioned,
"Partition pruning predicates only supported for partitioned tables.")
+ override def producedAttributes: AttributeSet = outputSet ++
+ AttributeSet(partitionPruningPred.flatMap(_.references))
+
// Retrieve the original attributes based on expression ID so that capitalization matches.
val attributes = requestedAttributes.map(relation.attributeMap)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
index d9b9ba4bfd..a61e162f48 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
@@ -60,6 +60,8 @@ case class ScriptTransformation(
override protected def otherCopyArgs: Seq[HiveContext] = sc :: Nil
+ override def producedAttributes: AttributeSet = outputSet -- inputSet
+
private val serializedHiveConf = new SerializableConfiguration(sc.hiveconf)
protected override def doExecute(): RDD[InternalRow] = {