aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatei Zaharia <matei@databricks.com>2014-08-18 10:45:24 -0700
committerMichael Armbrust <michael@databricks.com>2014-08-18 10:45:24 -0700
commit4bf3de71074053af94f077c99e9c65a1962739e1 (patch)
treea7bc1e6385a6b4ff28da44dfa8e414b537d5b626
parent6a13dca12fac06f3af892ffcc8922cc84f91b786 (diff)
downloadspark-4bf3de71074053af94f077c99e9c65a1962739e1.tar.gz
spark-4bf3de71074053af94f077c99e9c65a1962739e1.tar.bz2
spark-4bf3de71074053af94f077c99e9c65a1962739e1.zip
[SPARK-3085] [SQL] Use compact data structures in SQL joins
This reuses the CompactBuffer from Spark Core to save memory and pointer dereferences. I also tried AppendOnlyMap instead of java.util.HashMap but unfortunately that slows things down because it seems to do more equals() calls and the equals on GenericRow, and especially JoinedRow, is pretty expensive. Author: Matei Zaharia <matei@databricks.com> Closes #1993 from mateiz/spark-3085 and squashes the following commits: 188221e [Matei Zaharia] Remove unneeded import 5f903ee [Matei Zaharia] [SPARK-3085] [SQL] Use compact data structures in SQL joins
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala67
1 files changed, 33 insertions, 34 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
index 481bb8c05e..b08f9aacc1 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins.scala
@@ -19,16 +19,15 @@ package org.apache.spark.sql.execution
import java.util.{HashMap => JavaHashMap}
-import scala.collection.mutable.{ArrayBuffer, BitSet}
import scala.concurrent.ExecutionContext.Implicits.global
import scala.concurrent._
import scala.concurrent.duration._
import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.sql.SQLContext
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.plans._
import org.apache.spark.sql.catalyst.plans.physical._
+import org.apache.spark.util.collection.CompactBuffer
@DeveloperApi
sealed abstract class BuildSide
@@ -67,7 +66,7 @@ trait HashJoin {
def joinIterators(buildIter: Iterator[Row], streamIter: Iterator[Row]): Iterator[Row] = {
// TODO: Use Spark's HashMap implementation.
- val hashTable = new java.util.HashMap[Row, ArrayBuffer[Row]]()
+ val hashTable = new java.util.HashMap[Row, CompactBuffer[Row]]()
var currentRow: Row = null
// Create a mapping of buildKeys -> rows
@@ -77,7 +76,7 @@ trait HashJoin {
if (!rowKey.anyNull) {
val existingMatchList = hashTable.get(rowKey)
val matchList = if (existingMatchList == null) {
- val newMatchList = new ArrayBuffer[Row]()
+ val newMatchList = new CompactBuffer[Row]()
hashTable.put(rowKey, newMatchList)
newMatchList
} else {
@@ -89,7 +88,7 @@ trait HashJoin {
new Iterator[Row] {
private[this] var currentStreamedRow: Row = _
- private[this] var currentHashMatches: ArrayBuffer[Row] = _
+ private[this] var currentHashMatches: CompactBuffer[Row] = _
private[this] var currentMatchPosition: Int = -1
// Mutable per row objects.
@@ -140,7 +139,7 @@ trait HashJoin {
/**
* :: DeveloperApi ::
- * Performs a hash based outer join for two child relations by shuffling the data using
+ * Performs a hash based outer join for two child relations by shuffling the data using
* the join keys. This operator requires loading the associated partition in both side into memory.
*/
@DeveloperApi
@@ -179,26 +178,26 @@ case class HashOuterJoin(
@transient private[this] lazy val EMPTY_LIST = Seq.empty[Row]
// TODO we need to rewrite all of the iterators with our own implementation instead of the Scala
- // iterator for performance purpose.
+ // iterator for performance purpose.
private[this] def leftOuterIterator(
key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = {
val joinedRow = new JoinedRow()
val rightNullRow = new GenericRow(right.output.length)
- val boundCondition =
+ val boundCondition =
condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)
- leftIter.iterator.flatMap { l =>
+ leftIter.iterator.flatMap { l =>
joinedRow.withLeft(l)
var matched = false
- (if (!key.anyNull) rightIter.collect { case r if (boundCondition(joinedRow.withRight(r))) =>
+ (if (!key.anyNull) rightIter.collect { case r if (boundCondition(joinedRow.withRight(r))) =>
matched = true
joinedRow.copy
} else {
Nil
}) ++ DUMMY_LIST.filter(_ => !matched).map( _ => {
// DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row,
- // as we don't know whether we need to append it until finish iterating all of the
+ // as we don't know whether we need to append it until finish iterating all of the
// records in right side.
// If we didn't get any proper row, then append a single row with empty right
joinedRow.withRight(rightNullRow).copy
@@ -210,20 +209,20 @@ case class HashOuterJoin(
key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = {
val joinedRow = new JoinedRow()
val leftNullRow = new GenericRow(left.output.length)
- val boundCondition =
+ val boundCondition =
condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)
- rightIter.iterator.flatMap { r =>
+ rightIter.iterator.flatMap { r =>
joinedRow.withRight(r)
var matched = false
- (if (!key.anyNull) leftIter.collect { case l if (boundCondition(joinedRow.withLeft(l))) =>
+ (if (!key.anyNull) leftIter.collect { case l if (boundCondition(joinedRow.withLeft(l))) =>
matched = true
joinedRow.copy
} else {
Nil
}) ++ DUMMY_LIST.filter(_ => !matched).map( _ => {
// DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row,
- // as we don't know whether we need to append it until finish iterating all of the
+ // as we don't know whether we need to append it until finish iterating all of the
// records in left side.
// If we didn't get any proper row, then append a single row with empty left.
joinedRow.withLeft(leftNullRow).copy
@@ -236,7 +235,7 @@ case class HashOuterJoin(
val joinedRow = new JoinedRow()
val leftNullRow = new GenericRow(left.output.length)
val rightNullRow = new GenericRow(right.output.length)
- val boundCondition =
+ val boundCondition =
condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)
if (!key.anyNull) {
@@ -246,8 +245,8 @@ case class HashOuterJoin(
leftIter.iterator.flatMap[Row] { l =>
joinedRow.withLeft(l)
var matched = false
- rightIter.zipWithIndex.collect {
- // 1. For those matched (satisfy the join condition) records with both sides filled,
+ rightIter.zipWithIndex.collect {
+ // 1. For those matched (satisfy the join condition) records with both sides filled,
// append them directly
case (r, idx) if (boundCondition(joinedRow.withRight(r)))=> {
@@ -260,7 +259,7 @@ case class HashOuterJoin(
// 2. For those unmatched records in left, append additional records with empty right.
// DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row,
- // as we don't know whether we need to append it until finish iterating all
+ // as we don't know whether we need to append it until finish iterating all
// of the records in right side.
// If we didn't get any proper row, then append a single row with empty right.
joinedRow.withRight(rightNullRow).copy
@@ -268,8 +267,8 @@ case class HashOuterJoin(
} ++ rightIter.zipWithIndex.collect {
// 3. For those unmatched records in right, append additional records with empty left.
- // Re-visiting the records in right, and append additional row with empty left, if its not
- // in the matched set.
+ // Re-visiting the records in right, and append additional row with empty left, if its not
+ // in the matched set.
case (r, idx) if (!rightMatchedSet.contains(idx)) => {
joinedRow(leftNullRow, r).copy
}
@@ -284,15 +283,15 @@ case class HashOuterJoin(
}
private[this] def buildHashTable(
- iter: Iterator[Row], keyGenerator: Projection): JavaHashMap[Row, ArrayBuffer[Row]] = {
- val hashTable = new JavaHashMap[Row, ArrayBuffer[Row]]()
+ iter: Iterator[Row], keyGenerator: Projection): JavaHashMap[Row, CompactBuffer[Row]] = {
+ val hashTable = new JavaHashMap[Row, CompactBuffer[Row]]()
while (iter.hasNext) {
val currentRow = iter.next()
val rowKey = keyGenerator(currentRow)
var existingMatchList = hashTable.get(rowKey)
if (existingMatchList == null) {
- existingMatchList = new ArrayBuffer[Row]()
+ existingMatchList = new CompactBuffer[Row]()
hashTable.put(rowKey, existingMatchList)
}
@@ -311,20 +310,20 @@ case class HashOuterJoin(
val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output))
import scala.collection.JavaConversions._
- val boundCondition =
+ val boundCondition =
condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true)
joinType match {
case LeftOuter => leftHashTable.keysIterator.flatMap { key =>
- leftOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST),
+ leftOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST),
rightHashTable.getOrElse(key, EMPTY_LIST))
}
case RightOuter => rightHashTable.keysIterator.flatMap { key =>
- rightOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST),
+ rightOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST),
rightHashTable.getOrElse(key, EMPTY_LIST))
}
case FullOuter => (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key =>
- fullOuterIterator(key,
- leftHashTable.getOrElse(key, EMPTY_LIST),
+ fullOuterIterator(key,
+ leftHashTable.getOrElse(key, EMPTY_LIST),
rightHashTable.getOrElse(key, EMPTY_LIST))
}
case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType")
@@ -550,7 +549,7 @@ case class BroadcastNestedLoopJoin(
/** All rows that either match both-way, or rows from streamed joined with nulls. */
val matchesOrStreamedRowsWithNulls = streamed.execute().mapPartitions { streamedIter =>
- val matchedRows = new ArrayBuffer[Row]
+ val matchedRows = new CompactBuffer[Row]
// TODO: Use Spark's BitSet.
val includedBroadcastTuples =
new scala.collection.mutable.BitSet(broadcastedRelation.value.size)
@@ -602,20 +601,20 @@ case class BroadcastNestedLoopJoin(
val rightNulls = new GenericMutableRow(right.output.size)
/** Rows from broadcasted joined with nulls. */
val broadcastRowsWithNulls: Seq[Row] = {
- val arrBuf: collection.mutable.ArrayBuffer[Row] = collection.mutable.ArrayBuffer()
+ val buf: CompactBuffer[Row] = new CompactBuffer()
var i = 0
val rel = broadcastedRelation.value
while (i < rel.length) {
if (!allIncludedBroadcastTuples.contains(i)) {
(joinType, buildSide) match {
- case (RightOuter | FullOuter, BuildRight) => arrBuf += new JoinedRow(leftNulls, rel(i))
- case (LeftOuter | FullOuter, BuildLeft) => arrBuf += new JoinedRow(rel(i), rightNulls)
+ case (RightOuter | FullOuter, BuildRight) => buf += new JoinedRow(leftNulls, rel(i))
+ case (LeftOuter | FullOuter, BuildLeft) => buf += new JoinedRow(rel(i), rightNulls)
case _ =>
}
}
i += 1
}
- arrBuf.toSeq
+ buf.toSeq
}
// TODO: Breaks lineage.