From 3d7f201f2adc2d33be6f564fa76435c18552f4ba Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Mon, 10 Apr 2017 13:36:08 +0800 Subject: [SPARK-20229][SQL] add semanticHash to QueryPlan ## What changes were proposed in this pull request? Like `Expression`, `QueryPlan` should also have a `semanticHash` method, then we can put plans to a hash map and look it up fast. This PR refactors `QueryPlan` to follow `Expression` and put all the normalization logic in `QueryPlan.canonicalized`, so that it's very natural to implement `semanticHash`. follow-up: improve `CacheManager` to leverage this `semanticHash` and speed up plan lookup, instead of iterating all cached plans. ## How was this patch tested? existing tests. Note that we don't need to test the `semanticHash` method, once the existing tests prove `sameResult` is correct, we are good. Author: Wenchen Fan Closes #17541 from cloud-fan/plan-semantic. --- .../spark/sql/catalyst/analysis/Analyzer.scala | 2 +- .../spark/sql/catalyst/catalog/interface.scala | 11 ++- .../spark/sql/catalyst/plans/QueryPlan.scala | 102 ++++++++++++--------- .../sql/catalyst/plans/logical/LocalRelation.scala | 8 -- .../sql/catalyst/plans/logical/LogicalPlan.scala | 2 - .../plans/logical/basicLogicalOperators.scala | 2 + .../catalyst/plans/physical/broadcastMode.scala | 9 +- .../spark/sql/execution/DataSourceScanExec.scala | 37 ++++---- .../apache/spark/sql/execution/ExistingRDD.scala | 14 --- .../spark/sql/execution/LocalTableScanExec.scala | 2 +- .../sql/execution/basicPhysicalOperators.scala | 10 +- .../execution/datasources/LogicalRelation.scala | 13 +-- .../execution/exchange/BroadcastExchangeExec.scala | 6 +- .../spark/sql/execution/exchange/Exchange.scala | 6 +- .../spark/sql/execution/joins/HashedRelation.scala | 11 +-- .../apache/spark/sql/execution/ExchangeSuite.scala | 18 ++-- .../sql/hive/execution/HiveTableScanExec.scala | 45 +++++---- 17 files changed, 135 insertions(+), 163 deletions(-) (limited to 'sql') diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala index c698ca6a83..b0cdef7029 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala @@ -617,7 +617,7 @@ class Analyzer( def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { case i @ InsertIntoTable(u: UnresolvedRelation, parts, child, _, _) if child.resolved => - lookupTableFromCatalog(u).canonicalized match { + EliminateSubqueryAliases(lookupTableFromCatalog(u)) match { case v: View => u.failAnalysis(s"Inserting into a view is not allowed. View: ${v.desc.identifier}.") case other => i.copy(table = other) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala index 360e55d922..cc0cbba275 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/catalog/interface.scala @@ -423,8 +423,15 @@ case class CatalogRelation( Objects.hashCode(tableMeta.identifier, output) } - /** Only compare table identifier. */ - override lazy val cleanArgs: Seq[Any] = Seq(tableMeta.identifier) + override def preCanonicalized: LogicalPlan = copy(tableMeta = CatalogTable( + identifier = tableMeta.identifier, + tableType = tableMeta.tableType, + storage = CatalogStorageFormat.empty, + schema = tableMeta.schema, + partitionColumnNames = tableMeta.partitionColumnNames, + bucketSpec = tableMeta.bucketSpec, + createTime = -1 + )) override def computeStats(conf: SQLConf): Statistics = { // For data source tables, we will create a `LogicalRelation` and won't call this method, for diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 2d8ec2053a..3008e8cb84 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -359,9 +359,59 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT override protected def innerChildren: Seq[QueryPlan[_]] = subqueries /** - * Canonicalized copy of this query plan. + * Returns a plan where a best effort attempt has been made to transform `this` in a way + * that preserves the result but removes cosmetic variations (case sensitivity, ordering for + * commutative operations, expression id, etc.) + * + * Plans where `this.canonicalized == other.canonicalized` will always evaluate to the same + * result. + * + * Some nodes should overwrite this to provide proper canonicalize logic. + */ + lazy val canonicalized: PlanType = { + val canonicalizedChildren = children.map(_.canonicalized) + var id = -1 + preCanonicalized.mapExpressions { + case a: Alias => + id += 1 + // As the root of the expression, Alias will always take an arbitrary exprId, we need to + // normalize that for equality testing, by assigning expr id from 0 incrementally. The + // alias name doesn't matter and should be erased. + Alias(normalizeExprId(a.child), "")(ExprId(id), a.qualifier, isGenerated = a.isGenerated) + + case ar: AttributeReference if allAttributes.indexOf(ar.exprId) == -1 => + // Top level `AttributeReference` may also be used for output like `Alias`, we should + // normalize the epxrId too. + id += 1 + ar.withExprId(ExprId(id)) + + case other => normalizeExprId(other) + }.withNewChildren(canonicalizedChildren) + } + + /** + * Do some simple transformation on this plan before canonicalizing. Implementations can override + * this method to provide customized canonicalize logic without rewriting the whole logic. */ - protected lazy val canonicalized: PlanType = this + protected def preCanonicalized: PlanType = this + + /** + * Normalize the exprIds in the given expression, by updating the exprId in `AttributeReference` + * with its referenced ordinal from input attributes. It's similar to `BindReferences` but we + * do not use `BindReferences` here as the plan may take the expression as a parameter with type + * `Attribute`, and replace it with `BoundReference` will cause error. + */ + protected def normalizeExprId[T <: Expression](e: T, input: AttributeSeq = allAttributes): T = { + e.transformUp { + case ar: AttributeReference => + val ordinal = input.indexOf(ar.exprId) + if (ordinal == -1) { + ar + } else { + ar.withExprId(ExprId(ordinal)) + } + }.canonicalized.asInstanceOf[T] + } /** * Returns true when the given query plan will return the same results as this query plan. @@ -372,49 +422,19 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT * enhancements like caching. However, it is not acceptable to return true if the results could * possibly be different. * - * By default this function performs a modified version of equality that is tolerant of cosmetic - * differences like attribute naming and or expression id differences. Operators that - * can do better should override this function. + * This function performs a modified version of equality that is tolerant of cosmetic + * differences like attribute naming and or expression id differences. */ - def sameResult(plan: PlanType): Boolean = { - val left = this.canonicalized - val right = plan.canonicalized - left.getClass == right.getClass && - left.children.size == right.children.size && - left.cleanArgs == right.cleanArgs && - (left.children, right.children).zipped.forall(_ sameResult _) - } + final def sameResult(other: PlanType): Boolean = this.canonicalized == other.canonicalized + + /** + * Returns a `hashCode` for the calculation performed by this plan. Unlike the standard + * `hashCode`, an attempt has been made to eliminate cosmetic differences. + */ + final def semanticHash(): Int = canonicalized.hashCode() /** * All the attributes that are used for this plan. */ lazy val allAttributes: AttributeSeq = children.flatMap(_.output) - - protected def cleanExpression(e: Expression): Expression = e match { - case a: Alias => - // As the root of the expression, Alias will always take an arbitrary exprId, we need - // to erase that for equality testing. - val cleanedExprId = - Alias(a.child, a.name)(ExprId(-1), a.qualifier, isGenerated = a.isGenerated) - BindReferences.bindReference(cleanedExprId, allAttributes, allowFailures = true) - case other => - BindReferences.bindReference(other, allAttributes, allowFailures = true) - } - - /** Args that have cleaned such that differences in expression id should not affect equality */ - protected lazy val cleanArgs: Seq[Any] = { - def cleanArg(arg: Any): Any = arg match { - // Children are checked using sameResult above. - case tn: TreeNode[_] if containsChild(tn) => null - case e: Expression => cleanExpression(e).canonicalized - case other => other - } - - mapProductIterator { - case s: Option[_] => s.map(cleanArg) - case s: Seq[_] => s.map(cleanArg) - case m: Map[_, _] => m.mapValues(cleanArg) - case other => cleanArg(other) - }.toSeq - } } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala index b7177c4a2c..9cd5dfd21b 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LocalRelation.scala @@ -67,14 +67,6 @@ case class LocalRelation(output: Seq[Attribute], data: Seq[InternalRow] = Nil) } } - override def sameResult(plan: LogicalPlan): Boolean = { - plan.canonicalized match { - case LocalRelation(otherOutput, otherData) => - otherOutput.map(_.dataType) == output.map(_.dataType) && otherData == data - case _ => false - } - } - override def computeStats(conf: SQLConf): Statistics = Statistics(sizeInBytes = output.map(n => BigInt(n.dataType.defaultSize)).sum * data.length) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala index 036b625668..6bdcf490ca 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/LogicalPlan.scala @@ -143,8 +143,6 @@ abstract class LogicalPlan extends QueryPlan[LogicalPlan] with Logging { */ def childrenResolved: Boolean = children.forall(_.resolved) - override lazy val canonicalized: LogicalPlan = EliminateSubqueryAliases(this) - /** * Resolves a given schema to concrete [[Attribute]] references in this query plan. This function * should only be called on analyzed plans since it will throw [[AnalysisException]] for diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala index c91de08ca5..3ad757ebba 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/logical/basicLogicalOperators.scala @@ -803,6 +803,8 @@ case class SubqueryAlias( child: LogicalPlan) extends UnaryNode { + override lazy val canonicalized: LogicalPlan = child.canonicalized + override def output: Seq[Attribute] = child.output.map(_.withQualifier(Some(alias))) } diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala index 9dfdf4da78..2ab46dc833 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/physical/broadcastMode.scala @@ -26,10 +26,7 @@ import org.apache.spark.sql.catalyst.InternalRow trait BroadcastMode { def transform(rows: Array[InternalRow]): Any - /** - * Returns true iff this [[BroadcastMode]] generates the same result as `other`. - */ - def compatibleWith(other: BroadcastMode): Boolean + def canonicalized: BroadcastMode } /** @@ -39,7 +36,5 @@ case object IdentityBroadcastMode extends BroadcastMode { // TODO: pack the UnsafeRows into single bytes array. override def transform(rows: Array[InternalRow]): Array[InternalRow] = rows - override def compatibleWith(other: BroadcastMode): Boolean = { - this eq other - } + override def canonicalized: BroadcastMode = this } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala index 2fa660c4d5..3a9132d74a 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/DataSourceScanExec.scala @@ -119,7 +119,7 @@ case class RowDataSourceScanExec( val input = ctx.freshName("input") ctx.addMutableState("scala.collection.Iterator", input, s"$input = inputs[0];") val exprRows = output.zipWithIndex.map{ case (a, i) => - new BoundReference(i, a.dataType, a.nullable) + BoundReference(i, a.dataType, a.nullable) } val row = ctx.freshName("row") ctx.INPUT_ROW = row @@ -136,19 +136,17 @@ case class RowDataSourceScanExec( """.stripMargin } - // Ignore rdd when checking results - override def sameResult(plan: SparkPlan): Boolean = plan match { - case other: RowDataSourceScanExec => relation == other.relation && metadata == other.metadata - case _ => false - } + // Only care about `relation` and `metadata` when canonicalizing. + override def preCanonicalized: SparkPlan = + copy(rdd = null, outputPartitioning = null, metastoreTableIdentifier = None) } /** * Physical plan node for scanning data from HadoopFsRelations. * * @param relation The file-based relation to scan. - * @param output Output attributes of the scan. - * @param outputSchema Output schema of the scan. + * @param output Output attributes of the scan, including data attributes and partition attributes. + * @param requiredSchema Required schema of the underlying relation, excluding partition columns. * @param partitionFilters Predicates to use for partition pruning. * @param dataFilters Filters on non-partition columns. * @param metastoreTableIdentifier identifier for the table in the metastore. @@ -156,7 +154,7 @@ case class RowDataSourceScanExec( case class FileSourceScanExec( @transient relation: HadoopFsRelation, output: Seq[Attribute], - outputSchema: StructType, + requiredSchema: StructType, partitionFilters: Seq[Expression], dataFilters: Seq[Expression], override val metastoreTableIdentifier: Option[TableIdentifier]) @@ -267,7 +265,7 @@ case class FileSourceScanExec( val metadata = Map( "Format" -> relation.fileFormat.toString, - "ReadSchema" -> outputSchema.catalogString, + "ReadSchema" -> requiredSchema.catalogString, "Batched" -> supportsBatch.toString, "PartitionFilters" -> seqToString(partitionFilters), "PushedFilters" -> seqToString(pushedDownFilters), @@ -287,7 +285,7 @@ case class FileSourceScanExec( sparkSession = relation.sparkSession, dataSchema = relation.dataSchema, partitionSchema = relation.partitionSchema, - requiredSchema = outputSchema, + requiredSchema = requiredSchema, filters = pushedDownFilters, options = relation.options, hadoopConf = relation.sparkSession.sessionState.newHadoopConfWithOptions(relation.options)) @@ -515,14 +513,13 @@ case class FileSourceScanExec( } } - override def sameResult(plan: SparkPlan): Boolean = plan match { - case other: FileSourceScanExec => - val thisPredicates = partitionFilters.map(cleanExpression) - val otherPredicates = other.partitionFilters.map(cleanExpression) - val result = relation == other.relation && metadata == other.metadata && - thisPredicates.length == otherPredicates.length && - thisPredicates.zip(otherPredicates).forall(p => p._1.semanticEquals(p._2)) - result - case _ => false + override lazy val canonicalized: FileSourceScanExec = { + FileSourceScanExec( + relation, + output.map(normalizeExprId(_, output)), + requiredSchema, + partitionFilters.map(normalizeExprId(_, output)), + dataFilters.map(normalizeExprId(_, output)), + None) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala index 2827b8ac00..3d1b481a53 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/ExistingRDD.scala @@ -87,13 +87,6 @@ case class ExternalRDD[T]( override def newInstance(): ExternalRDD.this.type = ExternalRDD(outputObjAttr.newInstance(), rdd)(session).asInstanceOf[this.type] - override def sameResult(plan: LogicalPlan): Boolean = { - plan.canonicalized match { - case ExternalRDD(_, otherRDD) => rdd.id == otherRDD.id - case _ => false - } - } - override protected def stringArgs: Iterator[Any] = Iterator(output) @transient override def computeStats(conf: SQLConf): Statistics = Statistics( @@ -162,13 +155,6 @@ case class LogicalRDD( )(session).asInstanceOf[this.type] } - override def sameResult(plan: LogicalPlan): Boolean = { - plan.canonicalized match { - case LogicalRDD(_, otherRDD, _, _) => rdd.id == otherRDD.id - case _ => false - } - } - override protected def stringArgs: Iterator[Any] = Iterator(output) @transient override def computeStats(conf: SQLConf): Statistics = Statistics( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala index e366b9af35..19c68c1326 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/LocalTableScanExec.scala @@ -33,7 +33,7 @@ case class LocalTableScanExec( override lazy val metrics = Map( "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows")) - private val unsafeRows: Array[InternalRow] = { + private lazy val unsafeRows: Array[InternalRow] = { if (rows.isEmpty) { Array.empty } else { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala index 66a8e044ab..44278e37c5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicPhysicalOperators.scala @@ -342,8 +342,9 @@ case class RangeExec(range: org.apache.spark.sql.catalyst.plans.logical.Range) "numOutputRows" -> SQLMetrics.createMetric(sparkContext, "number of output rows"), "numGeneratedRows" -> SQLMetrics.createMetric(sparkContext, "number of generated rows")) - // output attributes should not affect the results - override lazy val cleanArgs: Seq[Any] = Seq(start, step, numSlices, numElements) + override lazy val canonicalized: SparkPlan = { + RangeExec(range.canonicalized.asInstanceOf[org.apache.spark.sql.catalyst.plans.logical.Range]) + } override def inputRDDs(): Seq[RDD[InternalRow]] = { sqlContext.sparkContext.parallelize(0 until numSlices, numSlices) @@ -607,11 +608,6 @@ case class SubqueryExec(name: String, child: SparkPlan) extends UnaryExecNode { override def outputOrdering: Seq[SortOrder] = child.outputOrdering - override def sameResult(o: SparkPlan): Boolean = o match { - case s: SubqueryExec => child.sameResult(s.child) - case _ => false - } - @transient private lazy val relationFuture: Future[Array[InternalRow]] = { // relationFuture is used in "doExecute". Therefore we can get the execution id correctly here. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala index 4215203960..3813f953e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/LogicalRelation.scala @@ -43,17 +43,8 @@ case class LogicalRelation( com.google.common.base.Objects.hashCode(relation, output) } - override def sameResult(otherPlan: LogicalPlan): Boolean = { - otherPlan.canonicalized match { - case LogicalRelation(otherRelation, _, _) => relation == otherRelation - case _ => false - } - } - - // When comparing two LogicalRelations from within LogicalPlan.sameResult, we only need - // LogicalRelation.cleanArgs to return Seq(relation), since expectedOutputAttribute's - // expId can be different but the relation is still the same. - override lazy val cleanArgs: Seq[Any] = Seq(relation) + // Only care about relation when canonicalizing. + override def preCanonicalized: LogicalPlan = copy(catalogTable = None) @transient override def computeStats(conf: SQLConf): Statistics = { catalogTable.flatMap(_.stats.map(_.toPlanStats(output))).getOrElse( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala index efcaca9338..9c859e41f8 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/BroadcastExchangeExec.scala @@ -48,10 +48,8 @@ case class BroadcastExchangeExec( override def outputPartitioning: Partitioning = BroadcastPartitioning(mode) - override def sameResult(plan: SparkPlan): Boolean = plan match { - case p: BroadcastExchangeExec => - mode.compatibleWith(p.mode) && child.sameResult(p.child) - case _ => false + override lazy val canonicalized: SparkPlan = { + BroadcastExchangeExec(mode.canonicalized, child.canonicalized) } @transient diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala index 9a9597d373..d993ea6c6c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/exchange/Exchange.scala @@ -48,10 +48,8 @@ abstract class Exchange extends UnaryExecNode { case class ReusedExchangeExec(override val output: Seq[Attribute], child: Exchange) extends LeafExecNode { - override def sameResult(plan: SparkPlan): Boolean = { - // Ignore this wrapper. `plan` could also be a ReusedExchange, so we reverse the order here. - plan.sameResult(child) - } + // Ignore this wrapper for canonicalizing. + override lazy val canonicalized: SparkPlan = child.canonicalized def doExecute(): RDD[InternalRow] = { child.execute() diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala index b9f6601ea8..2dd1dc3da9 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -829,15 +829,10 @@ private[execution] case class HashedRelationBroadcastMode(key: Seq[Expression]) extends BroadcastMode { override def transform(rows: Array[InternalRow]): HashedRelation = { - HashedRelation(rows.iterator, canonicalizedKey, rows.length) + HashedRelation(rows.iterator, canonicalized.key, rows.length) } - private lazy val canonicalizedKey: Seq[Expression] = { - key.map { e => e.canonicalized } - } - - override def compatibleWith(other: BroadcastMode): Boolean = other match { - case m: HashedRelationBroadcastMode => canonicalizedKey == m.canonicalizedKey - case _ => false + override lazy val canonicalized: HashedRelationBroadcastMode = { + this.copy(key = key.map(_.canonicalized)) } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala index 36cde3233d..59eaf4d1c2 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/ExchangeSuite.scala @@ -36,17 +36,17 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { ) } - test("compatible BroadcastMode") { + test("BroadcastMode.canonicalized") { val mode1 = IdentityBroadcastMode val mode2 = HashedRelationBroadcastMode(Literal(1L) :: Nil) val mode3 = HashedRelationBroadcastMode(Literal("s") :: Nil) - assert(mode1.compatibleWith(mode1)) - assert(!mode1.compatibleWith(mode2)) - assert(!mode2.compatibleWith(mode1)) - assert(mode2.compatibleWith(mode2)) - assert(!mode2.compatibleWith(mode3)) - assert(mode3.compatibleWith(mode3)) + assert(mode1.canonicalized == mode1.canonicalized) + assert(mode1.canonicalized != mode2.canonicalized) + assert(mode2.canonicalized != mode1.canonicalized) + assert(mode2.canonicalized == mode2.canonicalized) + assert(mode2.canonicalized != mode3.canonicalized) + assert(mode3.canonicalized == mode3.canonicalized) } test("BroadcastExchange same result") { @@ -70,7 +70,7 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { assert(!exchange1.sameResult(exchange2)) assert(!exchange2.sameResult(exchange3)) - assert(!exchange3.sameResult(exchange4)) + assert(exchange3.sameResult(exchange4)) assert(exchange4 sameResult exchange3) } @@ -98,7 +98,7 @@ class ExchangeSuite extends SparkPlanTest with SharedSQLContext { assert(exchange1 sameResult exchange2) assert(!exchange2.sameResult(exchange3)) assert(!exchange3.sameResult(exchange4)) - assert(!exchange4.sameResult(exchange5)) + assert(exchange4.sameResult(exchange5)) assert(exchange5 sameResult exchange4) } } diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala index 28f074849c..fab0d7fa84 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/execution/HiveTableScanExec.scala @@ -72,7 +72,7 @@ case class HiveTableScanExec( // Bind all partition key attribute references in the partition pruning predicate for later // evaluation. - private val boundPruningPred = partitionPruningPred.reduceLeftOption(And).map { pred => + private lazy val boundPruningPred = partitionPruningPred.reduceLeftOption(And).map { pred => require( pred.dataType == BooleanType, s"Data type of predicate $pred must be BooleanType rather than ${pred.dataType}.") @@ -80,20 +80,22 @@ case class HiveTableScanExec( BindReferences.bindReference(pred, relation.partitionCols) } - // Create a local copy of hadoopConf,so that scan specific modifications should not impact - // other queries - @transient private val hadoopConf = sparkSession.sessionState.newHadoopConf() - - @transient private val hiveQlTable = HiveClientImpl.toHiveTable(relation.tableMeta) - @transient private val tableDesc = new TableDesc( + @transient private lazy val hiveQlTable = HiveClientImpl.toHiveTable(relation.tableMeta) + @transient private lazy val tableDesc = new TableDesc( hiveQlTable.getInputFormatClass, hiveQlTable.getOutputFormatClass, hiveQlTable.getMetadata) - // append columns ids and names before broadcast - addColumnMetadataToConf(hadoopConf) + // Create a local copy of hadoopConf,so that scan specific modifications should not impact + // other queries + @transient private lazy val hadoopConf = { + val c = sparkSession.sessionState.newHadoopConf() + // append columns ids and names before broadcast + addColumnMetadataToConf(c) + c + } - @transient private val hadoopReader = new HadoopTableReader( + @transient private lazy val hadoopReader = new HadoopTableReader( output, relation.partitionCols, tableDesc, @@ -104,7 +106,7 @@ case class HiveTableScanExec( Cast(Literal(value), dataType).eval(null) } - private def addColumnMetadataToConf(hiveConf: Configuration) { + private def addColumnMetadataToConf(hiveConf: Configuration): Unit = { // Specifies needed column IDs for those non-partitioning columns. val columnOrdinals = AttributeMap(relation.dataCols.zipWithIndex) val neededColumnIDs = output.flatMap(columnOrdinals.get).map(o => o: Integer) @@ -198,18 +200,13 @@ case class HiveTableScanExec( } } - override def sameResult(plan: SparkPlan): Boolean = plan match { - case other: HiveTableScanExec => - val thisPredicates = partitionPruningPred.map(cleanExpression) - val otherPredicates = other.partitionPruningPred.map(cleanExpression) - - val result = relation.sameResult(other.relation) && - output.length == other.output.length && - output.zip(other.output) - .forall(p => p._1.name == p._2.name && p._1.dataType == p._2.dataType) && - thisPredicates.length == otherPredicates.length && - thisPredicates.zip(otherPredicates).forall(p => p._1.semanticEquals(p._2)) - result - case _ => false + override lazy val canonicalized: HiveTableScanExec = { + val input: AttributeSeq = relation.output + HiveTableScanExec( + requestedAttributes.map(normalizeExprId(_, input)), + relation.canonicalized.asInstanceOf[CatalogRelation], + partitionPruningPred.map(normalizeExprId(_, input)))(sparkSession) } + + override def otherCopyArgs: Seq[AnyRef] = Seq(sparkSession) } -- cgit v1.2.3