From 39ccabacf11abdd9afc8f9895084c6707ff35c85 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 13 Oct 2014 11:50:42 -0700 Subject: [SPARK-3861][SQL] Avoid rebuilding hash tables for broadcast joins on each partition Author: Reynold Xin Closes #2727 from rxin/SPARK-3861-broadcast-hash-2 and squashes the following commits: 9c7b1a2 [Reynold Xin] Revert "Reuse CompactBuffer in UniqueKeyHashedRelation." 97626a1 [Reynold Xin] Reuse CompactBuffer in UniqueKeyHashedRelation. 7fcffb5 [Reynold Xin] Make UniqueKeyHashedRelation private[joins]. 18eb214 [Reynold Xin] Merge branch 'SPARK-3861-broadcast-hash' into SPARK-3861-broadcast-hash-1 4b9d0c9 [Reynold Xin] UniqueKeyHashedRelation.get should return null if the value is null. e0ebdd1 [Reynold Xin] Added a test case. 90b58c0 [Reynold Xin] [SPARK-3861] Avoid rebuilding hash tables on each partition 0c0082b [Reynold Xin] Fix line length. cbc664c [Reynold Xin] Rename join -> joins package. a070d44 [Reynold Xin] Fix line length in HashJoin a39be8c [Reynold Xin] [SPARK-3857] Create a join package for various join operators. --- .../sql/execution/joins/BroadcastHashJoin.scala | 8 +- .../spark/sql/execution/joins/HashJoin.scala | 34 ++----- .../spark/sql/execution/joins/HashedRelation.scala | 109 +++++++++++++++++++++ .../sql/execution/joins/ShuffledHashJoin.scala | 5 +- .../sql/execution/joins/HashedRelationSuite.scala | 63 ++++++++++++ 5 files changed, 187 insertions(+), 32 deletions(-) create mode 100644 sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala create mode 100644 sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala (limited to 'sql') diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala index d88ab6367a..8fd35880ee 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/BroadcastHashJoin.scala @@ -22,7 +22,7 @@ import scala.concurrent.duration._ import scala.concurrent.ExecutionContext.Implicits.global import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.{Row, Expression} import org.apache.spark.sql.catalyst.plans.physical.{Partitioning, UnspecifiedDistribution} import org.apache.spark.sql.execution.{BinaryNode, SparkPlan} @@ -49,14 +49,16 @@ case class BroadcastHashJoin( @transient private val broadcastFuture = future { - sparkContext.broadcast(buildPlan.executeCollect()) + val input: Array[Row] = buildPlan.executeCollect() + val hashed = HashedRelation(input.iterator, buildSideKeyGenerator, input.length) + sparkContext.broadcast(hashed) } override def execute() = { val broadcastRelation = Await.result(broadcastFuture, 5.minute) streamedPlan.execute().mapPartitions { streamedIter => - joinIterators(broadcastRelation.value.iterator, streamedIter) + hashJoin(streamedIter, broadcastRelation.value) } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala index 472b2e6ca6..4012d757d5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashJoin.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution.joins -import org.apache.spark.sql.catalyst.expressions.{Expression, JoinedRow2, Row} +import org.apache.spark.sql.catalyst.expressions._ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.util.collection.CompactBuffer @@ -43,34 +43,14 @@ trait HashJoin { override def output = left.output ++ right.output - @transient protected lazy val buildSideKeyGenerator = newProjection(buildKeys, buildPlan.output) - @transient protected lazy val streamSideKeyGenerator = + @transient protected lazy val buildSideKeyGenerator: Projection = + newProjection(buildKeys, buildPlan.output) + + @transient protected lazy val streamSideKeyGenerator: () => MutableProjection = newMutableProjection(streamedKeys, streamedPlan.output) - protected def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = + protected def hashJoin(streamIter: Iterator[Row], hashedRelation: HashedRelation): Iterator[Row] = { - // TODO: Use Spark's HashMap implementation. - - val hashTable = new java.util.HashMap[Row, CompactBuffer[Row]]() - var currentRow: Row = null - - // Create a mapping of buildKeys -> rows - while (buildIter.hasNext) { - currentRow = buildIter.next() - val rowKey = buildSideKeyGenerator(currentRow) - if (!rowKey.anyNull) { - val existingMatchList = hashTable.get(rowKey) - val matchList = if (existingMatchList == null) { - val newMatchList = new CompactBuffer[Row]() - hashTable.put(rowKey, newMatchList) - newMatchList - } else { - existingMatchList - } - matchList += currentRow.copy() - } - } - new Iterator[Row] { private[this] var currentStreamedRow: Row = _ private[this] var currentHashMatches: CompactBuffer[Row] = _ @@ -107,7 +87,7 @@ trait HashJoin { while (currentHashMatches == null && streamIter.hasNext) { currentStreamedRow = streamIter.next() if (!joinKeys(currentStreamedRow).anyNull) { - currentHashMatches = hashTable.get(joinKeys.currentValue) + currentHashMatches = hashedRelation.get(joinKeys.currentValue) } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala new file mode 100644 index 0000000000..38b8993b03 --- /dev/null +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashedRelation.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import java.util.{HashMap => JavaHashMap} + +import org.apache.spark.sql.catalyst.expressions.{Projection, Row} +import org.apache.spark.util.collection.CompactBuffer + + +/** + * Interface for a hashed relation by some key. Use [[HashedRelation.apply]] to create a concrete + * object. + */ +private[joins] sealed trait HashedRelation { + def get(key: Row): CompactBuffer[Row] +} + + +/** + * A general [[HashedRelation]] backed by a hash map that maps the key into a sequence of values. + */ +private[joins] final class GeneralHashedRelation(hashTable: JavaHashMap[Row, CompactBuffer[Row]]) + extends HashedRelation with Serializable { + + override def get(key: Row) = hashTable.get(key) +} + + +/** + * A specialized [[HashedRelation]] that maps key into a single value. This implementation + * assumes the key is unique. + */ +private[joins] final class UniqueKeyHashedRelation(hashTable: JavaHashMap[Row, Row]) + extends HashedRelation with Serializable { + + override def get(key: Row) = { + val v = hashTable.get(key) + if (v eq null) null else CompactBuffer(v) + } + + def getValue(key: Row): Row = hashTable.get(key) +} + + +// TODO(rxin): a version of [[HashedRelation]] backed by arrays for consecutive integer keys. + + +private[joins] object HashedRelation { + + def apply( + input: Iterator[Row], + keyGenerator: Projection, + sizeEstimate: Int = 64): HashedRelation = { + + // TODO: Use Spark's HashMap implementation. + val hashTable = new JavaHashMap[Row, CompactBuffer[Row]](sizeEstimate) + var currentRow: Row = null + + // Whether the join key is unique. If the key is unique, we can convert the underlying + // hash map into one specialized for this. + var keyIsUnique = true + + // Create a mapping of buildKeys -> rows + while (input.hasNext) { + currentRow = input.next() + val rowKey = keyGenerator(currentRow) + if (!rowKey.anyNull) { + val existingMatchList = hashTable.get(rowKey) + val matchList = if (existingMatchList == null) { + val newMatchList = new CompactBuffer[Row]() + hashTable.put(rowKey, newMatchList) + newMatchList + } else { + keyIsUnique = false + existingMatchList + } + matchList += currentRow.copy() + } + } + + if (keyIsUnique) { + val uniqHashTable = new JavaHashMap[Row, Row](hashTable.size) + val iter = hashTable.entrySet().iterator() + while (iter.hasNext) { + val entry = iter.next() + uniqHashTable.put(entry.getKey, entry.getValue()(0)) + } + new UniqueKeyHashedRelation(uniqHashTable) + } else { + new GeneralHashedRelation(hashTable) + } + } +} diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala index 8247304c1d..418c1c23e5 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/ShuffledHashJoin.scala @@ -42,8 +42,9 @@ case class ShuffledHashJoin( ClusteredDistribution(leftKeys) :: ClusteredDistribution(rightKeys) :: Nil override def execute() = { - buildPlan.execute().zipPartitions(streamedPlan.execute()) { - (buildIter, streamIter) => joinIterators(buildIter, streamIter) + buildPlan.execute().zipPartitions(streamedPlan.execute()) { (buildIter, streamIter) => + val hashed = HashedRelation(buildIter, buildSideKeyGenerator) + hashJoin(streamIter, hashed) } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala new file mode 100644 index 0000000000..2aad01ded1 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/joins/HashedRelationSuite.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.joins + +import org.scalatest.FunSuite + +import org.apache.spark.sql.catalyst.expressions.{Projection, Row} +import org.apache.spark.util.collection.CompactBuffer + + +class HashedRelationSuite extends FunSuite { + + // Key is simply the record itself + private val keyProjection = new Projection { + override def apply(row: Row): Row = row + } + + test("GeneralHashedRelation") { + val data = Array(Row(0), Row(1), Row(2), Row(2)) + val hashed = HashedRelation(data.iterator, keyProjection) + assert(hashed.isInstanceOf[GeneralHashedRelation]) + + assert(hashed.get(data(0)) == CompactBuffer[Row](data(0))) + assert(hashed.get(data(1)) == CompactBuffer[Row](data(1))) + assert(hashed.get(Row(10)) === null) + + val data2 = CompactBuffer[Row](data(2)) + data2 += data(2) + assert(hashed.get(data(2)) == data2) + } + + test("UniqueKeyHashedRelation") { + val data = Array(Row(0), Row(1), Row(2)) + val hashed = HashedRelation(data.iterator, keyProjection) + assert(hashed.isInstanceOf[UniqueKeyHashedRelation]) + + assert(hashed.get(data(0)) == CompactBuffer[Row](data(0))) + assert(hashed.get(data(1)) == CompactBuffer[Row](data(1))) + assert(hashed.get(data(2)) == CompactBuffer[Row](data(2))) + assert(hashed.get(Row(10)) === null) + + val uniqHashed = hashed.asInstanceOf[UniqueKeyHashedRelation] + assert(uniqHashed.getValue(data(0)) == data(0)) + assert(uniqHashed.getValue(data(1)) == data(1)) + assert(uniqHashed.getValue(data(2)) == data(2)) + assert(uniqHashed.getValue(Row(10)) == null) + } +} -- cgit v1.2.3