aboutsummaryrefslogtreecommitdiff
path: root/core/src
diff options
context:
space:
mode:
Diffstat (limited to 'core/src')
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala18
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala17
2 files changed, 34 insertions, 1 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
index 66cf4369da..8171dcc046 100644
--- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
@@ -20,6 +20,8 @@ package org.apache.spark.rdd
import java.io.{IOException, ObjectOutputStream}
import scala.collection.mutable.ArrayBuffer
+import scala.collection.parallel.ForkJoinTaskSupport
+import scala.concurrent.forkjoin.ForkJoinPool
import scala.reflect.ClassTag
import org.apache.spark.{Dependency, Partition, RangeDependency, SparkContext, TaskContext}
@@ -62,8 +64,22 @@ class UnionRDD[T: ClassTag](
var rdds: Seq[RDD[T]])
extends RDD[T](sc, Nil) { // Nil since we implement getDependencies
+ // visible for testing
+ private[spark] val isPartitionListingParallel: Boolean =
+ rdds.length > conf.getInt("spark.rdd.parallelListingThreshold", 10)
+
+ @transient private lazy val partitionEvalTaskSupport =
+ new ForkJoinTaskSupport(new ForkJoinPool(8))
+
override def getPartitions: Array[Partition] = {
- val array = new Array[Partition](rdds.map(_.partitions.length).sum)
+ val parRDDs = if (isPartitionListingParallel) {
+ val parArray = rdds.par
+ parArray.tasksupport = partitionEvalTaskSupport
+ parArray
+ } else {
+ rdds
+ }
+ val array = new Array[Partition](parRDDs.map(_.partitions.length).seq.sum)
var pos = 0
for ((rdd, rddIndex) <- rdds.zipWithIndex; split <- rdd.partitions) {
array(pos) = new UnionPartition(pos, rdd, rddIndex, split.index)
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index a663dab772..979fb426c9 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -116,6 +116,23 @@ class RDDSuite extends SparkFunSuite with SharedSparkContext {
assert(sc.union(Seq(nums, nums)).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4))
}
+ test("SparkContext.union parallel partition listing") {
+ val nums1 = sc.makeRDD(Array(1, 2, 3, 4), 2)
+ val nums2 = sc.makeRDD(Array(5, 6, 7, 8), 2)
+ val serialUnion = sc.union(nums1, nums2)
+ val expected = serialUnion.collect().toList
+
+ assert(serialUnion.asInstanceOf[UnionRDD[Int]].isPartitionListingParallel === false)
+
+ sc.conf.set("spark.rdd.parallelListingThreshold", "1")
+ val parallelUnion = sc.union(nums1, nums2)
+ val actual = parallelUnion.collect().toList
+ sc.conf.remove("spark.rdd.parallelListingThreshold")
+
+ assert(parallelUnion.asInstanceOf[UnionRDD[Int]].isPartitionListingParallel === true)
+ assert(expected === actual)
+ }
+
test("SparkContext.union creates UnionRDD if at least one RDD has no partitioner") {
val rddWithPartitioner = sc.parallelize(Seq(1 -> true)).partitionBy(new HashPartitioner(1))
val rddWithNoPartitioner = sc.parallelize(Seq(2 -> true))