diff options
-rw-r--r-- | core/src/main/scala/spark/rdd/PartitionPruningRDD.scala | 5 | ||||
-rw-r--r-- | core/src/test/scala/spark/PartitionPruningRDDSuite.scala | 28 |
2 files changed, 31 insertions, 2 deletions
diff --git a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala index 191cfde565..d8700becb0 100644 --- a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala @@ -33,8 +33,9 @@ class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boo extends NarrowDependency[T](rdd) { @transient - val partitions: Array[Partition] = rdd.partitions.filter(s => partitionFilterFunc(s.index)) - .zipWithIndex.map { case(split, idx) => new PartitionPruningRDDPartition(idx, split) : Partition } + val partitions: Array[Partition] = rdd.partitions.zipWithIndex + .filter(s => partitionFilterFunc(s._2)) + .map { case(split, idx) => new PartitionPruningRDDPartition(idx, split) : Partition } override def getParents(partitionId: Int) = List(partitions(partitionId).index) } diff --git a/core/src/test/scala/spark/PartitionPruningRDDSuite.scala b/core/src/test/scala/spark/PartitionPruningRDDSuite.scala new file mode 100644 index 0000000000..88352b639f --- /dev/null +++ b/core/src/test/scala/spark/PartitionPruningRDDSuite.scala @@ -0,0 +1,28 @@ +package spark + +import org.scalatest.FunSuite +import spark.SparkContext._ +import spark.rdd.PartitionPruningRDD + + +class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext { + + test("Pruned Partitions inherit locality prefs correctly") { + class TestPartition(i: Int) extends Partition { + def index = i + } + val rdd = new RDD[Int](sc, Nil) { + override protected def getPartitions = { + Array[Partition]( + new TestPartition(1), + new TestPartition(2), + new TestPartition(3)) + } + def compute(split: Partition, context: TaskContext) = {Iterator()} + } + val prunedRDD = PartitionPruningRDD.create(rdd, {x => if (x==2) true else false}) + val p = prunedRDD.partitions(0) + assert(p.index == 2) + assert(prunedRDD.partitions.length == 1) + } +} |