aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala11
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala4
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala1
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala2
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala14
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala3
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala2
15 files changed, 63 insertions, 18 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index d2626440b9..b43b7ee71e 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -44,15 +44,16 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
AttributeSet(children.flatMap(_.asInstanceOf[QueryPlan[PlanType]].output))
/**
+ * The set of all attributes that are produced by this node.
+ */
+ def producedAttributes: AttributeSet = AttributeSet.empty
+
+ /**
* Attributes that are referenced by expressions but not provided by this nodes children.
* Subclasses should override this method if they produce attributes internally as it is used by
* assertions designed to prevent the construction of invalid plans.
- *
- * Note that virtual columns should be excluded. Currently, we only support the grouping ID
- * virtual column.
*/
- def missingInput: AttributeSet =
- (references -- inputSet).filter(_.name != VirtualColumn.groupingIdName)
+ def missingInput: AttributeSet = references -- inputSet -- producedAttributes
/**
* Runs [[transform]] with `rule` on all expressions present in this query operator.
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
index e3e7a11dba..572d7d2f0b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
@@ -18,8 +18,8 @@
package org.apache.spark.sql.catalyst.plans.logical
import org.apache.spark.sql.Row
-import org.apache.spark.sql.catalyst.expressions.Attribute
-import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow, analysis}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet}
+import org.apache.spark.sql.catalyst.{analysis, CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.types.{StructField, StructType}
object LocalRelation {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 8f8747e105..6d859551f8 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -295,6 +295,7 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
*/
abstract class LeafNode extends LogicalPlan {
override def children: Seq[LogicalPlan] = Nil
+ override def producedAttributes: AttributeSet = outputSet
}
/**
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
index 64ef4d7996..5f34d4a4eb 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicOperators.scala
@@ -526,7 +526,7 @@ case class MapPartitions[T, U](
uEncoder: ExpressionEncoder[U],
output: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
- override def missingInput: AttributeSet = AttributeSet.empty
+ override def producedAttributes: AttributeSet = outputSet
}
/** Factory for constructing new `AppendColumn` nodes. */
@@ -552,7 +552,7 @@ case class AppendColumns[T, U](
newColumns: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output ++ newColumns
- override def missingInput: AttributeSet = super.missingInput -- newColumns
+ override def producedAttributes: AttributeSet = AttributeSet(newColumns)
}
/** Factory for constructing new `MapGroups` nodes. */
@@ -587,7 +587,7 @@ case class MapGroups[K, T, U](
groupingAttributes: Seq[Attribute],
output: Seq[Attribute],
child: LogicalPlan) extends UnaryNode {
- override def missingInput: AttributeSet = AttributeSet.empty
+ override def producedAttributes: AttributeSet = outputSet
}
/** Factory for constructing new `CoGroup` nodes. */
@@ -630,5 +630,5 @@ case class CoGroup[Key, Left, Right, Result](
rightGroup: Seq[Attribute],
left: LogicalPlan,
right: LogicalPlan) extends BinaryNode {
- override def missingInput: AttributeSet = AttributeSet.empty
+ override def producedAttributes: AttributeSet = outputSet
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index ea5a9afe03..5c01af011d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -18,11 +18,11 @@
package org.apache.spark.sql.execution
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.catalyst.{InternalRow, CatalystTypeConverters}
+import org.apache.spark.sql.catalyst.{CatalystTypeConverters, InternalRow}
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
-import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericMutableRow}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, GenericMutableRow}
import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
-import org.apache.spark.sql.sources.{HadoopFsRelation, BaseRelation}
+import org.apache.spark.sql.sources.{BaseRelation, HadoopFsRelation}
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.{Row, SQLContext}
@@ -84,6 +84,8 @@ private[sql] case class LogicalRDD(
case _ => false
}
+ override def producedAttributes: AttributeSet = outputSet
+
@transient override lazy val statistics: Statistics = Statistics(
// TODO: Instead of returning a default value here, find a way to return a meaningful size
// estimate for RDDs. See PR 1238 for more discussions.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
index 54b8cb5828..0c613e91b9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
@@ -54,6 +54,8 @@ case class Generate(
child: SparkPlan)
extends UnaryNode {
+ override def expressions: Seq[Expression] = generator :: Nil
+
val boundGenerator = BindReferences.bindReference(generator, child.output)
protected override def doExecute(): RDD[InternalRow] = {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index ec98f81041..fe9b2ad4a0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -279,6 +279,7 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
private[sql] trait LeafNode extends SparkPlan {
override def children: Seq[SparkPlan] = Nil
+ override def producedAttributes: AttributeSet = outputSet
}
private[sql] trait UnaryNode extends SparkPlan {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
index c5470a6989..c4587ba677 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregate.scala
@@ -36,6 +36,15 @@ case class SortBasedAggregate(
child: SparkPlan)
extends UnaryNode {
+ private[this] val aggregateBufferAttributes = {
+ aggregateExpressions.flatMap(_.aggregateFunction.aggBufferAttributes)
+ }
+
+ override def producedAttributes: AttributeSet =
+ AttributeSet(aggregateAttributes) ++
+ AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
+ AttributeSet(aggregateBufferAttributes)
+
override private[sql] lazy val metrics = Map(
"numInputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of input rows"),
"numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
index b8849c8270..9d758eb3b7 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala
@@ -55,6 +55,11 @@ case class TungstenAggregate(
override def output: Seq[Attribute] = resultExpressions.map(_.toAttribute)
+ override def producedAttributes: AttributeSet =
+ AttributeSet(aggregateAttributes) ++
+ AttributeSet(resultExpressions.diff(groupingExpressions).map(_.toAttribute)) ++
+ AttributeSet(aggregateBufferAttributes)
+
override def requiredChildDistribution: List[Distribution] = {
requiredChildDistributionExpressions match {
case Some(exprs) if exprs.length == 0 => AllTuples :: Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 6b7b3bbbf6..f19d72f067 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -369,6 +369,7 @@ case class MapPartitions[T, U](
uEncoder: ExpressionEncoder[U],
output: Seq[Attribute],
child: SparkPlan) extends UnaryNode {
+ override def producedAttributes: AttributeSet = outputSet
override def canProcessSafeRows: Boolean = true
override def canProcessUnsafeRows: Boolean = true
@@ -391,6 +392,7 @@ case class AppendColumns[T, U](
uEncoder: ExpressionEncoder[U],
newColumns: Seq[Attribute],
child: SparkPlan) extends UnaryNode {
+ override def producedAttributes: AttributeSet = AttributeSet(newColumns)
// We are using an unsafe combiner.
override def canProcessSafeRows: Boolean = false
@@ -424,6 +426,7 @@ case class MapGroups[K, T, U](
groupingAttributes: Seq[Attribute],
output: Seq[Attribute],
child: SparkPlan) extends UnaryNode {
+ override def producedAttributes: AttributeSet = outputSet
override def canProcessSafeRows: Boolean = true
override def canProcessUnsafeRows: Boolean = true
@@ -467,6 +470,7 @@ case class CoGroup[Key, Left, Right, Result](
rightGroup: Seq[Attribute],
left: SparkPlan,
right: SparkPlan) extends BinaryNode {
+ override def producedAttributes: AttributeSet = outputSet
override def canProcessSafeRows: Boolean = true
override def canProcessUnsafeRows: Boolean = true
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala
index 4afa5f8ec1..aa7a668e0e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala
@@ -66,6 +66,8 @@ private[sql] case class InMemoryRelation(
private[sql] var _batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = null)
extends LogicalPlan with MultiInstanceRelation {
+ override def producedAttributes: AttributeSet = outputSet
+
private val batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] =
if (_batchStats == null) {
child.sqlContext.sparkContext.accumulableCollection(ArrayBuffer.empty[InternalRow])
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala
index 78a98798ef..359a1e7f84 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/ExtraStrategiesSuite.scala
@@ -15,16 +15,14 @@
* limitations under the License.
*/
-package test.org.apache.spark.sql
+package org.apache.spark.sql
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Literal, GenericInternalRow, Attribute}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.logical.{Project, LogicalPlan}
import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.{Row, Strategy, QueryTest}
import org.apache.spark.sql.test.SharedSQLContext
-import org.apache.spark.unsafe.types.UTF8String
case class FastOperator(output: Seq[Attribute]) extends SparkPlan {
@@ -34,6 +32,7 @@ case class FastOperator(output: Seq[Attribute]) extends SparkPlan {
sparkContext.parallelize(Seq(row))
}
+ override def producedAttributes: AttributeSet = outputSet
override def children: Seq[SparkPlan] = Nil
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
index 442ae79f4f..815372f192 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/QueryTest.scala
@@ -130,6 +130,8 @@ abstract class QueryTest extends PlanTest {
checkJsonFormat(analyzedDF)
+ assertEmptyMissingInput(df)
+
QueryTest.checkAnswer(analyzedDF, expectedAnswer) match {
case Some(errorMessage) => fail(errorMessage)
case None =>
@@ -275,6 +277,18 @@ abstract class QueryTest extends PlanTest {
""".stripMargin)
}
}
+
+ /**
+ * Asserts that a given [[Queryable]] does not have missing inputs in all the analyzed plans.
+ */
+ def assertEmptyMissingInput(query: Queryable): Unit = {
+ assert(query.queryExecution.analyzed.missingInput.isEmpty,
+ s"The analyzed logical plan has missing inputs: ${query.queryExecution.analyzed}")
+ assert(query.queryExecution.optimizedPlan.missingInput.isEmpty,
+ s"The optimized logical plan has missing inputs: ${query.queryExecution.optimizedPlan}")
+ assert(query.queryExecution.executedPlan.missingInput.isEmpty,
+ s"The physical plan has missing inputs: ${query.queryExecution.executedPlan}")
+ }
}
object QueryTest {
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala
index 806d2b9b0b..8141136de5 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScan.scala
@@ -51,6 +51,9 @@ case class HiveTableScan(
require(partitionPruningPred.isEmpty || relation.hiveQlTable.isPartitioned,
"Partition pruning predicates only supported for partitioned tables.")
+ override def producedAttributes: AttributeSet = outputSet ++
+ AttributeSet(partitionPruningPred.flatMap(_.references))
+
// Retrieve the original attributes based on expression ID so that capitalization matches.
val attributes = requestedAttributes.map(relation.attributeMap)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
index d9b9ba4bfd..a61e162f48 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/ScriptTransformation.scala
@@ -60,6 +60,8 @@ case class ScriptTransformation(
override protected def otherCopyArgs: Seq[HiveContext] = sc :: Nil
+ override def producedAttributes: AttributeSet = outputSet -- inputSet
+
private val serializedHiveConf = new SerializableConfiguration(sc.hiveconf)
protected override def doExecute(): RDD[InternalRow] = {