aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala21
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala13
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala14
3 files changed, 42 insertions, 6 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala
index b475bd8d79..1f2213d0c4 100644
--- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDD.scala
@@ -17,6 +17,7 @@
package org.apache.spark.rdd
+import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag
import org.apache.spark.{Partition, Partitioner, TaskContext}
@@ -38,12 +39,28 @@ private[spark] class MapPartitionsWithPreparationRDD[U: ClassTag, T: ClassTag, M
override def getPartitions: Array[Partition] = firstParent[T].partitions
+ // In certain join operations, prepare can be called on the same partition multiple times.
+ // In this case, we need to ensure that each call to compute gets a separate prepare argument.
+ private[this] var preparedArguments: ArrayBuffer[M] = new ArrayBuffer[M]
+
+ /**
+ * Prepare a partition for a single call to compute.
+ */
+ def prepare(): Unit = {
+ preparedArguments += preparePartition()
+ }
+
/**
* Prepare a partition before computing it from its parent.
*/
override def compute(partition: Partition, context: TaskContext): Iterator[U] = {
- val preparedArgument = preparePartition()
+ val prepared =
+ if (preparedArguments.isEmpty) {
+ preparePartition()
+ } else {
+ preparedArguments.remove(0)
+ }
val parentIterator = firstParent[T].iterator(partition, context)
- executePartition(context, partition.index, preparedArgument, parentIterator)
+ executePartition(context, partition.index, prepared, parentIterator)
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
index 81f40ad33a..b3c64394ab 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
@@ -73,6 +73,16 @@ private[spark] abstract class ZippedPartitionsBaseRDD[V: ClassTag](
super.clearDependencies()
rdds = null
}
+
+ /**
+ * Call the prepare method of every parent that has one.
+ * This is needed for reserving execution memory in advance.
+ */
+ protected def tryPrepareParents(): Unit = {
+ rdds.collect {
+ case rdd: MapPartitionsWithPreparationRDD[_, _, _] => rdd.prepare()
+ }
+ }
}
private[spark] class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag](
@@ -84,6 +94,7 @@ private[spark] class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag]
extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2), preservesPartitioning) {
override def compute(s: Partition, context: TaskContext): Iterator[V] = {
+ tryPrepareParents()
val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
f(rdd1.iterator(partitions(0), context), rdd2.iterator(partitions(1), context))
}
@@ -107,6 +118,7 @@ private[spark] class ZippedPartitionsRDD3
extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3), preservesPartitioning) {
override def compute(s: Partition, context: TaskContext): Iterator[V] = {
+ tryPrepareParents()
val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
f(rdd1.iterator(partitions(0), context),
rdd2.iterator(partitions(1), context),
@@ -134,6 +146,7 @@ private[spark] class ZippedPartitionsRDD4
extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4), preservesPartitioning) {
override def compute(s: Partition, context: TaskContext): Iterator[V] = {
+ tryPrepareParents()
val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
f(rdd1.iterator(partitions(0), context),
rdd2.iterator(partitions(1), context),
diff --git a/core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala
index c16930e7d6..e281e817e4 100644
--- a/core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/MapPartitionsWithPreparationRDDSuite.scala
@@ -46,11 +46,17 @@ class MapPartitionsWithPreparationRDDSuite extends SparkFunSuite with LocalSpark
}
// Verify that the numbers are pushed in the order expected
- val result = {
- new MapPartitionsWithPreparationRDD[Int, Int, Unit](
- parent, preparePartition, executePartition).collect()
- }
+ val rdd = new MapPartitionsWithPreparationRDD[Int, Int, Unit](
+ parent, preparePartition, executePartition)
+ val result = rdd.collect()
assert(result === Array(10, 20, 30))
+
+ TestObject.things.clear()
+ // Zip two of these RDDs, both should be prepared before the parent is executed
+ val rdd2 = new MapPartitionsWithPreparationRDD[Int, Int, Unit](
+ parent, preparePartition, executePartition)
+ val result2 = rdd.zipPartitions(rdd2)((a, b) => a).collect()
+ assert(result2 === Array(10, 10, 20, 30, 20, 30))
}
}