From 0b8d694999b43ada4833388cad6c285c7757cbf7 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Mon, 6 Jun 2016 11:44:51 -0700 Subject: [SPARK-15764][SQL] Replace N^2 loop in BindReferences BindReferences contains a n^2 loop which causes performance issues when operating over large schemas: to determine the ordinal of an attribute reference, we perform a linear scan over the `input` array. Because input can sometimes be a `List`, the call to `input(ordinal).nullable` can also be O(n). Instead of performing a linear scan, we can convert the input into an array and build a hash map to map from expression ids to ordinals. The greater up-front cost of the map construction is offset by the fact that an expression can contain multiple attribute references, so the cost of the map construction is amortized across a number of lookups. Perf. benchmarks to follow. /cc ericl Author: Josh Rosen Closes #13505 from JoshRosen/bind-references-improvement. --- .../sql/catalyst/expressions/AttributeMap.scala | 7 ----- .../sql/catalyst/expressions/BoundAttribute.scala | 6 ++-- .../spark/sql/catalyst/expressions/package.scala | 34 +++++++++++++++++++++- .../spark/sql/catalyst/plans/QueryPlan.scala | 2 +- .../execution/aggregate/HashAggregateExec.scala | 2 +- .../execution/columnar/InMemoryTableScanExec.scala | 4 +-- 6 files changed, 40 insertions(+), 15 deletions(-) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala index ef3cc554b7..96a11e352e 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/AttributeMap.scala @@ -26,13 +26,6 @@ object AttributeMap { def apply[A](kvs: Seq[(Attribute, A)]): AttributeMap[A] = { new AttributeMap(kvs.map(kv => (kv._1.exprId, kv)).toMap) } - - /** Given a schema, constructs an [[AttributeMap]] from [[Attribute]] to ordinal */ - def byIndex(schema: Seq[Attribute]): AttributeMap[Int] = apply(schema.zipWithIndex) - - /** Given a schema, constructs a map from ordinal to Attribute. */ - def toIndex(schema: Seq[Attribute]): Map[Int, Attribute] = - schema.zipWithIndex.map { case (a, i) => i -> a }.toMap } class AttributeMap[A](baseMap: Map[ExprId, (Attribute, A)]) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala index a38f1ec091..7d16118c9d 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/BoundAttribute.scala @@ -82,16 +82,16 @@ object BindReferences extends Logging { def bindReference[A <: Expression]( expression: A, - input: Seq[Attribute], + input: AttributeSeq, allowFailures: Boolean = false): A = { expression.transform { case a: AttributeReference => attachTree(a, "Binding attribute") { - val ordinal = input.indexWhere(_.exprId == a.exprId) + val ordinal = input.indexOf(a.exprId) if (ordinal == -1) { if (allowFailures) { a } else { - sys.error(s"Couldn't find $a in ${input.mkString("[", ",", "]")}") + sys.error(s"Couldn't find $a in ${input.attrs.mkString("[", ",", "]")}") } } else { BoundReference(ordinal, a.dataType, input(ordinal).nullable) diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala index 23baa6f783..81f5bb4a65 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/package.scala @@ -17,6 +17,8 @@ package org.apache.spark.sql.catalyst +import com.google.common.collect.Maps + import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.types.{StructField, StructType} @@ -86,11 +88,41 @@ package object expressions { /** * Helper functions for working with `Seq[Attribute]`. */ - implicit class AttributeSeq(attrs: Seq[Attribute]) { + implicit class AttributeSeq(val attrs: Seq[Attribute]) extends Serializable { /** Creates a StructType with a schema matching this `Seq[Attribute]`. */ def toStructType: StructType = { StructType(attrs.map(a => StructField(a.name, a.dataType, a.nullable))) } + + // It's possible that `attrs` is a linked list, which can lead to bad O(n^2) loops when + // accessing attributes by their ordinals. To avoid this performance penalty, convert the input + // to an array. + @transient private lazy val attrsArray = attrs.toArray + + @transient private lazy val exprIdToOrdinal = { + val arr = attrsArray + val map = Maps.newHashMapWithExpectedSize[ExprId, Int](arr.length) + // Iterate over the array in reverse order so that the final map value is the first attribute + // with a given expression id. + var index = arr.length - 1 + while (index >= 0) { + map.put(arr(index).exprId, index) + index -= 1 + } + map + } + + /** + * Returns the attribute at the given index. + */ + def apply(ordinal: Int): Attribute = attrsArray(ordinal) + + /** + * Returns the index of first attribute with a matching expression id, or -1 if no match exists. + */ + def indexOf(exprId: ExprId): Int = { + Option(exprIdToOrdinal.get(exprId)).getOrElse(-1) + } } /** diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala index 3de15a9a3f..19a66cff4f 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala @@ -296,7 +296,7 @@ abstract class QueryPlan[PlanType <: QueryPlan[PlanType]] extends TreeNode[PlanT /** * All the attributes that are used for this plan. */ - lazy val allAttributes: Seq[Attribute] = children.flatMap(_.output) + lazy val allAttributes: AttributeSeq = children.flatMap(_.output) private def cleanExpression(e: Expression): Expression = e match { case a: Alias => diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala index f5bc0628b6..f270ca0755 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/HashAggregateExec.scala @@ -49,7 +49,7 @@ case class HashAggregateExec( require(HashAggregateExec.supportsAggregate(aggregateBufferAttributes)) - override lazy val allAttributes: Seq[Attribute] = + override lazy val allAttributes: AttributeSeq = child.output ++ aggregateBufferAttributes ++ aggregateAttributes ++ aggregateExpressions.flatMap(_.aggregateFunction.inputAggBufferAttributes) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala index bd55e1a875..a1c2f0a8fb 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryTableScanExec.scala @@ -310,7 +310,7 @@ private[sql] case class InMemoryTableScanExec( // within the map Partitions closure. val schema = relation.partitionStatistics.schema val schemaIndex = schema.zipWithIndex - val relOutput = relation.output + val relOutput: AttributeSeq = relation.output val buffers = relation.cachedColumnBuffers buffers.mapPartitionsInternal { cachedBatchIterator => @@ -321,7 +321,7 @@ private[sql] case class InMemoryTableScanExec( // Find the ordinals and data types of the requested columns. val (requestedColumnIndices, requestedColumnDataTypes) = attributes.map { a => - relOutput.indexWhere(_.exprId == a.exprId) -> a.dataType + relOutput.indexOf(a.exprId) -> a.dataType }.unzip // Do partition batch pruning if enabled -- cgit v1.2.3