From 743fc4e7aa8a2ca4edbe731bbefb2127d5d1a7d4 Mon Sep 17 00:00:00 2001 From: harshars Date: Fri, 26 Jul 2013 14:35:17 -0700 Subject: Fix Bug in Partition Pruning, index of Pruned Partitions should inherit from parent --- .../main/scala/spark/rdd/PartitionPruningRDD.scala | 6 +++-- .../scala/spark/PartitionPruningRDDSuite.scala | 28 ++++++++++++++++++++++ 2 files changed, 32 insertions(+), 2 deletions(-) create mode 100644 core/src/test/scala/spark/PartitionPruningRDDSuite.scala (limited to 'core') diff --git a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala index 41ff62dd22..6fe004a009 100644 --- a/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/spark/rdd/PartitionPruningRDD.scala @@ -16,8 +16,9 @@ class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boo extends NarrowDependency[T](rdd) { @transient - val partitions: Array[Partition] = rdd.partitions.filter(s => partitionFilterFunc(s.index)) - .zipWithIndex.map { case(split, idx) => new PartitionPruningRDDPartition(idx, split) : Partition } + val partitions: Array[Partition] = rdd.partitions. + zipWithIndex.filter(s => partitionFilterFunc(s._2)). + map { case(split, idx) => new PartitionPruningRDDPartition(idx, split) : Partition } override def getParents(partitionId: Int) = List(partitions(partitionId).index) } @@ -39,6 +40,7 @@ class PartitionPruningRDD[T: ClassManifest]( override protected def getPartitions: Array[Partition] = getDependencies.head.asInstanceOf[PruneDependency[T]].partitions + } diff --git a/core/src/test/scala/spark/PartitionPruningRDDSuite.scala b/core/src/test/scala/spark/PartitionPruningRDDSuite.scala new file mode 100644 index 0000000000..a0e6413160 --- /dev/null +++ b/core/src/test/scala/spark/PartitionPruningRDDSuite.scala @@ -0,0 +1,28 @@ +package spark + +import org.scalatest.FunSuite +import spark.SparkContext._ +import spark.rdd.PartitionPruningRDD + + +class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext { + + test("Pruned Partitions inherit locality prefs correctly") { + class TestPartition(i: Int) extends Partition { + def index = i + } + val rdd = new RDD[Int](sc, Nil) { + override protected def getPartitions = { + Array[Partition]( + new TestPartition(1), + new TestPartition(2), + new TestPartition(3)) + } + def compute(split: Partition, context: TaskContext) = {Iterator()} + } + val prunedRDD = PartitionPruningRDD.create(rdd, {x => if (x==2) true else false}) + println(prunedRDD.partitions.length) + val p = prunedRDD.partitions(0) + assert(p.index == 2) + } +} \ No newline at end of file -- cgit v1.2.3