aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorCheng Hao <hao.cheng@intel.com>2014-08-29 15:32:26 -0700
committerMichael Armbrust <michael@databricks.com>2014-08-29 15:32:26 -0700
commitdc4d577c6549df58f42c0e22cac354554d169896 (patch)
treeca5a282f18607538ea6dd880e472b2dcdc33c93b /sql
parent287c0ac7722dd4bc51b921ccc6f0e3c1625b5ff4 (diff)
downloadspark-dc4d577c6549df58f42c0e22cac354554d169896.tar.gz
spark-dc4d577c6549df58f42c0e22cac354554d169896.tar.bz2
spark-dc4d577c6549df58f42c0e22cac354554d169896.zip
[SPARK-3198] [SQL] Remove the TreeNode.id
Thus id property of the TreeNode API does save time in a faster way to compare 2 TreeNodes, it is kind of performance bottleneck during the expression object creation in a multi-threading env (because of the memory barrier). Fortunately, the tree node comparison only happen once in master, so even we remove it, the entire performance will not be affected. Author: Cheng Hao <hao.cheng@intel.com> Closes #2155 from chenghao-intel/treenode and squashes the following commits: 7cf2cd2 [Cheng Hao] Remove the implicit keyword for TreeNodeRef and some other small issues 5873415 [Cheng Hao] Remove the TreeNode.id
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala11
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala12
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala24
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala11
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala2
8 files changed, 40 insertions, 42 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
index 90923fe31a..f0fd9a8b9a 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/planning/patterns.scala
@@ -19,8 +19,9 @@ package org.apache.spark.sql.catalyst.planning
import scala.annotation.tailrec
-import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.Logging
+import org.apache.spark.sql.catalyst.trees.TreeNodeRef
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.logical._
@@ -134,8 +135,8 @@ object PartialAggregation {
// Only do partial aggregation if supported by all aggregate expressions.
if (allAggregates.size == partialAggregates.size) {
// Create a map of expressions to their partial evaluations for all aggregate expressions.
- val partialEvaluations: Map[Long, SplitEvaluation] =
- partialAggregates.map(a => (a.id, a.asPartial)).toMap
+ val partialEvaluations: Map[TreeNodeRef, SplitEvaluation] =
+ partialAggregates.map(a => (new TreeNodeRef(a), a.asPartial)).toMap
// We need to pass all grouping expressions though so the grouping can happen a second
// time. However some of them might be unnamed so we alias them allowing them to be
@@ -148,8 +149,8 @@ object PartialAggregation {
// Replace aggregations with a new expression that computes the result from the already
// computed partial evaluations and grouping values.
val rewrittenAggregateExpressions = aggregateExpressions.map(_.transformUp {
- case e: Expression if partialEvaluations.contains(e.id) =>
- partialEvaluations(e.id).finalEvaluation
+ case e: Expression if partialEvaluations.contains(new TreeNodeRef(e)) =>
+ partialEvaluations(new TreeNodeRef(e)).finalEvaluation
case e: Expression if namedGroupingExpressions.contains(e) =>
namedGroupingExpressions(e).toAttribute
}).asInstanceOf[Seq[NamedExpression]]
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
index 1e177e28f8..af9e4d86e9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/plans/QueryPlan.scala
@@ -50,11 +50,11 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
@inline def transformExpressionDown(e: Expression) = {
val newE = e.transformDown(rule)
- if (newE.id != e.id && newE != e) {
+ if (newE.fastEquals(e)) {
+ e
+ } else {
changed = true
newE
- } else {
- e
}
}
@@ -82,11 +82,11 @@ abstract class QueryPlan[PlanType <: TreeNode[PlanType]] extends TreeNode[PlanTy
@inline def transformExpressionUp(e: Expression) = {
val newE = e.transformUp(rule)
- if (newE.id != e.id && newE != e) {
+ if (newE.fastEquals(e)) {
+ e
+ } else {
changed = true
newE
- } else {
- e
}
}
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
index 96ce35939e..2013ae4f7b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/TreeNode.scala
@@ -19,11 +19,6 @@ package org.apache.spark.sql.catalyst.trees
import org.apache.spark.sql.catalyst.errors._
-object TreeNode {
- private val currentId = new java.util.concurrent.atomic.AtomicLong
- protected def nextId() = currentId.getAndIncrement()
-}
-
/** Used by [[TreeNode.getNodeNumbered]] when traversing the tree for a given number */
private class MutableInt(var i: Int)
@@ -34,28 +29,12 @@ abstract class TreeNode[BaseType <: TreeNode[BaseType]] {
def children: Seq[BaseType]
/**
- * A globally unique id for this specific instance. Not preserved across copies.
- * Unlike `equals`, `id` can be used to differentiate distinct but structurally
- * identical branches of a tree.
- */
- val id = TreeNode.nextId()
-
- /**
- * Returns true if other is the same [[catalyst.trees.TreeNode TreeNode]] instance. Unlike
- * `equals` this function will return false for different instances of structurally identical
- * trees.
- */
- def sameInstance(other: TreeNode[_]): Boolean = {
- this.id == other.id
- }
-
- /**
* Faster version of equality which short-circuits when two treeNodes are the same instance.
* We don't just override Object.Equals, as doing so prevents the scala compiler from from
* generating case class `equals` methods
*/
def fastEquals(other: TreeNode[_]): Boolean = {
- sameInstance(other) || this == other
+ this.eq(other) || this == other
}
/**
@@ -393,3 +372,4 @@ trait UnaryNode[BaseType <: TreeNode[BaseType]] {
def child: BaseType
def children = child :: Nil
}
+
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala
index d725a92c06..79a8e06d4b 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/trees/package.scala
@@ -37,4 +37,15 @@ package object trees extends Logging {
// Since we want tree nodes to be lightweight, we create one logger for all treenode instances.
protected override def logName = "catalyst.trees"
+ /**
+ * A [[TreeNode]] companion for reference equality for Hash based Collection.
+ */
+ class TreeNodeRef(val obj: TreeNode[_]) {
+ override def equals(o: Any) = o match {
+ case that: TreeNodeRef => that.obj.eq(obj)
+ case _ => false
+ }
+
+ override def hashCode = if (obj == null) 0 else obj.hashCode
+ }
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
index 296202543e..036fd3fa1d 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/trees/TreeNodeSuite.scala
@@ -51,7 +51,10 @@ class TreeNodeSuite extends FunSuite {
val after = before transform { case Literal(5, _) => Literal(1)}
assert(before === after)
- assert(before.map(_.id) === after.map(_.id))
+ // Ensure that the objects after are the same objects before the transformation.
+ before.map(identity[Expression]).zip(after.map(identity[Expression])).foreach {
+ case (b, a) => assert(b eq a)
+ }
}
test("collect") {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
index 31ad5e8aab..b3edd5020f 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/GeneratedAggregate.scala
@@ -19,6 +19,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.sql.SQLContext
+import org.apache.spark.sql.catalyst.trees._
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.catalyst.types._
@@ -141,9 +142,10 @@ case class GeneratedAggregate(
val computationSchema = computeFunctions.flatMap(_.schema)
- val resultMap: Map[Long, Expression] = aggregatesToCompute.zip(computeFunctions).map {
- case (agg, func) => agg.id -> func.result
- }.toMap
+ val resultMap: Map[TreeNodeRef, Expression] =
+ aggregatesToCompute.zip(computeFunctions).map {
+ case (agg, func) => new TreeNodeRef(agg) -> func.result
+ }.toMap
val namedGroups = groupingExpressions.zipWithIndex.map {
case (ne: NamedExpression, _) => (ne, ne)
@@ -156,7 +158,7 @@ case class GeneratedAggregate(
// The set of expressions that produce the final output given the aggregation buffer and the
// grouping expressions.
val resultExpressions = aggregateExpressions.map(_.transform {
- case e: Expression if resultMap.contains(e.id) => resultMap(e.id)
+ case e: Expression if resultMap.contains(new TreeNodeRef(e)) => resultMap(new TreeNodeRef(e))
case e: Expression if groupMap.contains(e) => groupMap(e)
})
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
index 5b896c55b7..8ff757bbe3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/debug/package.scala
@@ -23,6 +23,7 @@ import org.apache.spark.{AccumulatorParam, Accumulator, SparkContext}
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.SparkContext._
import org.apache.spark.sql.{SchemaRDD, Row}
+import org.apache.spark.sql.catalyst.trees.TreeNodeRef
/**
* :: DeveloperApi ::
@@ -43,10 +44,10 @@ package object debug {
implicit class DebugQuery(query: SchemaRDD) {
def debug(): Unit = {
val plan = query.queryExecution.executedPlan
- val visited = new collection.mutable.HashSet[Long]()
+ val visited = new collection.mutable.HashSet[TreeNodeRef]()
val debugPlan = plan transform {
- case s: SparkPlan if !visited.contains(s.id) =>
- visited += s.id
+ case s: SparkPlan if !visited.contains(new TreeNodeRef(s)) =>
+ visited += new TreeNodeRef(s)
DebugNode(s)
}
println(s"Results returned: ${debugPlan.execute().count()}")
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
index aef6ebf86b..3dc8be2456 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/pythonUdfs.scala
@@ -98,7 +98,7 @@ private[spark] object ExtractPythonUdfs extends Rule[LogicalPlan] {
logical.Project(
l.output,
l.transformExpressions {
- case p: PythonUDF if p.id == udf.id => evaluation.resultAttribute
+ case p: PythonUDF if p.fastEquals(udf) => evaluation.resultAttribute
}.withNewChildren(newChildren))
}
}