diff options
Diffstat (limited to 'sql')
-rw-r--r-- | sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala | 14 | ||||
-rw-r--r-- | sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala | 24 |
2 files changed, 32 insertions, 6 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala index ce5827855e..663bc904f3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/CoGroupedIterator.scala @@ -38,17 +38,19 @@ class CoGroupedIterator( private var currentLeftData: (InternalRow, Iterator[InternalRow]) = _ private var currentRightData: (InternalRow, Iterator[InternalRow]) = _ - override def hasNext: Boolean = left.hasNext || right.hasNext - - override def next(): (InternalRow, Iterator[InternalRow], Iterator[InternalRow]) = { - if (currentLeftData.eq(null) && left.hasNext) { + override def hasNext: Boolean = { + if (currentLeftData == null && left.hasNext) { currentLeftData = left.next() } - if (currentRightData.eq(null) && right.hasNext) { + if (currentRightData == null && right.hasNext) { currentRightData = right.next() } - assert(currentLeftData.ne(null) || currentRightData.ne(null)) + currentLeftData != null || currentRightData != null + } + + override def next(): (InternalRow, Iterator[InternalRow], Iterator[InternalRow]) = { + assert(hasNext) if (currentLeftData.eq(null)) { // left is null, right is not null, consume the right data. diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala index d1fe81947e..4ff96e6574 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/CoGroupedIteratorSuite.scala @@ -48,4 +48,28 @@ class CoGroupedIteratorSuite extends SparkFunSuite with ExpressionEvalHelper { Nil ) } + + test("SPARK-11393: respect the fact that GroupedIterator.hasNext is not idempotent") { + val leftInput = Seq(create_row(2, "a")).iterator + val rightInput = Seq(create_row(1, 2L)).iterator + val leftGrouped = GroupedIterator(leftInput, Seq('i.int.at(0)), Seq('i.int, 's.string)) + val rightGrouped = GroupedIterator(rightInput, Seq('i.int.at(0)), Seq('i.int, 'l.long)) + val cogrouped = new CoGroupedIterator(leftGrouped, rightGrouped, Seq('i.int)) + + val result = cogrouped.map { + case (key, leftData, rightData) => + assert(key.numFields == 1) + (key.getInt(0), leftData.toSeq, rightData.toSeq) + }.toSeq + + assert(result == + (1, + Seq.empty, + Seq(create_row(1, 2L))) :: + (2, + Seq(create_row(2, "a")), + Seq.empty) :: + Nil + ) + } } |