aboutsummaryrefslogtreecommitdiff
path: root/core/src
diff options
context:
space:
mode:
Diffstat (limited to 'core/src')
-rw-r--r--core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala40
-rw-r--r--core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala61
2 files changed, 87 insertions, 14 deletions
diff --git a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
index d220ab51d1..1a3bf2bb67 100644
--- a/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/memory/MemoryStore.scala
@@ -663,31 +663,43 @@ private[spark] class MemoryStore(
private[storage] class PartiallyUnrolledIterator[T](
memoryStore: MemoryStore,
unrollMemory: Long,
- unrolled: Iterator[T],
+ private[this] var unrolled: Iterator[T],
rest: Iterator[T])
extends Iterator[T] {
- private[this] var unrolledIteratorIsConsumed: Boolean = false
- private[this] var iter: Iterator[T] = {
- val completionIterator = CompletionIterator[T, Iterator[T]](unrolled, {
- unrolledIteratorIsConsumed = true
- memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, unrollMemory)
- })
- completionIterator ++ rest
+ private def releaseUnrollMemory(): Unit = {
+ memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, unrollMemory)
+ // SPARK-17503: Garbage collects the unrolling memory before the life end of
+ // PartiallyUnrolledIterator.
+ unrolled = null
}
- override def hasNext: Boolean = iter.hasNext
- override def next(): T = iter.next()
+ override def hasNext: Boolean = {
+ if (unrolled == null) {
+ rest.hasNext
+ } else if (!unrolled.hasNext) {
+ releaseUnrollMemory()
+ rest.hasNext
+ } else {
+ true
+ }
+ }
+
+ override def next(): T = {
+ if (unrolled == null) {
+ rest.next()
+ } else {
+ unrolled.next()
+ }
+ }
/**
* Called to dispose of this iterator and free its memory.
*/
def close(): Unit = {
- if (!unrolledIteratorIsConsumed) {
- memoryStore.releaseUnrollMemoryForThisTask(MemoryMode.ON_HEAP, unrollMemory)
- unrolledIteratorIsConsumed = true
+ if (unrolled != null) {
+ releaseUnrollMemory()
}
- iter = null
}
}
diff --git a/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala b/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala
new file mode 100644
index 0000000000..02c2331dc3
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/storage/PartiallyUnrolledIteratorSuite.scala
@@ -0,0 +1,61 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.storage
+
+import org.mockito.Matchers
+import org.mockito.Mockito._
+import org.scalatest.mock.MockitoSugar
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.memory.MemoryMode.ON_HEAP
+import org.apache.spark.storage.memory.{MemoryStore, PartiallyUnrolledIterator}
+
+class PartiallyUnrolledIteratorSuite extends SparkFunSuite with MockitoSugar {
+ test("join two iterators") {
+ val unrollSize = 1000
+ val unroll = (0 until unrollSize).iterator
+ val restSize = 500
+ val rest = (unrollSize until restSize + unrollSize).iterator
+
+ val memoryStore = mock[MemoryStore]
+ val joinIterator = new PartiallyUnrolledIterator(memoryStore, unrollSize, unroll, rest)
+
+ // Firstly iterate over unrolling memory iterator
+ (0 until unrollSize).foreach { value =>
+ assert(joinIterator.hasNext)
+ assert(joinIterator.hasNext)
+ assert(joinIterator.next() == value)
+ }
+
+ joinIterator.hasNext
+ joinIterator.hasNext
+ verify(memoryStore, times(1))
+ .releaseUnrollMemoryForThisTask(Matchers.eq(ON_HEAP), Matchers.eq(unrollSize.toLong))
+
+ // Secondly, iterate over rest iterator
+ (unrollSize until unrollSize + restSize).foreach { value =>
+ assert(joinIterator.hasNext)
+ assert(joinIterator.hasNext)
+ assert(joinIterator.next() == value)
+ }
+
+ joinIterator.close()
+ // MemoryMode.releaseUnrollMemoryForThisTask is called only once
+ verifyNoMoreInteractions(memoryStore)
+ }
+}