diff options
author | Matthew Taylor <matthew.t@tbfe.net> | 2013-11-18 06:41:21 +0000 |
---|---|---|
committer | Matthew Taylor <matthew.t@tbfe.net> | 2013-11-19 06:27:33 +0000 |
commit | 13b9bf494b0d1d0e65dc357efe832763127aefd2 (patch) | |
tree | 1f063d31c6caf09661739ef64f38785436051314 | |
parent | e2ebc3a9d8bca83bf842b134f2f056c1af0ad2be (diff) | |
download | spark-13b9bf494b0d1d0e65dc357efe832763127aefd2.tar.gz spark-13b9bf494b0d1d0e65dc357efe832763127aefd2.tar.bz2 spark-13b9bf494b0d1d0e65dc357efe832763127aefd2.zip |
PartitionPruningRDD is using index from parent
-rw-r--r-- | core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala | 6 | ||||
-rw-r--r-- | core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala | 70 |
2 files changed, 63 insertions, 13 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala index 165cd412fc..2738a00894 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala @@ -34,10 +34,12 @@ class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boo @transient val partitions: Array[Partition] = rdd.partitions.zipWithIndex - .filter(s => partitionFilterFunc(s._2)) + .filter(s => partitionFilterFunc(s._2)).map(_._1).zipWithIndex .map { case(split, idx) => new PartitionPruningRDDPartition(idx, split) : Partition } - override def getParents(partitionId: Int) = List(partitions(partitionId).index) + override def getParents(partitionId: Int) = { + List(partitions(partitionId).asInstanceOf[PartitionPruningRDDPartition].parentSplit.index) + } } diff --git a/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala b/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala index 21f16ef2c6..28e71e835f 100644 --- a/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala @@ -19,27 +19,75 @@ package org.apache.spark import org.scalatest.FunSuite import org.apache.spark.SparkContext._ -import org.apache.spark.rdd.{RDD, PartitionPruningRDD} +import org.apache.spark.rdd.{PartitionPruningRDDPartition, RDD, PartitionPruningRDD} class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext { + test("Pruned Partitions inherit locality prefs correctly") { - class TestPartition(i: Int) extends Partition { - def index = i - } + val rdd = new RDD[Int](sc, Nil) { override protected def getPartitions = { Array[Partition]( - new TestPartition(1), - new TestPartition(2), - new TestPartition(3)) + new TestPartition(0, 1), + new TestPartition(1, 1), + new TestPartition(2, 1)) + } + + def compute(split: Partition, context: TaskContext) = { + Iterator() } - def compute(split: Partition, context: TaskContext) = {Iterator()} } - val prunedRDD = PartitionPruningRDD.create(rdd, {x => if (x==2) true else false}) - val p = prunedRDD.partitions(0) - assert(p.index == 2) + val prunedRDD = PartitionPruningRDD.create(rdd, { + x => if (x == 2) true else false + }) assert(prunedRDD.partitions.length == 1) + val p = prunedRDD.partitions(0) + assert(p.index == 0) + assert(p.asInstanceOf[PartitionPruningRDDPartition].parentSplit.index == 2) } + + + test("Pruned Partitions can be merged ") { + + val rdd = new RDD[Int](sc, Nil) { + override protected def getPartitions = { + Array[Partition]( + new TestPartition(0, 4), + new TestPartition(1, 5), + new TestPartition(2, 6)) + } + + def compute(split: Partition, context: TaskContext) = { + List(split.asInstanceOf[TestPartition].testValue).iterator + } + } + val prunedRDD1 = PartitionPruningRDD.create(rdd, { + x => if (x == 0) true else false + }) + + val prunedRDD2 = PartitionPruningRDD.create(rdd, { + x => if (x == 2) true else false + }) + + val merged = prunedRDD1 ++ prunedRDD2 + + assert(merged.count() == 2) + val take = merged.take(2) + + assert(take.apply(0) == 4) + + assert(take.apply(1) == 6) + + + } + } + +class TestPartition(i: Int, value: Int) extends Partition with Serializable { + def index = i + + def testValue = this.value + +}
\ No newline at end of file |