aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorWenchen Fan <wenchen@databricks.com>2015-10-30 12:17:51 +0100
committerMichael Armbrust <michael@databricks.com>2015-10-30 12:17:51 +0100
commit14d08b99085d4e609aeae0cf54d4584e860eb552 (patch)
tree1eed67ed7c0e7abaa9395122fd74262fb72a5768 /sql
parent59db9e9c382fab40aac0633f2c779bee8cf2025f (diff)
downloadspark-14d08b99085d4e609aeae0cf54d4584e860eb552.tar.gz
spark-14d08b99085d4e609aeae0cf54d4584e860eb552.tar.bz2
spark-14d08b99085d4e609aeae0cf54d4584e860eb552.zip
[SPARK-11393] [SQL] CoGroupedIterator should respect the fact that GroupedIterator.hasNext is not idempotent
When we cogroup 2 `GroupedIterator`s in `CoGroupedIterator`, if the right side is smaller, we will consume right data and keep the left data unchanged. Then we call `hasNext` which will call `left.hasNext`. This will make `GroupedIterator` generate an extra group as the previous one has not been comsumed yet. Author: Wenchen Fan <wenchen@databricks.com> Closes #9346 from cloud-fan/cogroup and squashes the following commits: 9be67c8 [Wenchen Fan] SPARK-11393
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala14
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala24
2 files changed, 32 insertions, 6 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala
index ce5827855e..663bc904f3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala
@@ -38,17 +38,19 @@ class CoGroupedIterator(
private var currentLeftData: (InternalRow, Iterator[InternalRow]) = _
private var currentRightData: (InternalRow, Iterator[InternalRow]) = _
- override def hasNext: Boolean = left.hasNext || right.hasNext
-
- override def next(): (InternalRow, Iterator[InternalRow], Iterator[InternalRow]) = {
- if (currentLeftData.eq(null) && left.hasNext) {
+ override def hasNext: Boolean = {
+ if (currentLeftData == null && left.hasNext) {
currentLeftData = left.next()
}
- if (currentRightData.eq(null) && right.hasNext) {
+ if (currentRightData == null && right.hasNext) {
currentRightData = right.next()
}
- assert(currentLeftData.ne(null) || currentRightData.ne(null))
+ currentLeftData != null || currentRightData != null
+ }
+
+ override def next(): (InternalRow, Iterator[InternalRow], Iterator[InternalRow]) = {
+ assert(hasNext)
if (currentLeftData.eq(null)) {
// left is null, right is not null, consume the right data.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala
index d1fe81947e..4ff96e6574 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala
@@ -48,4 +48,28 @@ class CoGroupedIteratorSuite extends SparkFunSuite with ExpressionEvalHelper {
Nil
)
}
+
+ test("SPARK-11393: respect the fact that GroupedIterator.hasNext is not idempotent") {
+ val leftInput = Seq(create_row(2, "a")).iterator
+ val rightInput = Seq(create_row(1, 2L)).iterator
+ val leftGrouped = GroupedIterator(leftInput, Seq('i.int.at(0)), Seq('i.int, 's.string))
+ val rightGrouped = GroupedIterator(rightInput, Seq('i.int.at(0)), Seq('i.int, 'l.long))
+ val cogrouped = new CoGroupedIterator(leftGrouped, rightGrouped, Seq('i.int))
+
+ val result = cogrouped.map {
+ case (key, leftData, rightData) =>
+ assert(key.numFields == 1)
+ (key.getInt(0), leftData.toSeq, rightData.toSeq)
+ }.toSeq
+
+ assert(result ==
+ (1,
+ Seq.empty,
+ Seq(create_row(1, 2L))) ::
+ (2,
+ Seq(create_row(2, "a")),
+ Seq.empty) ::
+ Nil
+ )
+ }
}