aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorYin Huai <yhuai@databricks.com>2015-10-08 11:56:44 -0700
committerAndrew Or <andrew@databricks.com>2015-10-08 11:56:44 -0700
commit82d275f27c3e9211ce69c5c8685a0fe90c0be26f (patch)
treea51a68a14568e2022ada5efd9a568687b4be3c73 /sql
parent2a6f614cd6ffb0cc32460018cb13dad2fd94520f (diff)
downloadspark-82d275f27c3e9211ce69c5c8685a0fe90c0be26f.tar.gz
spark-82d275f27c3e9211ce69c5c8685a0fe90c0be26f.tar.bz2
spark-82d275f27c3e9211ce69c5c8685a0fe90c0be26f.zip
[SPARK-10887] [SQL] Build HashedRelation outside of HashJoinNode.
This PR refactors `HashJoinNode` to take a existing `HashedRelation`. So, we can reuse this node for both `ShuffledHashJoin` and `BroadcastHashJoin`. https://issues.apache.org/jira/browse/SPARK-10887 Author: Yin Huai <yhuai@databricks.com> Closes #8953 from yhuai/SPARK-10887.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala2
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala76
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/local/BroadcastHashJoinNode.scala59
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala67
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala85
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala20
7 files changed, 262 insertions, 51 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
index d82d19185b..e8ee64756d 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala
@@ -27,6 +27,8 @@ abstract class BaseMutableProjection extends MutableProjection
/**
* Generates byte code that produces a [[MutableRow]] object that can update itself based on a new
* input [[InternalRow]] for a fixed set of [[Expression Expressions]].
+ * It exposes a `target` method, which is used to set the row that will be updated.
+ * The internal [[MutableRow]] object created internally is used only when `target` is not used.
*/
object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => MutableProjection] {
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
index ea09e029da..9873630937 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateSafeProjection.scala
@@ -23,8 +23,8 @@ import org.apache.spark.sql.types._
/**
- * Generates byte code that produces a [[MutableRow]] object that can update itself based on a new
- * input [[InternalRow]] for a fixed set of [[Expression Expressions]].
+ * Generates byte code that produces a [[MutableRow]] object (not an [[UnsafeRow]]) that can update
+ * itself based on a new input [[InternalRow]] for a fixed set of [[Expression Expressions]].
*/
object GenerateSafeProjection extends CodeGenerator[Seq[Expression], Projection] {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala
new file mode 100644
index 0000000000..52dcb9e43c
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BinaryHashJoinNode.scala
@@ -0,0 +1,76 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.local
+
+import org.apache.spark.sql.SQLConf
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.joins.{HashedRelation, BuildLeft, BuildRight, BuildSide}
+
+/**
+ * A [[HashJoinNode]] that builds the [[HashedRelation]] according to the value of
+ * `buildSide`. The actual work of this node is defined in [[HashJoinNode]].
+ */
+case class BinaryHashJoinNode(
+ conf: SQLConf,
+ leftKeys: Seq[Expression],
+ rightKeys: Seq[Expression],
+ buildSide: BuildSide,
+ left: LocalNode,
+ right: LocalNode)
+ extends BinaryLocalNode(conf) with HashJoinNode {
+
+ protected override val (streamedNode, streamedKeys) = buildSide match {
+ case BuildLeft => (right, rightKeys)
+ case BuildRight => (left, leftKeys)
+ }
+
+ private val (buildNode, buildKeys) = buildSide match {
+ case BuildLeft => (left, leftKeys)
+ case BuildRight => (right, rightKeys)
+ }
+
+ override def output: Seq[Attribute] = left.output ++ right.output
+
+ private def buildSideKeyGenerator: Projection = {
+ // We are expecting the data types of buildKeys and streamedKeys are the same.
+ assert(buildKeys.map(_.dataType) == streamedKeys.map(_.dataType))
+ if (isUnsafeMode) {
+ UnsafeProjection.create(buildKeys, buildNode.output)
+ } else {
+ newMutableProjection(buildKeys, buildNode.output)()
+ }
+ }
+
+ protected override def doOpen(): Unit = {
+ buildNode.open()
+ val hashedRelation = HashedRelation(buildNode, buildSideKeyGenerator)
+ // We have built the HashedRelation. So, close buildNode.
+ buildNode.close()
+
+ streamedNode.open()
+ // Set the HashedRelation used by the HashJoinNode.
+ withHashedRelation(hashedRelation)
+ }
+
+ override def close(): Unit = {
+ // Please note that we do not need to call the close method of our buildNode because
+ // it has been called in this.open.
+ streamedNode.close()
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BroadcastHashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BroadcastHashJoinNode.scala
new file mode 100644
index 0000000000..cd1c86516e
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/BroadcastHashJoinNode.scala
@@ -0,0 +1,59 @@
+/*
+* Licensed to the Apache Software Foundation (ASF) under one or more
+* contributor license agreements. See the NOTICE file distributed with
+* this work for additional information regarding copyright ownership.
+* The ASF licenses this file to You under the Apache License, Version 2.0
+* (the "License"); you may not use this file except in compliance with
+* the License. You may obtain a copy of the License at
+*
+* http://www.apache.org/licenses/LICENSE-2.0
+*
+* Unless required by applicable law or agreed to in writing, software
+* distributed under the License is distributed on an "AS IS" BASIS,
+* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+* See the License for the specific language governing permissions and
+* limitations under the License.
+*/
+
+package org.apache.spark.sql.execution.local
+
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.sql.SQLConf
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide, HashedRelation}
+
+/**
+ * A [[HashJoinNode]] for broadcast join. It takes a streamedNode and a broadcast
+ * [[HashedRelation]]. The actual work of this node is defined in [[HashJoinNode]].
+ */
+case class BroadcastHashJoinNode(
+ conf: SQLConf,
+ streamedKeys: Seq[Expression],
+ streamedNode: LocalNode,
+ buildSide: BuildSide,
+ buildOutput: Seq[Attribute],
+ hashedRelation: Broadcast[HashedRelation])
+ extends UnaryLocalNode(conf) with HashJoinNode {
+
+ override val child = streamedNode
+
+ // Because we do not pass in the buildNode, we take the output of buildNode to
+ // create the inputSet properly.
+ override def inputSet: AttributeSet = AttributeSet(child.output ++ buildOutput)
+
+ override def output: Seq[Attribute] = buildSide match {
+ case BuildRight => streamedNode.output ++ buildOutput
+ case BuildLeft => buildOutput ++ streamedNode.output
+ }
+
+ protected override def doOpen(): Unit = {
+ streamedNode.open()
+ // Set the HashedRelation used by the HashJoinNode.
+ withHashedRelation(hashedRelation.value)
+ }
+
+ override def close(): Unit = {
+ streamedNode.close()
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala
index e7b24e3fca..b1dc719ca8 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/local/HashJoinNode.scala
@@ -17,27 +17,23 @@
package org.apache.spark.sql.execution.local
-import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.joins._
-import org.apache.spark.sql.execution.metric.SQLMetrics
/**
+ * An abstract node for sharing common functionality among different implementations of
+ * inner hash equi-join, notably [[BinaryHashJoinNode]] and [[BroadcastHashJoinNode]].
+ *
* Much of this code is similar to [[org.apache.spark.sql.execution.joins.HashJoin]].
*/
-case class HashJoinNode(
- conf: SQLConf,
- leftKeys: Seq[Expression],
- rightKeys: Seq[Expression],
- buildSide: BuildSide,
- left: LocalNode,
- right: LocalNode) extends BinaryLocalNode(conf) {
-
- private[this] lazy val (buildNode, buildKeys, streamedNode, streamedKeys) = buildSide match {
- case BuildLeft => (left, leftKeys, right, rightKeys)
- case BuildRight => (right, rightKeys, left, leftKeys)
- }
+trait HashJoinNode {
+
+ self: LocalNode =>
+
+ protected def streamedKeys: Seq[Expression]
+ protected def streamedNode: LocalNode
+ protected def buildSide: BuildSide
private[this] var currentStreamedRow: InternalRow = _
private[this] var currentHashMatches: Seq[InternalRow] = _
@@ -49,23 +45,14 @@ case class HashJoinNode(
private[this] var hashed: HashedRelation = _
private[this] var joinKeys: Projection = _
- override def output: Seq[Attribute] = left.output ++ right.output
-
- private[this] def isUnsafeMode: Boolean = {
- (codegenEnabled && unsafeEnabled
- && UnsafeProjection.canSupport(buildKeys)
- && UnsafeProjection.canSupport(schema))
- }
-
- private[this] def buildSideKeyGenerator: Projection = {
- if (isUnsafeMode) {
- UnsafeProjection.create(buildKeys, buildNode.output)
- } else {
- newMutableProjection(buildKeys, buildNode.output)()
- }
+ protected def isUnsafeMode: Boolean = {
+ (codegenEnabled &&
+ unsafeEnabled &&
+ UnsafeProjection.canSupport(schema) &&
+ UnsafeProjection.canSupport(streamedKeys))
}
- private[this] def streamSideKeyGenerator: Projection = {
+ private def streamSideKeyGenerator: Projection = {
if (isUnsafeMode) {
UnsafeProjection.create(streamedKeys, streamedNode.output)
} else {
@@ -73,10 +60,21 @@ case class HashJoinNode(
}
}
+ /**
+ * Sets the HashedRelation used by this node. This method needs to be called after
+ * before the first `next` gets called.
+ */
+ protected def withHashedRelation(hashedRelation: HashedRelation): Unit = {
+ hashed = hashedRelation
+ }
+
+ /**
+ * Custom open implementation to be overridden by subclasses.
+ */
+ protected def doOpen(): Unit
+
override def open(): Unit = {
- buildNode.open()
- hashed = HashedRelation(buildNode, buildSideKeyGenerator)
- streamedNode.open()
+ doOpen()
joinRow = new JoinedRow
resultProjection = {
if (isUnsafeMode) {
@@ -128,9 +126,4 @@ case class HashJoinNode(
}
resultProjection(ret)
}
-
- override def close(): Unit = {
- left.close()
- right.close()
- }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala
index 5c1bdb088e..8c2e78b2a9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/HashJoinNodeSuite.scala
@@ -17,10 +17,13 @@
package org.apache.spark.sql.execution.local
+import org.mockito.Mockito.{mock, when}
+
+import org.apache.spark.broadcast.TorrentBroadcast
import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.dsl.expressions._
-import org.apache.spark.sql.execution.joins.{BuildLeft, BuildRight, BuildSide}
-
+import org.apache.spark.sql.catalyst.expressions.{InterpretedMutableProjection, UnsafeProjection, Expression}
+import org.apache.spark.sql.execution.joins.{HashedRelation, BuildLeft, BuildRight, BuildSide}
class HashJoinNodeSuite extends LocalNodeTest {
@@ -34,6 +37,35 @@ class HashJoinNodeSuite extends LocalNodeTest {
}
/**
+ * Builds a [[HashedRelation]] based on a resolved `buildKeys`
+ * and a resolved `buildNode`.
+ */
+ private def buildHashedRelation(
+ conf: SQLConf,
+ buildKeys: Seq[Expression],
+ buildNode: LocalNode): HashedRelation = {
+
+ val isUnsafeMode =
+ conf.codegenEnabled &&
+ conf.unsafeEnabled &&
+ UnsafeProjection.canSupport(buildKeys)
+
+ val buildSideKeyGenerator =
+ if (isUnsafeMode) {
+ UnsafeProjection.create(buildKeys, buildNode.output)
+ } else {
+ new InterpretedMutableProjection(buildKeys, buildNode.output)
+ }
+
+ buildNode.prepare()
+ buildNode.open()
+ val hashedRelation = HashedRelation(buildNode, buildSideKeyGenerator)
+ buildNode.close()
+
+ hashedRelation
+ }
+
+ /**
* Test inner hash join with varying degrees of matches.
*/
private def testJoin(
@@ -51,20 +83,51 @@ class HashJoinNodeSuite extends LocalNodeTest {
val rightInputMap = rightInput.toMap
val leftNode = new DummyNode(joinNameAttributes, leftInput)
val rightNode = new DummyNode(joinNicknameAttributes, rightInput)
- val makeNode = (node1: LocalNode, node2: LocalNode) => {
- resolveExpressions(new HashJoinNode(
- conf, Seq('id1), Seq('id2), buildSide, node1, node2))
+ val makeBinaryHashJoinNode = (node1: LocalNode, node2: LocalNode) => {
+ val binaryHashJoinNode =
+ BinaryHashJoinNode(conf, Seq('id1), Seq('id2), buildSide, node1, node2)
+ resolveExpressions(binaryHashJoinNode)
+ }
+ val makeBroadcastJoinNode = (node1: LocalNode, node2: LocalNode) => {
+ val leftKeys = Seq('id1.attr)
+ val rightKeys = Seq('id2.attr)
+ // Figure out the build side and stream side.
+ val (buildNode, buildKeys, streamedNode, streamedKeys) = buildSide match {
+ case BuildLeft => (node1, leftKeys, node2, rightKeys)
+ case BuildRight => (node2, rightKeys, node1, leftKeys)
+ }
+ // Resolve the expressions of the build side and then create a HashedRelation.
+ val resolvedBuildNode = resolveExpressions(buildNode)
+ val resolvedBuildKeys = resolveExpressions(buildKeys, resolvedBuildNode)
+ val hashedRelation = buildHashedRelation(conf, resolvedBuildKeys, resolvedBuildNode)
+ val broadcastHashedRelation = mock(classOf[TorrentBroadcast[HashedRelation]])
+ when(broadcastHashedRelation.value).thenReturn(hashedRelation)
+
+ val hashJoinNode =
+ BroadcastHashJoinNode(
+ conf,
+ streamedKeys,
+ streamedNode,
+ buildSide,
+ resolvedBuildNode.output,
+ broadcastHashedRelation)
+ resolveExpressions(hashJoinNode)
}
- val makeUnsafeNode = if (unsafeAndCodegen) wrapForUnsafe(makeNode) else makeNode
- val hashJoinNode = makeUnsafeNode(leftNode, rightNode)
+
val expectedOutput = leftInput
.filter { case (k, _) => rightInputMap.contains(k) }
.map { case (k, v) => (k, v, k, rightInputMap(k)) }
- val actualOutput = hashJoinNode.collect().map { row =>
- // (id, name, id, nickname)
- (row.getInt(0), row.getString(1), row.getInt(2), row.getString(3))
+
+ Seq(makeBinaryHashJoinNode, makeBroadcastJoinNode).foreach { makeNode =>
+ val makeUnsafeNode = if (unsafeAndCodegen) wrapForUnsafe(makeNode) else makeNode
+ val hashJoinNode = makeUnsafeNode(leftNode, rightNode)
+
+ val actualOutput = hashJoinNode.collect().map { row =>
+ // (id, name, id, nickname)
+ (row.getInt(0), row.getString(1), row.getInt(2), row.getString(3))
+ }
+ assert(actualOutput === expectedOutput)
}
- assert(actualOutput === expectedOutput)
}
test(s"$testNamePrefix: empty") {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala
index 098050bcd2..615c417093 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/local/LocalNodeTest.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.execution.local
import org.apache.spark.SparkFunSuite
import org.apache.spark.sql.SQLConf
import org.apache.spark.sql.catalyst.analysis.UnresolvedAttribute
-import org.apache.spark.sql.catalyst.expressions.AttributeReference
+import org.apache.spark.sql.catalyst.expressions.{Expression, AttributeReference}
import org.apache.spark.sql.types.{IntegerType, StringType}
@@ -67,4 +67,22 @@ class LocalNodeTest extends SparkFunSuite {
}
}
+ /**
+ * Resolve all expressions in `expressions` based on the `output` of `localNode`.
+ * It assumes that all expressions in the `localNode` are resolved.
+ */
+ protected def resolveExpressions(
+ expressions: Seq[Expression],
+ localNode: LocalNode): Seq[Expression] = {
+ require(localNode.expressions.forall(_.resolved))
+ val inputMap = localNode.output.map { a => (a.name, a) }.toMap
+ expressions.map { expression =>
+ expression.transformUp {
+ case UnresolvedAttribute(Seq(u)) =>
+ inputMap.getOrElse(u,
+ sys.error(s"Invalid Test: Cannot resolve $u given input $inputMap"))
+ }
+ }
+ }
+
}