aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
Diffstat (limited to 'core')
-rw-r--r--core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java43
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala49
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala2
-rw-r--r--core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java19
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala60
6 files changed, 160 insertions, 16 deletions
diff --git a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
index 8f78fc5a41..4c54ba4bce 100644
--- a/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
+++ b/core/src/main/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorter.java
@@ -138,6 +138,11 @@ public final class UnsafeExternalSorter {
this.inMemSorter = existingInMemorySorter;
}
+ // Acquire a new page as soon as we construct the sorter to ensure that we have at
+ // least one page to work with. Otherwise, other operators in the same task may starve
+ // this sorter (SPARK-9709).
+ acquireNewPage();
+
// Register a cleanup task with TaskContext to ensure that memory is guaranteed to be freed at
// the end of the task. This is necessary to avoid memory leaks in when the downstream operator
// does not fully consume the sorter's output (e.g. sort followed by limit).
@@ -343,22 +348,32 @@ public final class UnsafeExternalSorter {
throw new IOException("Required space " + requiredSpace + " is greater than page size (" +
pageSizeBytes + ")");
} else {
- final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
- if (memoryAcquired < pageSizeBytes) {
- shuffleMemoryManager.release(memoryAcquired);
- spill();
- final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
- if (memoryAcquiredAfterSpilling != pageSizeBytes) {
- shuffleMemoryManager.release(memoryAcquiredAfterSpilling);
- throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory");
- }
- }
- currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
- currentPagePosition = currentPage.getBaseOffset();
- freeSpaceInCurrentPage = pageSizeBytes;
- allocatedPages.add(currentPage);
+ acquireNewPage();
+ }
+ }
+ }
+
+ /**
+ * Acquire a new page from the {@link ShuffleMemoryManager}.
+ *
+ * If there is not enough space to allocate the new page, spill all existing ones
+ * and try again. If there is still not enough space, report error to the caller.
+ */
+ private void acquireNewPage() throws IOException {
+ final long memoryAcquired = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
+ if (memoryAcquired < pageSizeBytes) {
+ shuffleMemoryManager.release(memoryAcquired);
+ spill();
+ final long memoryAcquiredAfterSpilling = shuffleMemoryManager.tryToAcquire(pageSizeBytes);
+ if (memoryAcquiredAfterSpilling != pageSizeBytes) {
+ shuffleMemoryManager.release(memoryAcquiredAfterSpilling);
+ throw new IOException("Unable to acquire " + pageSizeBytes + " bytes of memory");
}
}
+ currentPage = taskMemoryManager.allocatePage(pageSizeBytes);
+ currentPagePosition = currentPage.getBaseOffset();
+ freeSpaceInCurrentPage = pageSizeBytes;
+ allocatedPages.add(currentPage);
}
/**
diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
index a838aac6e8..4312d3a417 100644
--- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
@@ -21,6 +21,9 @@ import scala.reflect.ClassTag
import org.apache.spark.{Partition, TaskContext}
+/**
+ * An RDD that applies the provided function to every partition of the parent RDD.
+ */
private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](
prev: RDD[T],
f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator)
diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala
new file mode 100644
index 0000000000..b475bd8d79
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.{Partition, Partitioner, TaskContext}
+
+/**
+ * An RDD that applies a user provided function to every partition of the parent RDD, and
+ * additionally allows the user to prepare each partition before computing the parent partition.
+ */
+private[spark] class MapPartitionsWithPreparationRDD[U: ClassTag, T: ClassTag, M: ClassTag](
+ prev: RDD[T],
+ preparePartition: () => M,
+ executePartition: (TaskContext, Int, M, Iterator[T]) => Iterator[U],
+ preservesPartitioning: Boolean = false)
+ extends RDD[U](prev) {
+
+ override val partitioner: Option[Partitioner] = {
+ if (preservesPartitioning) firstParent[T].partitioner else None
+ }
+
+ override def getPartitions: Array[Partition] = firstParent[T].partitions
+
+ /**
+ * Prepare a partition before computing it from its parent.
+ */
+ override def compute(partition: Partition, context: TaskContext): Iterator[U] = {
+ val preparedArgument = preparePartition()
+ val parentIterator = firstParent[T].iterator(partition, context)
+ executePartition(context, partition.index, preparedArgument, parentIterator)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
index 00c1e078a4..e3d229cc99 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
@@ -124,7 +124,7 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging {
}
}
-private object ShuffleMemoryManager {
+private[spark] object ShuffleMemoryManager {
/**
* Figure out the shuffle memory limit from a SparkConf. We currently have both a fraction
* of the memory pool and a safety factor since collections can sometimes grow bigger than
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
index 117745f9a9..f5300373d8 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -340,7 +340,8 @@ public class UnsafeExternalSorterSuite {
for (int i = 0; i < numRecordsPerPage * 10; i++) {
insertNumber(sorter, i);
newPeakMemory = sorter.getPeakMemoryUsedBytes();
- if (i % numRecordsPerPage == 0) {
+ // The first page is pre-allocated on instantiation
+ if (i % numRecordsPerPage == 0 && i > 0) {
// We allocated a new page for this record, so peak memory should change
assertEquals(previousPeakMemory + pageSizeBytes, newPeakMemory);
} else {
@@ -364,5 +365,21 @@ public class UnsafeExternalSorterSuite {
}
}
+ @Test
+ public void testReservePageOnInstantiation() throws Exception {
+ final UnsafeExternalSorter sorter = newSorter();
+ try {
+ assertEquals(1, sorter.getNumberOfAllocatedPages());
+ // Inserting a new record doesn't allocate more memory since we already have a page
+ long peakMemory = sorter.getPeakMemoryUsedBytes();
+ insertNumber(sorter, 100);
+ assertEquals(peakMemory, sorter.getPeakMemoryUsedBytes());
+ assertEquals(1, sorter.getNumberOfAllocatedPages());
+ } finally {
+ sorter.cleanupResources();
+ assertSpillFilesWereCleanedUp();
+ }
+ }
+
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala
new file mode 100644
index 0000000000..c16930e7d6
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala
@@ -0,0 +1,60 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import scala.collection.mutable
+
+import org.apache.spark.{LocalSparkContext, SparkContext, SparkFunSuite, TaskContext}
+
+class MapPartitionsWithPreparationRDDSuite extends SparkFunSuite with LocalSparkContext {
+
+ test("prepare called before parent partition is computed") {
+ sc = new SparkContext("local", "test")
+
+ // Have the parent partition push a number to the list
+ val parent = sc.parallelize(1 to 100, 1).mapPartitions { iter =>
+ TestObject.things.append(20)
+ iter
+ }
+
+ // Push a different number during the prepare phase
+ val preparePartition = () => { TestObject.things.append(10) }
+
+ // Push yet another number during the execution phase
+ val executePartition = (
+ taskContext: TaskContext,
+ partitionIndex: Int,
+ notUsed: Unit,
+ parentIterator: Iterator[Int]) => {
+ TestObject.things.append(30)
+ TestObject.things.iterator
+ }
+
+ // Verify that the numbers are pushed in the order expected
+ val result = {
+ new MapPartitionsWithPreparationRDD[Int, Int, Unit](
+ parent, preparePartition, executePartition).collect()
+ }
+ assert(result === Array(10, 20, 30))
+ }
+
+}
+
+private object TestObject {
+ val things = new mutable.ListBuffer[Int]
+}