aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2017-04-10 13:36:08 +0800
committerWenchen Fan <wenchen@databricks.com>2017-04-10 13:36:08 +0800
commit3d7f201f2adc2d33be6f564fa76435c18552f4ba (patch)
tree2c34606cf5cf36da43cf4d9b7056bf2b0c33cd44 /sql
parent1a0bc41659eef317dcac18df35c26857216a4314 (diff)
downloadspark-3d7f201f2adc2d33be6f564fa76435c18552f4ba.tar.gz
spark-3d7f201f2adc2d33be6f564fa76435c18552f4ba.tar.bz2
spark-3d7f201f2adc2d33be6f564fa76435c18552f4ba.zip
[SPARK-20229][SQL] add semanticHash to QueryPlan
## What changes were proposed in this pull request? Like `Expression`, `QueryPlan` should also have a `semanticHash` method, then we can put plans to a hash map and look it up fast. This PR refactors `QueryPlan` to follow `Expression` and put all the normalization logic in `QueryPlan.canonicalized`, so that it's very natural to implement `semanticHash`. follow-up: improve `CacheManager` to leverage this `semanticHash` and speed up plan lookup, instead of iterating all cached plans. ## How was this patch tested? existing tests. Note that we don't need to test the `semanticHash` method, once the existing tests prove `sameResult` is correct, we are good. Author: Wenchen Fan <wenchen@databricks.com> Closes #17541 from cloud-fan/plan-semantic.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala11
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala102
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala8
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala37
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala14
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala13
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala11
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala18
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala45
17 files changed, 135 insertions, 163 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
index c698ca6a83..b0cdef7029 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala
@@ -617,7 +617,7 @@ class Analyzer(
def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved =>
- lookupTableFromCatalog(u).canonicalized match {
+ EliminateSubqueryAliases(lookupTableFromCatalog(u)) match {
case v: View =>
u.failAnalysis(s"Inserting into a view is not allowed. View: ${v.desc.identifier}.")
case other => i.copy(table = other)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
index 360e55d922..cc0cbba275 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala
@@ -423,8 +423,15 @@ case class CatalogRelation(
Objects.hashCode(tableMeta.identifier, output)
}
- /** Only compare table identifier. */
- override lazy val cleanArgs: Seq[Any] = Seq(tableMeta.identifier)
+ override def preCanonicalized: LogicalPlan = copy(tableMeta = CatalogTable(
+ identifier = tableMeta.identifier,
+ tableType = tableMeta.tableType,
+ storage = CatalogStorageFormat.empty,
+ schema = tableMeta.schema,
+ partitionColumnNames = tableMeta.partitionColumnNames,
+ bucketSpec = tableMeta.bucketSpec,
+ createTime = -1
+ ))
override def computeStats(conf: SQLConf): Statistics = {
// For data source tables, we will create a `LogicalRelation` and won't call this method, for
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 2d8ec2053a..3008e8cb84 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -359,9 +359,59 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
override protected def innerChildren: Seq[QueryPlan[_]] = subqueries
/**
- * Canonicalized copy of this query plan.
+ * Returns a plan where a best effort attempt has been made to transform `this` in a way
+ * that preserves the result but removes cosmetic variations (case sensitivity, ordering for
+ * commutative operations, expression id, etc.)
+ *
+ * Plans where `this.canonicalized == other.canonicalized` will always evaluate to the same
+ * result.
+ *
+ * Some nodes should overwrite this to provide proper canonicalize logic.
+ */
+ lazy val canonicalized: PlanType = {
+ val canonicalizedChildren = children.map(_.canonicalized)
+ var id = -1
+ preCanonicalized.mapExpressions {
+ case a: Alias =>
+ id += 1
+ // As the root of the expression, Alias will always take an arbitrary exprId, we need to
+ // normalize that for equality testing, by assigning expr id from 0 incrementally. The
+ // alias name doesn't matter and should be erased.
+ Alias(normalizeExprId(a.child), "")(ExprId(id), a.qualifier, isGenerated = a.isGenerated)
+
+ case ar: AttributeReference if allAttributes.indexOf(ar.exprId) == -1 =>
+ // Top level `AttributeReference` may also be used for output like `Alias`, we should
+ // normalize the epxrId too.
+ id += 1
+ ar.withExprId(ExprId(id))
+
+ case other => normalizeExprId(other)
+ }.withNewChildren(canonicalizedChildren)
+ }
+
+ /**
+ * Do some simple transformation on this plan before canonicalizing. Implementations can override
+ * this method to provide customized canonicalize logic without rewriting the whole logic.
*/
- protected lazy val canonicalized: PlanType = this
+ protected def preCanonicalized: PlanType = this
+
+ /**
+ * Normalize the exprIds in the given expression, by updating the exprId in `AttributeReference`
+ * with its referenced ordinal from input attributes. It's similar to `BindReferences` but we
+ * do not use `BindReferences` here as the plan may take the expression as a parameter with type
+ * `Attribute`, and replace it with `BoundReference` will cause error.
+ */
+ protected def normalizeExprId[T <: Expression](e: T, input: AttributeSeq = allAttributes): T = {
+ e.transformUp {
+ case ar: AttributeReference =>
+ val ordinal = input.indexOf(ar.exprId)
+ if (ordinal == -1) {
+ ar
+ } else {
+ ar.withExprId(ExprId(ordinal))
+ }
+ }.canonicalized.asInstanceOf[T]
+ }
/**
* Returns true when the given query plan will return the same results as this query plan.
@@ -372,49 +422,19 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT
* enhancements like caching. However, it is not acceptable to return true if the results could
* possibly be different.
*
- * By default this function performs a modified version of equality that is tolerant of cosmetic
- * differences like attribute naming and or expression id differences. Operators that
- * can do better should override this function.
+ * This function performs a modified version of equality that is tolerant of cosmetic
+ * differences like attribute naming and or expression id differences.
*/
- def sameResult(plan: PlanType): Boolean = {
- val left = this.canonicalized
- val right = plan.canonicalized
- left.getClass == right.getClass &&
- left.children.size == right.children.size &&
- left.cleanArgs == right.cleanArgs &&
- (left.children, right.children).zipped.forall(_ sameResult _)
- }
+ final def sameResult(other: PlanType): Boolean = this.canonicalized == other.canonicalized
+
+ /**
+ * Returns a `hashCode` for the calculation performed by this plan. Unlike the standard
+ * `hashCode`, an attempt has been made to eliminate cosmetic differences.
+ */
+ final def semanticHash(): Int = canonicalized.hashCode()
/**
* All the attributes that are used for this plan.
*/
lazy val allAttributes: AttributeSeq = children.flatMap(_.output)
-
- protected def cleanExpression(e: Expression): Expression = e match {
- case a: Alias =>
- // As the root of the expression, Alias will always take an arbitrary exprId, we need
- // to erase that for equality testing.
- val cleanedExprId =
- Alias(a.child, a.name)(ExprId(-1), a.qualifier, isGenerated = a.isGenerated)
- BindReferences.bindReference(cleanedExprId, allAttributes, allowFailures = true)
- case other =>
- BindReferences.bindReference(other, allAttributes, allowFailures = true)
- }
-
- /** Args that have cleaned such that differences in expression id should not affect equality */
- protected lazy val cleanArgs: Seq[Any] = {
- def cleanArg(arg: Any): Any = arg match {
- // Children are checked using sameResult above.
- case tn: TreeNode[_] if containsChild(tn) => null
- case e: Expression => cleanExpression(e).canonicalized
- case other => other
- }
-
- mapProductIterator {
- case s: Option[_] => s.map(cleanArg)
- case s: Seq[_] => s.map(cleanArg)
- case m: Map[_, _] => m.mapValues(cleanArg)
- case other => cleanArg(other)
- }.toSeq
- }
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
index b7177c4a2c..9cd5dfd21b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala
@@ -67,14 +67,6 @@ case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil)
}
}
- override def sameResult(plan: LogicalPlan): Boolean = {
- plan.canonicalized match {
- case LocalRelation(otherOutput, otherData) =>
- otherOutput.map(_.dataType) == output.map(_.dataType) && otherData == data
- case _ => false
- }
- }
-
override def computeStats(conf: SQLConf): Statistics =
Statistics(sizeInBytes =
output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length)
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
index 036b625668..6bdcf490ca 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala
@@ -143,8 +143,6 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging {
*/
def childrenResolved: Boolean = children.forall(_.resolved)
- override lazy val canonicalized: LogicalPlan = EliminateSubqueryAliases(this)
-
/**
* Resolves a given schema to concrete [[Attribute]] references in this query plan. This function
* should only be called on analyzed plans since it will throw [[AnalysisException]] for
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
index c91de08ca5..3ad757ebba 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala
@@ -803,6 +803,8 @@ case class SubqueryAlias(
child: LogicalPlan)
extends UnaryNode {
+ override lazy val canonicalized: LogicalPlan = child.canonicalized
+
override def output: Seq[Attribute] = child.output.map(_.withQualifier(Some(alias)))
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala
index 9dfdf4da78..2ab46dc833 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala
@@ -26,10 +26,7 @@ import org.apache.spark.sql.catalyst.InternalRow
trait BroadcastMode {
def transform(rows: Array[InternalRow]): Any
- /**
- * Returns true iff this [[BroadcastMode]] generates the same result as `other`.
- */
- def compatibleWith(other: BroadcastMode): Boolean
+ def canonicalized: BroadcastMode
}
/**
@@ -39,7 +36,5 @@ case object IdentityBroadcastMode extends BroadcastMode {
// TODO: pack the UnsafeRows into single bytes array.
override def transform(rows: Array[InternalRow]): Array[InternalRow] = rows
- override def compatibleWith(other: BroadcastMode): Boolean = {
- this eq other
- }
+ override def canonicalized: BroadcastMode = this
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
index 2fa660c4d5..3a9132d74a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala
@@ -119,7 +119,7 @@ case class RowDataSourceScanExec(
val input = ctx.freshName("input")
ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];")
val exprRows = output.zipWithIndex.map{ case (a, i) =>
- new BoundReference(i, a.dataType, a.nullable)
+ BoundReference(i, a.dataType, a.nullable)
}
val row = ctx.freshName("row")
ctx.INPUT_ROW = row
@@ -136,19 +136,17 @@ case class RowDataSourceScanExec(
""".stripMargin
}
- // Ignore rdd when checking results
- override def sameResult(plan: SparkPlan): Boolean = plan match {
- case other: RowDataSourceScanExec => relation == other.relation && metadata == other.metadata
- case _ => false
- }
+ // Only care about `relation` and `metadata` when canonicalizing.
+ override def preCanonicalized: SparkPlan =
+ copy(rdd = null, outputPartitioning = null, metastoreTableIdentifier = None)
}
/**
* Physical plan node for scanning data from HadoopFsRelations.
*
* @param relation The file-based relation to scan.
- * @param output Output attributes of the scan.
- * @param outputSchema Output schema of the scan.
+ * @param output Output attributes of the scan, including data attributes and partition attributes.
+ * @param requiredSchema Required schema of the underlying relation, excluding partition columns.
* @param partitionFilters Predicates to use for partition pruning.
* @param dataFilters Filters on non-partition columns.
* @param metastoreTableIdentifier identifier for the table in the metastore.
@@ -156,7 +154,7 @@ case class RowDataSourceScanExec(
case class FileSourceScanExec(
@transient relation: HadoopFsRelation,
output: Seq[Attribute],
- outputSchema: StructType,
+ requiredSchema: StructType,
partitionFilters: Seq[Expression],
dataFilters: Seq[Expression],
override val metastoreTableIdentifier: Option[TableIdentifier])
@@ -267,7 +265,7 @@ case class FileSourceScanExec(
val metadata =
Map(
"Format" -> relation.fileFormat.toString,
- "ReadSchema" -> outputSchema.catalogString,
+ "ReadSchema" -> requiredSchema.catalogString,
"Batched" -> supportsBatch.toString,
"PartitionFilters" -> seqToString(partitionFilters),
"PushedFilters" -> seqToString(pushedDownFilters),
@@ -287,7 +285,7 @@ case class FileSourceScanExec(
sparkSession = relation.sparkSession,
dataSchema = relation.dataSchema,
partitionSchema = relation.partitionSchema,
- requiredSchema = outputSchema,
+ requiredSchema = requiredSchema,
filters = pushedDownFilters,
options = relation.options,
hadoopConf = relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options))
@@ -515,14 +513,13 @@ case class FileSourceScanExec(
}
}
- override def sameResult(plan: SparkPlan): Boolean = plan match {
- case other: FileSourceScanExec =>
- val thisPredicates = partitionFilters.map(cleanExpression)
- val otherPredicates = other.partitionFilters.map(cleanExpression)
- val result = relation == other.relation && metadata == other.metadata &&
- thisPredicates.length == otherPredicates.length &&
- thisPredicates.zip(otherPredicates).forall(p => p._1.semanticEquals(p._2))
- result
- case _ => false
+ override lazy val canonicalized: FileSourceScanExec = {
+ FileSourceScanExec(
+ relation,
+ output.map(normalizeExprId(_, output)),
+ requiredSchema,
+ partitionFilters.map(normalizeExprId(_, output)),
+ dataFilters.map(normalizeExprId(_, output)),
+ None)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
index 2827b8ac00..3d1b481a53 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala
@@ -87,13 +87,6 @@ case class ExternalRDD[T](
override def newInstance(): ExternalRDD.this.type =
ExternalRDD(outputObjAttr.newInstance(), rdd)(session).asInstanceOf[this.type]
- override def sameResult(plan: LogicalPlan): Boolean = {
- plan.canonicalized match {
- case ExternalRDD(_, otherRDD) => rdd.id == otherRDD.id
- case _ => false
- }
- }
-
override protected def stringArgs: Iterator[Any] = Iterator(output)
@transient override def computeStats(conf: SQLConf): Statistics = Statistics(
@@ -162,13 +155,6 @@ case class LogicalRDD(
)(session).asInstanceOf[this.type]
}
- override def sameResult(plan: LogicalPlan): Boolean = {
- plan.canonicalized match {
- case LogicalRDD(_, otherRDD, _, _) => rdd.id == otherRDD.id
- case _ => false
- }
- }
-
override protected def stringArgs: Iterator[Any] = Iterator(output)
@transient override def computeStats(conf: SQLConf): Statistics = Statistics(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala
index e366b9af35..19c68c1326 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala
@@ -33,7 +33,7 @@ case class LocalTableScanExec(
override lazy val metrics = Map(
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"))
- private val unsafeRows: Array[InternalRow] = {
+ private lazy val unsafeRows: Array[InternalRow] = {
if (rows.isEmpty) {
Array.empty
} else {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
index 66a8e044ab..44278e37c5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala
@@ -342,8 +342,9 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range)
"numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"),
"numGeneratedRows" -> SQLMetrics.createMetric(sparkContext, "number of generated rows"))
- // output attributes should not affect the results
- override lazy val cleanArgs: Seq[Any] = Seq(start, step, numSlices, numElements)
+ override lazy val canonicalized: SparkPlan = {
+ RangeExec(range.canonicalized.asInstanceOf[org.apache.spark.sql.catalyst.plans.logical.Range])
+ }
override def inputRDDs(): Seq[RDD[InternalRow]] = {
sqlContext.sparkContext.parallelize(0 until numSlices, numSlices)
@@ -607,11 +608,6 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode {
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
- override def sameResult(o: SparkPlan): Boolean = o match {
- case s: SubqueryExec => child.sameResult(s.child)
- case _ => false
- }
-
@transient
private lazy val relationFuture: Future[Array[InternalRow]] = {
// relationFuture is used in "doExecute". Therefore we can get the execution id correctly here.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala
index 4215203960..3813f953e0 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala
@@ -43,17 +43,8 @@ case class LogicalRelation(
com.google.common.base.Objects.hashCode(relation, output)
}
- override def sameResult(otherPlan: LogicalPlan): Boolean = {
- otherPlan.canonicalized match {
- case LogicalRelation(otherRelation, _, _) => relation == otherRelation
- case _ => false
- }
- }
-
- // When comparing two LogicalRelations from within LogicalPlan.sameResult, we only need
- // LogicalRelation.cleanArgs to return Seq(relation), since expectedOutputAttribute's
- // expId can be different but the relation is still the same.
- override lazy val cleanArgs: Seq[Any] = Seq(relation)
+ // Only care about relation when canonicalizing.
+ override def preCanonicalized: LogicalPlan = copy(catalogTable = None)
@transient override def computeStats(conf: SQLConf): Statistics = {
catalogTable.flatMap(_.stats.map(_.toPlanStats(output))).getOrElse(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
index efcaca9338..9c859e41f8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala
@@ -48,10 +48,8 @@ case class BroadcastExchangeExec(
override def outputPartitioning: Partitioning = BroadcastPartitioning(mode)
- override def sameResult(plan: SparkPlan): Boolean = plan match {
- case p: BroadcastExchangeExec =>
- mode.compatibleWith(p.mode) && child.sameResult(p.child)
- case _ => false
+ override lazy val canonicalized: SparkPlan = {
+ BroadcastExchangeExec(mode.canonicalized, child.canonicalized)
}
@transient
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala
index 9a9597d373..d993ea6c6c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala
@@ -48,10 +48,8 @@ abstract class Exchange extends UnaryExecNode {
case class ReusedExchangeExec(override val output: Seq[Attribute], child: Exchange)
extends LeafExecNode {
- override def sameResult(plan: SparkPlan): Boolean = {
- // Ignore this wrapper. `plan` could also be a ReusedExchange, so we reverse the order here.
- plan.sameResult(child)
- }
+ // Ignore this wrapper for canonicalizing.
+ override lazy val canonicalized: SparkPlan = child.canonicalized
def doExecute(): RDD[InternalRow] = {
child.execute()
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
index b9f6601ea8..2dd1dc3da9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -829,15 +829,10 @@ private[execution] case class HashedRelationBroadcastMode(key: Seq[Expression])
extends BroadcastMode {
override def transform(rows: Array[InternalRow]): HashedRelation = {
- HashedRelation(rows.iterator, canonicalizedKey, rows.length)
+ HashedRelation(rows.iterator, canonicalized.key, rows.length)
}
- private lazy val canonicalizedKey: Seq[Expression] = {
- key.map { e => e.canonicalized }
- }
-
- override def compatibleWith(other: BroadcastMode): Boolean = other match {
- case m: HashedRelationBroadcastMode => canonicalizedKey == m.canonicalizedKey
- case _ => false
+ override lazy val canonicalized: HashedRelationBroadcastMode = {
+ this.copy(key = key.map(_.canonicalized))
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
index 36cde3233d..59eaf4d1c2 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala
@@ -36,17 +36,17 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
)
}
- test("compatible BroadcastMode") {
+ test("BroadcastMode.canonicalized") {
val mode1 = IdentityBroadcastMode
val mode2 = HashedRelationBroadcastMode(Literal(1L) :: Nil)
val mode3 = HashedRelationBroadcastMode(Literal("s") :: Nil)
- assert(mode1.compatibleWith(mode1))
- assert(!mode1.compatibleWith(mode2))
- assert(!mode2.compatibleWith(mode1))
- assert(mode2.compatibleWith(mode2))
- assert(!mode2.compatibleWith(mode3))
- assert(mode3.compatibleWith(mode3))
+ assert(mode1.canonicalized == mode1.canonicalized)
+ assert(mode1.canonicalized != mode2.canonicalized)
+ assert(mode2.canonicalized != mode1.canonicalized)
+ assert(mode2.canonicalized == mode2.canonicalized)
+ assert(mode2.canonicalized != mode3.canonicalized)
+ assert(mode3.canonicalized == mode3.canonicalized)
}
test("BroadcastExchange same result") {
@@ -70,7 +70,7 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
assert(!exchange1.sameResult(exchange2))
assert(!exchange2.sameResult(exchange3))
- assert(!exchange3.sameResult(exchange4))
+ assert(exchange3.sameResult(exchange4))
assert(exchange4 sameResult exchange3)
}
@@ -98,7 +98,7 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext {
assert(exchange1 sameResult exchange2)
assert(!exchange2.sameResult(exchange3))
assert(!exchange3.sameResult(exchange4))
- assert(!exchange4.sameResult(exchange5))
+ assert(exchange4.sameResult(exchange5))
assert(exchange5 sameResult exchange4)
}
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
index 28f074849c..fab0d7fa84 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala
@@ -72,7 +72,7 @@ case class HiveTableScanExec(
// Bind all partition key attribute references in the partition pruning predicate for later
// evaluation.
- private val boundPruningPred = partitionPruningPred.reduceLeftOption(And).map { pred =>
+ private lazy val boundPruningPred = partitionPruningPred.reduceLeftOption(And).map { pred =>
require(
pred.dataType == BooleanType,
s"Data type of predicate $pred must be BooleanType rather than ${pred.dataType}.")
@@ -80,20 +80,22 @@ case class HiveTableScanExec(
BindReferences.bindReference(pred, relation.partitionCols)
}
- // Create a local copy of hadoopConf,so that scan specific modifications should not impact
- // other queries
- @transient private val hadoopConf = sparkSession.sessionState.newHadoopConf()
-
- @transient private val hiveQlTable = HiveClientImpl.toHiveTable(relation.tableMeta)
- @transient private val tableDesc = new TableDesc(
+ @transient private lazy val hiveQlTable = HiveClientImpl.toHiveTable(relation.tableMeta)
+ @transient private lazy val tableDesc = new TableDesc(
hiveQlTable.getInputFormatClass,
hiveQlTable.getOutputFormatClass,
hiveQlTable.getMetadata)
- // append columns ids and names before broadcast
- addColumnMetadataToConf(hadoopConf)
+ // Create a local copy of hadoopConf,so that scan specific modifications should not impact
+ // other queries
+ @transient private lazy val hadoopConf = {
+ val c = sparkSession.sessionState.newHadoopConf()
+ // append columns ids and names before broadcast
+ addColumnMetadataToConf(c)
+ c
+ }
- @transient private val hadoopReader = new HadoopTableReader(
+ @transient private lazy val hadoopReader = new HadoopTableReader(
output,
relation.partitionCols,
tableDesc,
@@ -104,7 +106,7 @@ case class HiveTableScanExec(
Cast(Literal(value), dataType).eval(null)
}
- private def addColumnMetadataToConf(hiveConf: Configuration) {
+ private def addColumnMetadataToConf(hiveConf: Configuration): Unit = {
// Specifies needed column IDs for those non-partitioning columns.
val columnOrdinals = AttributeMap(relation.dataCols.zipWithIndex)
val neededColumnIDs = output.flatMap(columnOrdinals.get).map(o => o: Integer)
@@ -198,18 +200,13 @@ case class HiveTableScanExec(
}
}
- override def sameResult(plan: SparkPlan): Boolean = plan match {
- case other: HiveTableScanExec =>
- val thisPredicates = partitionPruningPred.map(cleanExpression)
- val otherPredicates = other.partitionPruningPred.map(cleanExpression)
-
- val result = relation.sameResult(other.relation) &&
- output.length == other.output.length &&
- output.zip(other.output)
- .forall(p => p._1.name == p._2.name && p._1.dataType == p._2.dataType) &&
- thisPredicates.length == otherPredicates.length &&
- thisPredicates.zip(otherPredicates).forall(p => p._1.semanticEquals(p._2))
- result
- case _ => false
+ override lazy val canonicalized: HiveTableScanExec = {
+ val input: AttributeSeq = relation.output
+ HiveTableScanExec(
+ requestedAttributes.map(normalizeExprId(_, input)),
+ relation.canonicalized.asInstanceOf[CatalogRelation],
+ partitionPruningPred.map(normalizeExprId(_, input)))(sparkSession)
}
+
+ override def otherCopyArgs: Seq[AnyRef] = Seq(sparkSession)
}