aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala14
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala24
2 files changed, 32 insertions, 6 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala
index ce5827855e..663bc904f3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala
@@ -38,17 +38,19 @@ class CoGroupedIterator(
private var currentLeftData: (InternalRow, Iterator[InternalRow]) = _
private var currentRightData: (InternalRow, Iterator[InternalRow]) = _
- override def hasNext: Boolean = left.hasNext || right.hasNext
-
- override def next(): (InternalRow, Iterator[InternalRow], Iterator[InternalRow]) = {
- if (currentLeftData.eq(null) && left.hasNext) {
+ override def hasNext: Boolean = {
+ if (currentLeftData == null && left.hasNext) {
currentLeftData = left.next()
}
- if (currentRightData.eq(null) && right.hasNext) {
+ if (currentRightData == null && right.hasNext) {
currentRightData = right.next()
}
- assert(currentLeftData.ne(null) || currentRightData.ne(null))
+ currentLeftData != null || currentRightData != null
+ }
+
+ override def next(): (InternalRow, Iterator[InternalRow], Iterator[InternalRow]) = {
+ assert(hasNext)
if (currentLeftData.eq(null)) {
// left is null, right is not null, consume the right data.
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala
index d1fe81947e..4ff96e6574 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala
@@ -48,4 +48,28 @@ class CoGroupedIteratorSuite extends SparkFunSuite with ExpressionEvalHelper {
Nil
)
}
+
+ test("SPARK-11393: respect the fact that GroupedIterator.hasNext is not idempotent") {
+ val leftInput = Seq(create_row(2, "a")).iterator
+ val rightInput = Seq(create_row(1, 2L)).iterator
+ val leftGrouped = GroupedIterator(leftInput, Seq('i.int.at(0)), Seq('i.int, 's.string))
+ val rightGrouped = GroupedIterator(rightInput, Seq('i.int.at(0)), Seq('i.int, 'l.long))
+ val cogrouped = new CoGroupedIterator(leftGrouped, rightGrouped, Seq('i.int))
+
+ val result = cogrouped.map {
+ case (key, leftData, rightData) =>
+ assert(key.numFields == 1)
+ (key.getInt(0), leftData.toSeq, rightData.toSeq)
+ }.toSeq
+
+ assert(result ==
+ (1,
+ Seq.empty,
+ Seq(create_row(1, 2L))) ::
+ (2,
+ Seq(create_row(2, "a")),
+ Seq.empty) ::
+ Nil
+ )
+ }
}