aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorReynold Xin <rxin@apache.org>2014-10-13 11:50:42 -0700
committerMichael Armbrust <michael@databricks.com>2014-10-13 11:50:42 -0700
commit39ccabacf11abdd9afc8f9895084c6707ff35c85 (patch)
treed4d7b047111879f874605e04994f8c7a4196c892 /sql
parent942847fd94c920f7954ddf01f97263926e512b0e (diff)
downloadspark-39ccabacf11abdd9afc8f9895084c6707ff35c85.tar.gz
spark-39ccabacf11abdd9afc8f9895084c6707ff35c85.tar.bz2
spark-39ccabacf11abdd9afc8f9895084c6707ff35c85.zip
[SPARK-3861][SQL] Avoid rebuilding hash tables for broadcast joins on each partition
Author: Reynold Xin <rxin@apache.org> Closes #2727 from rxin/SPARK-3861-broadcast-hash-2 and squashes the following commits: 9c7b1a2 [Reynold Xin] Revert "Reuse CompactBuffer in UniqueKeyHashedRelation." 97626a1 [Reynold Xin] Reuse CompactBuffer in UniqueKeyHashedRelation. 7fcffb5 [Reynold Xin] Make UniqueKeyHashedRelation private[joins]. 18eb214 [Reynold Xin] Merge branch 'SPARK-3861-broadcast-hash' into SPARK-3861-broadcast-hash-1 4b9d0c9 [Reynold Xin] UniqueKeyHashedRelation.get should return null if the value is null. e0ebdd1 [Reynold Xin] Added a test case. 90b58c0 [Reynold Xin] [SPARK-3861] Avoid rebuilding hash tables on each partition 0c0082b [Reynold Xin] Fix line length. cbc664c [Reynold Xin] Rename join -> joins package. a070d44 [Reynold Xin] Fix line length in HashJoin a39be8c [Reynold Xin] [SPARK-3857] Create a join package for various join operators.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala34
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala109
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala5
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala63
5 files changed, 187 insertions, 32 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
index d88ab6367a..8fd35880ee 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala
@@ -22,7 +22,7 @@ import scala.concurrent.duration._
import scala.concurrent.ExecutionContext.Implicits.global
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.sql.catalyst.expressions.Expression
+import org.apache.spark.sql.catalyst.expressions.{Row, Expression}
import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnspecifiedDistribution}
import org.apache.spark.sql.execution.{BinaryNode, SparkPlan}
@@ -49,14 +49,16 @@ case class BroadcastHashJoin(
@transient
private val broadcastFuture = future {
- sparkContext.broadcast(buildPlan.executeCollect())
+ val input: Array[Row] = buildPlan.executeCollect()
+ val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.length)
+ sparkContext.broadcast(hashed)
}
override def execute() = {
val broadcastRelation = Await.result(broadcastFuture, 5.minute)
streamedPlan.execute().mapPartitions { streamedIter =>
- joinIterators(broadcastRelation.value.iterator, streamedIter)
+ hashJoin(streamedIter, broadcastRelation.value)
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
index 472b2e6ca6..4012d757d5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.execution.joins
-import org.apache.spark.sql.catalyst.expressions.{Expression, JoinedRow2, Row}
+import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.execution.SparkPlan
import org.apache.spark.util.collection.CompactBuffer
@@ -43,34 +43,14 @@ trait HashJoin {
override def output = left.output ++ right.output
- @transient protected lazy val buildSideKeyGenerator = newProjection(buildKeys, buildPlan.output)
- @transient protected lazy val streamSideKeyGenerator =
+ @transient protected lazy val buildSideKeyGenerator: Projection =
+ newProjection(buildKeys, buildPlan.output)
+
+ @transient protected lazy val streamSideKeyGenerator: () => MutableProjection =
newMutableProjection(streamedKeys, streamedPlan.output)
- protected def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] =
+ protected def hashJoin(streamIter: Iterator[Row], hashedRelation: HashedRelation): Iterator[Row] =
{
- // TODO: Use Spark's HashMap implementation.
-
- val hashTable = new java.util.HashMap[Row, CompactBuffer[Row]]()
- var currentRow: Row = null
-
- // Create a mapping of buildKeys -> rows
- while (buildIter.hasNext) {
- currentRow = buildIter.next()
- val rowKey = buildSideKeyGenerator(currentRow)
- if (!rowKey.anyNull) {
- val existingMatchList = hashTable.get(rowKey)
- val matchList = if (existingMatchList == null) {
- val newMatchList = new CompactBuffer[Row]()
- hashTable.put(rowKey, newMatchList)
- newMatchList
- } else {
- existingMatchList
- }
- matchList += currentRow.copy()
- }
- }
-
new Iterator[Row] {
private[this] var currentStreamedRow: Row = _
private[this] var currentHashMatches: CompactBuffer[Row] = _
@@ -107,7 +87,7 @@ trait HashJoin {
while (currentHashMatches == null && streamIter.hasNext) {
currentStreamedRow = streamIter.next()
if (!joinKeys(currentStreamedRow).anyNull) {
- currentHashMatches = hashTable.get(joinKeys.currentValue)
+ currentHashMatches = hashedRelation.get(joinKeys.currentValue)
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
new file mode 100644
index 0000000000..38b8993b03
--- /dev/null
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import java.util.{HashMap => JavaHashMap}
+
+import org.apache.spark.sql.catalyst.expressions.{Projection, Row}
+import org.apache.spark.util.collection.CompactBuffer
+
+
+/**
+ * Interface for a hashed relation by some key. Use [[HashedRelation.apply]] to create a concrete
+ * object.
+ */
+private[joins] sealed trait HashedRelation {
+ def get(key: Row): CompactBuffer[Row]
+}
+
+
+/**
+ * A general [[HashedRelation]] backed by a hash map that maps the key into a sequence of values.
+ */
+private[joins] final class GeneralHashedRelation(hashTable: JavaHashMap[Row, CompactBuffer[Row]])
+ extends HashedRelation with Serializable {
+
+ override def get(key: Row) = hashTable.get(key)
+}
+
+
+/**
+ * A specialized [[HashedRelation]] that maps key into a single value. This implementation
+ * assumes the key is unique.
+ */
+private[joins] final class UniqueKeyHashedRelation(hashTable: JavaHashMap[Row, Row])
+ extends HashedRelation with Serializable {
+
+ override def get(key: Row) = {
+ val v = hashTable.get(key)
+ if (v eq null) null else CompactBuffer(v)
+ }
+
+ def getValue(key: Row): Row = hashTable.get(key)
+}
+
+
+// TODO(rxin): a version of [[HashedRelation]] backed by arrays for consecutive integer keys.
+
+
+private[joins] object HashedRelation {
+
+ def apply(
+ input: Iterator[Row],
+ keyGenerator: Projection,
+ sizeEstimate: Int = 64): HashedRelation = {
+
+ // TODO: Use Spark's HashMap implementation.
+ val hashTable = new JavaHashMap[Row, CompactBuffer[Row]](sizeEstimate)
+ var currentRow: Row = null
+
+ // Whether the join key is unique. If the key is unique, we can convert the underlying
+ // hash map into one specialized for this.
+ var keyIsUnique = true
+
+ // Create a mapping of buildKeys -> rows
+ while (input.hasNext) {
+ currentRow = input.next()
+ val rowKey = keyGenerator(currentRow)
+ if (!rowKey.anyNull) {
+ val existingMatchList = hashTable.get(rowKey)
+ val matchList = if (existingMatchList == null) {
+ val newMatchList = new CompactBuffer[Row]()
+ hashTable.put(rowKey, newMatchList)
+ newMatchList
+ } else {
+ keyIsUnique = false
+ existingMatchList
+ }
+ matchList += currentRow.copy()
+ }
+ }
+
+ if (keyIsUnique) {
+ val uniqHashTable = new JavaHashMap[Row, Row](hashTable.size)
+ val iter = hashTable.entrySet().iterator()
+ while (iter.hasNext) {
+ val entry = iter.next()
+ uniqHashTable.put(entry.getKey, entry.getValue()(0))
+ }
+ new UniqueKeyHashedRelation(uniqHashTable)
+ } else {
+ new GeneralHashedRelation(hashTable)
+ }
+ }
+}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
index 8247304c1d..418c1c23e5 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala
@@ -42,8 +42,9 @@ case class ShuffledHashJoin(
ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil
override def execute() = {
- buildPlan.execute().zipPartitions(streamedPlan.execute()) {
- (buildIter, streamIter) => joinIterators(buildIter, streamIter)
+ buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) =>
+ val hashed = HashedRelation(buildIter, buildSideKeyGenerator)
+ hashJoin(streamIter, hashed)
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
new file mode 100644
index 0000000000..2aad01ded1
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala
@@ -0,0 +1,63 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.joins
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.sql.catalyst.expressions.{Projection, Row}
+import org.apache.spark.util.collection.CompactBuffer
+
+
+class HashedRelationSuite extends FunSuite {
+
+ // Key is simply the record itself
+ private val keyProjection = new Projection {
+ override def apply(row: Row): Row = row
+ }
+
+ test("GeneralHashedRelation") {
+ val data = Array(Row(0), Row(1), Row(2), Row(2))
+ val hashed = HashedRelation(data.iterator, keyProjection)
+ assert(hashed.isInstanceOf[GeneralHashedRelation])
+
+ assert(hashed.get(data(0)) == CompactBuffer[Row](data(0)))
+ assert(hashed.get(data(1)) == CompactBuffer[Row](data(1)))
+ assert(hashed.get(Row(10)) === null)
+
+ val data2 = CompactBuffer[Row](data(2))
+ data2 += data(2)
+ assert(hashed.get(data(2)) == data2)
+ }
+
+ test("UniqueKeyHashedRelation") {
+ val data = Array(Row(0), Row(1), Row(2))
+ val hashed = HashedRelation(data.iterator, keyProjection)
+ assert(hashed.isInstanceOf[UniqueKeyHashedRelation])
+
+ assert(hashed.get(data(0)) == CompactBuffer[Row](data(0)))
+ assert(hashed.get(data(1)) == CompactBuffer[Row](data(1)))
+ assert(hashed.get(data(2)) == CompactBuffer[Row](data(2)))
+ assert(hashed.get(Row(10)) === null)
+
+ val uniqHashed = hashed.asInstanceOf[UniqueKeyHashedRelation]
+ assert(uniqHashed.getValue(data(0)) == data(0))
+ assert(uniqHashed.getValue(data(1)) == data(1))
+ assert(uniqHashed.getValue(data(2)) == data(2))
+ assert(uniqHashed.getValue(Row(10)) == null)
+ }
+}