From d3234f9726db3917af4688ba70933938b078b0bd Mon Sep 17 00:00:00 2001 From: Shivaram Venkataraman Date: Thu, 19 Dec 2013 11:40:34 -0800 Subject: Make collectPartitions take an array of partitions Change the implementation to use runJob instead of PartitionPruningRDD. Also update the unit tests and the python take implementation to use the new interface. --- .../scala/org/apache/spark/api/java/JavaRDDLike.scala | 8 ++++---- .../src/test/scala/org/apache/spark/JavaAPISuite.java | 19 ++++++++++++------- python/pyspark/rdd.py | 7 ++++++- 3 files changed, 22 insertions(+), 12 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 1d71875ed1..458d9dcbc3 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -26,7 +26,7 @@ import com.google.common.base.Optional import org.apache.hadoop.io.compress.CompressionCodec import org.apache.spark.{SparkContext, Partition, TaskContext} -import org.apache.spark.rdd.{RDD, PartitionPruningRDD} +import org.apache.spark.rdd.RDD import org.apache.spark.api.java.JavaPairRDD._ import org.apache.spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _} import org.apache.spark.partial.{PartialResult, BoundedDouble} @@ -247,10 +247,10 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Return an array that contains all of the elements in a specific partition of this RDD. */ - def collectPartition(partitionId: Int): JList[T] = { + def collectPartitions(partitionIds: Array[Int]): Array[JList[T]] = { import scala.collection.JavaConversions._ - val partition = new PartitionPruningRDD[T](rdd, _ == partitionId) - new java.util.ArrayList(partition.collect().toSeq) + val res = context.runJob(rdd, (it: Iterator[T]) => it.toArray, partitionIds, true) + res.map(x => new java.util.ArrayList(x.toSeq)).toArray } /** diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java index 2862ed3019..79913dc718 100644 --- a/core/src/test/scala/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java @@ -899,7 +899,7 @@ public class JavaAPISuite implements Serializable { } @Test - public void collectPartition() { + public void collectPartitions() { JavaRDD rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7), 3); JavaPairRDD rdd2 = rdd1.map(new PairFunction() { @@ -909,20 +909,25 @@ public class JavaAPISuite implements Serializable { } }); - Assert.assertEquals(Arrays.asList(1, 2), rdd1.collectPartition(0)); - Assert.assertEquals(Arrays.asList(3, 4), rdd1.collectPartition(1)); - Assert.assertEquals(Arrays.asList(5, 6, 7), rdd1.collectPartition(2)); + List[] parts = rdd1.collectPartitions(new int[] {0}); + Assert.assertEquals(Arrays.asList(1, 2), parts[0]); + + parts = rdd1.collectPartitions(new int[] {1, 2}); + Assert.assertEquals(Arrays.asList(3, 4), parts[0]); + Assert.assertEquals(Arrays.asList(5, 6, 7), parts[1]); Assert.assertEquals(Arrays.asList(new Tuple2(1, 1), new Tuple2(2, 0)), - rdd2.collectPartition(0)); + rdd2.collectPartitions(new int[] {0})[0]); + + parts = rdd2.collectPartitions(new int[] {1, 2}); Assert.assertEquals(Arrays.asList(new Tuple2(3, 1), new Tuple2(4, 0)), - rdd2.collectPartition(1)); + parts[0]); Assert.assertEquals(Arrays.asList(new Tuple2(5, 1), new Tuple2(6, 0), new Tuple2(7, 1)), - rdd2.collectPartition(2)); + parts[1]); } } diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index d81b7c90c1..7015119551 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -576,8 +576,13 @@ class RDD(object): # Take only up to num elements from each partition we try mapped = self.mapPartitions(takeUpToNum) items = [] + # TODO(shivaram): Similar to the scala implementation, update the take + # method to scan multiple splits based on an estimate of how many elements + # we have per-split. for partition in range(mapped._jrdd.splits().size()): - iterator = mapped._jrdd.collectPartition(partition).iterator() + partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1) + partitionsToTake[0] = partition + iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator() items.extend(mapped._collect_iterator_through_file(iterator)) if len(items) >= num: break -- cgit v1.2.3