diff options
3 files changed, 94 insertions, 39 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala new file mode 100644 index 0000000000..96cf93f99e --- /dev/null +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionerAwareUnionRDD.scala @@ -0,0 +1,65 @@ +package org.apache.spark.rdd + +import org.apache.spark.{TaskContext, OneToOneDependency, SparkContext, Partition} + +private[spark] +class PartitionerAwareUnionRDDPartition(val idx: Int, val partitions: Array[Partition]) + extends Partition { + override val index = idx + override def hashCode(): Int = idx +} + +private[spark] +class PartitionerAwareUnionRDD[T: ClassManifest]( + sc: SparkContext, + var rdds: Seq[RDD[T]] + ) extends RDD[T](sc, rdds.map(x => new OneToOneDependency(x))) { + require(rdds.length > 0) + require(rdds.flatMap(_.partitioner).toSet.size == 1, + "Parent RDDs have different partitioners: " + rdds.flatMap(_.partitioner)) + + override val partitioner = rdds.head.partitioner + + override def getPartitions: Array[Partition] = { + val numPartitions = rdds.head.partitions.length + (0 until numPartitions).map(index => { + val parentPartitions = rdds.map(_.partitions(index)).toArray + new PartitionerAwareUnionRDDPartition(index, parentPartitions) + }).toArray + } + + // Get the location where most of the partitions of parent RDDs are located + override def getPreferredLocations(s: Partition): Seq[String] = { + logDebug("Getting preferred locations for " + this) + val parentPartitions = s.asInstanceOf[PartitionerAwareUnionRDDPartition].partitions + val locations = rdds.zip(parentPartitions).flatMap { + case (rdd, part) => { + val parentLocations = currPrefLocs(rdd, part) + logDebug("Location of " + rdd + " partition " + part.index + " = " + parentLocations) + parentLocations + } + } + + if (locations.isEmpty) { + Seq.empty + } else { + Seq(locations.groupBy(x => x).map(x => (x._1, x._2.length)).maxBy(_._2)._1) + } + } + + override def compute(s: Partition, context: TaskContext): Iterator[T] = { + val parentPartitions = s.asInstanceOf[PartitionerAwareUnionRDDPartition].partitions + rdds.zip(parentPartitions).iterator.flatMap { + case (rdd, p) => rdd.iterator(p, context) + } + } + + // gets the *current* preferred locations from the DAGScheduler (as opposed to the static ones) + private def currPrefLocs(rdd: RDD[_], part: Partition): Seq[String] = { + rdd.context.getPreferredLocs(rdd, part.index).map(tl => tl.host) + } +} + + + + diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 354ab8ae5d..88b36a6855 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -71,6 +71,33 @@ class RDDSuite extends FunSuite with SharedSparkContext { assert(sc.union(Seq(nums, nums)).collect().toList === List(1, 2, 3, 4, 1, 2, 3, 4)) } + test("partitioner aware union") { + import SparkContext._ + def makeRDDWithPartitioner(seq: Seq[Int]) = { + sc.makeRDD(seq, 1) + .map(x => (x, null)) + .partitionBy(new HashPartitioner(2)) + .mapPartitions(_.map(_._1), true) + } + + val nums1 = makeRDDWithPartitioner(1 to 4) + val nums2 = makeRDDWithPartitioner(5 to 8) + assert(nums1.partitioner == nums2.partitioner) + assert(new PartitionerAwareUnionRDD(sc, Seq(nums1)).collect().toSet === Set(1, 2, 3, 4)) + + val union = new PartitionerAwareUnionRDD(sc, Seq(nums1, nums2)) + assert(union.collect().toSet === Set(1, 2, 3, 4, 5, 6, 7, 8)) + val nums1Parts = nums1.collectPartitions() + val nums2Parts = nums2.collectPartitions() + val unionParts = union.collectPartitions() + assert(nums1Parts.length === 2) + assert(nums2Parts.length === 2) + assert(unionParts.length === 2) + assert((nums1Parts(0) ++ nums2Parts(0)).toList === unionParts(0).toList) + assert((nums1Parts(1) ++ nums2Parts(1)).toList === unionParts(1).toList) + assert(union.partitioner === nums1.partitioner) + } + test("aggregate") { val pairs = sc.makeRDD(Array(("a", 1), ("b", 2), ("a", 2), ("c", 5), ("a", 3))) type StringMap = HashMap[String, Int] diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala index 03f522e581..49f84310bc 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala @@ -17,8 +17,7 @@ package org.apache.spark.streaming.dstream -import org.apache.spark.rdd.RDD -import org.apache.spark.rdd.UnionRDD +import org.apache.spark.rdd.{PartitionerAwareUnionRDD, RDD, UnionRDD} import org.apache.spark.storage.StorageLevel import org.apache.spark.streaming._ import org.apache.spark._ @@ -57,7 +56,7 @@ class WindowedDStream[T: ClassManifest]( val rddsInWindow = parent.slice(currentWindow) val windowRDD = if (rddsInWindow.flatMap(_.partitioner).distinct.length == 1) { logInfo("Using partition aware union") - new PartitionAwareUnionRDD(ssc.sc, rddsInWindow) + new PartitionerAwareUnionRDD(ssc.sc, rddsInWindow) } else { logInfo("Using normal union") new UnionRDD(ssc.sc,rddsInWindow) @@ -66,39 +65,3 @@ class WindowedDStream[T: ClassManifest]( } } -private[streaming] -class PartitionAwareUnionRDDPartition(val idx: Int, val partitions: Array[Partition]) - extends Partition { - override val index = idx - override def hashCode(): Int = idx -} - -private[streaming] -class PartitionAwareUnionRDD[T: ClassManifest]( - sc: SparkContext, - var rdds: Seq[RDD[T]]) - extends RDD[T](sc, rdds.map(x => new OneToOneDependency(x))) { - require(rdds.length > 0) - require(rdds.flatMap(_.partitioner).distinct.length == 1, "Parent RDDs have different partitioners") - - override val partitioner = rdds.head.partitioner - - override def getPartitions: Array[Partition] = { - val numPartitions = rdds.head.partitions.length - (0 until numPartitions).map(index => { - val parentPartitions = rdds.map(_.partitions(index)).toArray - new PartitionAwareUnionRDDPartition(index, parentPartitions) - }).toArray - } - - override def compute(s: Partition, context: TaskContext): Iterator[T] = { - val parentPartitions = s.asInstanceOf[PartitionAwareUnionRDDPartition].partitions - rdds.zip(parentPartitions).iterator.flatMap { - case (rdd, p) => rdd.iterator(p, context) - } - } -} - - - - |