aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorMatei Zaharia <matei@eecs.berkeley.edu>2013-01-13 19:34:07 -0800
committerMatei Zaharia <matei@eecs.berkeley.edu>2013-01-13 19:34:07 -0800
commit72408e8dfacc24652f376d1ee4dd6f04edb54804 (patch)
tree01bb5baa94198259cfab27e7fa6bfda29163f3f6 /core
parent62e476739441f7bf6a22c28045561d204e1e4939 (diff)
downloadspark-72408e8dfacc24652f376d1ee4dd6f04edb54804.tar.gz
spark-72408e8dfacc24652f376d1ee4dd6f04edb54804.tar.bz2
spark-72408e8dfacc24652f376d1ee4dd6f04edb54804.zip
Make filter preserve partitioner info, since it can
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/spark/rdd/FilteredRDD.scala3
-rw-r--r--core/src/test/scala/spark/PartitioningSuite.scala5
2 files changed, 7 insertions, 1 deletions
diff --git a/core/src/main/scala/spark/rdd/FilteredRDD.scala b/core/src/main/scala/spark/rdd/FilteredRDD.scala
index b148da28de..d46549b8b6 100644
--- a/core/src/main/scala/spark/rdd/FilteredRDD.scala
+++ b/core/src/main/scala/spark/rdd/FilteredRDD.scala
@@ -7,5 +7,6 @@ private[spark]
class FilteredRDD[T: ClassManifest](prev: RDD[T], f: T => Boolean) extends RDD[T](prev.context) {
override def splits = prev.splits
override val dependencies = List(new OneToOneDependency(prev))
+ override val partitioner = prev.partitioner // Since filter cannot change a partition's keys
override def compute(split: Split, context: TaskContext) = prev.iterator(split, context).filter(f)
-} \ No newline at end of file
+}
diff --git a/core/src/test/scala/spark/PartitioningSuite.scala b/core/src/test/scala/spark/PartitioningSuite.scala
index f09b602a7b..eb3c8f238f 100644
--- a/core/src/test/scala/spark/PartitioningSuite.scala
+++ b/core/src/test/scala/spark/PartitioningSuite.scala
@@ -106,6 +106,11 @@ class PartitioningSuite extends FunSuite with BeforeAndAfter {
assert(grouped2.leftOuterJoin(reduced2).partitioner === grouped2.partitioner)
assert(grouped2.rightOuterJoin(reduced2).partitioner === grouped2.partitioner)
assert(grouped2.cogroup(reduced2).partitioner === grouped2.partitioner)
+
+ assert(grouped2.map(_ => 1).partitioner === None)
+ assert(grouped2.mapValues(_ => 1).partitioner === grouped2.partitioner)
+ assert(grouped2.flatMapValues(_ => Seq(1)).partitioner === grouped2.partitioner)
+ assert(grouped2.filter(_._1 > 4).partitioner === grouped2.partitioner)
}
test("partitioning Java arrays should fail") {