diff options
author | tianyi <tianyi@asiainfo-linkage.com> | 2014-12-16 15:22:29 -0800 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2014-12-16 15:22:29 -0800 |
commit | 30f6b85c816d1ef611a7be071af0053d64b6fe9e (patch) | |
tree | 31407834a106ccfe88bf6be7e2fdb6b2836b15e9 | |
parent | ea1315e3e26507c8e1cab877cec5fe69c2899ae8 (diff) | |
download | spark-30f6b85c816d1ef611a7be071af0053d64b6fe9e.tar.gz spark-30f6b85c816d1ef611a7be071af0053d64b6fe9e.tar.bz2 spark-30f6b85c816d1ef611a7be071af0053d64b6fe9e.zip |
[SPARK-4483][SQL]Optimization about reduce memory costs during the HashOuterJoin
In `HashOuterJoin.scala`, spark read data from both side of join operation before zip them together. It is a waste for memory. We are trying to read data from only one side, put them into a hashmap, and then generate the `JoinedRow` with data from other side one by one.
Currently, we could only do this optimization for `left outer join` and `right outer join`. For `full outer join`, we will do something in another issue.
for
table test_csv contains 1 million records
table dim_csv contains 10 thousand records
SQL:
`select * from test_csv a left outer join dim_csv b on a.key = b.key`
the result is:
master:
```
CSV: 12671 ms
CSV: 9021 ms
CSV: 9200 ms
Current Mem Usage:787788984
```
after patch:
```
CSV: 10382 ms
CSV: 7543 ms
CSV: 7469 ms
Current Mem Usage:208145728
```
Author: tianyi <tianyi@asiainfo-linkage.com>
Author: tianyi <tianyi.asiainfo@gmail.com>
Closes #3375 from tianyi/SPARK-4483 and squashes the following commits:
72a8aec [tianyi] avoid having mutable state stored inside of the task
99c5c97 [tianyi] performance optimization
d2f94d7 [tianyi] fix bug: missing output when the join-key is null.
2be45d1 [tianyi] fix spell bug
1f2c6f1 [tianyi] remove commented codes
a676de6 [tianyi] optimize some codes
9e7d5b5 [tianyi] remove commented old codes
838707d [tianyi] Optimization about reduce memory costs during the HashOuterJoin
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala | 128 |
1 files changed, 64 insertions, 64 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala index b73041d306..59ef904272 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/joins/HashOuterJoin.scala @@ -68,66 +68,56 @@ case class HashOuterJoin( @transient private[this] lazy val DUMMY_LIST = Seq[Row](null) @transient private[this] lazy val EMPTY_LIST = Seq.empty[Row] + @transient private[this] lazy val leftNullRow = new GenericRow(left.output.length) + @transient private[this] lazy val rightNullRow = new GenericRow(right.output.length) + @transient private[this] lazy val boundCondition = + condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + // TODO we need to rewrite all of the iterators with our own implementation instead of the Scala // iterator for performance purpose. private[this] def leftOuterIterator( - key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { - val joinedRow = new JoinedRow() - val rightNullRow = new GenericRow(right.output.length) - val boundCondition = - condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) - - leftIter.iterator.flatMap { l => - joinedRow.withLeft(l) - var matched = false - (if (!key.anyNull) rightIter.collect { case r if (boundCondition(joinedRow.withRight(r))) => - matched = true - joinedRow.copy + key: Row, joinedRow: JoinedRow, rightIter: Iterable[Row]): Iterator[Row] = { + val ret: Iterable[Row] = ( + if (!key.anyNull) { + val temp = rightIter.collect { + case r if (boundCondition(joinedRow.withRight(r))) => joinedRow.copy + } + if (temp.size == 0) { + joinedRow.withRight(rightNullRow).copy :: Nil + } else { + temp + } } else { - Nil - }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => { - // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, - // as we don't know whether we need to append it until finish iterating all of the - // records in right side. - // If we didn't get any proper row, then append a single row with empty right - joinedRow.withRight(rightNullRow).copy - }) - } + joinedRow.withRight(rightNullRow).copy :: Nil + } + ) + ret.iterator } private[this] def rightOuterIterator( - key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { - val joinedRow = new JoinedRow() - val leftNullRow = new GenericRow(left.output.length) - val boundCondition = - condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) - - rightIter.iterator.flatMap { r => - joinedRow.withRight(r) - var matched = false - (if (!key.anyNull) leftIter.collect { case l if (boundCondition(joinedRow.withLeft(l))) => - matched = true - joinedRow.copy + key: Row, leftIter: Iterable[Row], joinedRow: JoinedRow): Iterator[Row] = { + + val ret: Iterable[Row] = ( + if (!key.anyNull) { + val temp = leftIter.collect { + case l if (boundCondition(joinedRow.withLeft(l))) => joinedRow.copy + } + if (temp.size == 0) { + joinedRow.withLeft(leftNullRow).copy :: Nil + } else { + temp + } } else { - Nil - }) ++ DUMMY_LIST.filter(_ => !matched).map( _ => { - // DUMMY_LIST.filter(_ => !matched) is a tricky way to add additional row, - // as we don't know whether we need to append it until finish iterating all of the - // records in left side. - // If we didn't get any proper row, then append a single row with empty left. - joinedRow.withLeft(leftNullRow).copy - }) - } + joinedRow.withLeft(leftNullRow).copy :: Nil + } + ) + ret.iterator } private[this] def fullOuterIterator( - key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row]): Iterator[Row] = { - val joinedRow = new JoinedRow() - val leftNullRow = new GenericRow(left.output.length) - val rightNullRow = new GenericRow(right.output.length) - val boundCondition = - condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + key: Row, leftIter: Iterable[Row], rightIter: Iterable[Row], + joinedRow: JoinedRow): Iterator[Row] = { if (!key.anyNull) { // Store the positions of records in right, if one of its associated row satisfy @@ -193,27 +183,37 @@ case class HashOuterJoin( } override def execute() = { + val joinedRow = new JoinedRow() left.execute().zipPartitions(right.execute()) { (leftIter, rightIter) => // TODO this probably can be replaced by external sort (sort merged join?) - // Build HashMap for current partition in left relation - val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) - // Build HashMap for current partition in right relation - val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) - val boundCondition = - condition.map(newPredicate(_, left.output ++ right.output)).getOrElse((row: Row) => true) + joinType match { - case LeftOuter => leftHashTable.keysIterator.flatMap { key => - leftOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), - rightHashTable.getOrElse(key, EMPTY_LIST)) + case LeftOuter => { + val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) + val keyGenerator = newProjection(leftKeys, left.output) + leftIter.flatMap( currentRow => { + val rowKey = keyGenerator(currentRow) + joinedRow.withLeft(currentRow) + leftOuterIterator(rowKey, joinedRow, rightHashTable.getOrElse(rowKey, EMPTY_LIST)) + }) } - case RightOuter => rightHashTable.keysIterator.flatMap { key => - rightOuterIterator(key, leftHashTable.getOrElse(key, EMPTY_LIST), - rightHashTable.getOrElse(key, EMPTY_LIST)) + case RightOuter => { + val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) + val keyGenerator = newProjection(rightKeys, right.output) + rightIter.flatMap ( currentRow => { + val rowKey = keyGenerator(currentRow) + joinedRow.withRight(currentRow) + rightOuterIterator(rowKey, leftHashTable.getOrElse(rowKey, EMPTY_LIST), joinedRow) + }) } - case FullOuter => (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => - fullOuterIterator(key, - leftHashTable.getOrElse(key, EMPTY_LIST), - rightHashTable.getOrElse(key, EMPTY_LIST)) + case FullOuter => { + val leftHashTable = buildHashTable(leftIter, newProjection(leftKeys, left.output)) + val rightHashTable = buildHashTable(rightIter, newProjection(rightKeys, right.output)) + (leftHashTable.keySet ++ rightHashTable.keySet).iterator.flatMap { key => + fullOuterIterator(key, + leftHashTable.getOrElse(key, EMPTY_LIST), + rightHashTable.getOrElse(key, EMPTY_LIST), joinedRow) + } } case x => throw new Exception(s"HashOuterJoin should not take $x as the JoinType") } |