diff options
-rw-r--r-- | core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala | 8 | ||||
-rw-r--r-- | core/src/test/scala/org/apache/spark/JavaAPISuite.java | 19 | ||||
-rw-r--r-- | python/pyspark/rdd.py | 7 |
3 files changed, 22 insertions, 12 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 1d71875ed1..458d9dcbc3 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -26,7 +26,7 @@ import com.google.common.base.Optional import org.apache.hadoop.io.compress.CompressionCodec import org.apache.spark.{SparkContext, Partition, TaskContext} -import org.apache.spark.rdd.{RDD, PartitionPruningRDD} +import org.apache.spark.rdd.RDD import org.apache.spark.api.java.JavaPairRDD._ import org.apache.spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _} import org.apache.spark.partial.{PartialResult, BoundedDouble} @@ -247,10 +247,10 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { /** * Return an array that contains all of the elements in a specific partition of this RDD. */ - def collectPartition(partitionId: Int): JList[T] = { + def collectPartitions(partitionIds: Array[Int]): Array[JList[T]] = { import scala.collection.JavaConversions._ - val partition = new PartitionPruningRDD[T](rdd, _ == partitionId) - new java.util.ArrayList(partition.collect().toSeq) + val res = context.runJob(rdd, (it: Iterator[T]) => it.toArray, partitionIds, true) + res.map(x => new java.util.ArrayList(x.toSeq)).toArray } /** diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java index 2862ed3019..79913dc718 100644 --- a/core/src/test/scala/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java @@ -899,7 +899,7 @@ public class JavaAPISuite implements Serializable { } @Test - public void collectPartition() { + public void collectPartitions() { JavaRDD<Integer> rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7), 3); JavaPairRDD<Integer, Integer> rdd2 = rdd1.map(new PairFunction<Integer, Integer, Integer>() { @@ -909,20 +909,25 @@ public class JavaAPISuite implements Serializable { } }); - Assert.assertEquals(Arrays.asList(1, 2), rdd1.collectPartition(0)); - Assert.assertEquals(Arrays.asList(3, 4), rdd1.collectPartition(1)); - Assert.assertEquals(Arrays.asList(5, 6, 7), rdd1.collectPartition(2)); + List[] parts = rdd1.collectPartitions(new int[] {0}); + Assert.assertEquals(Arrays.asList(1, 2), parts[0]); + + parts = rdd1.collectPartitions(new int[] {1, 2}); + Assert.assertEquals(Arrays.asList(3, 4), parts[0]); + Assert.assertEquals(Arrays.asList(5, 6, 7), parts[1]); Assert.assertEquals(Arrays.asList(new Tuple2<Integer, Integer>(1, 1), new Tuple2<Integer, Integer>(2, 0)), - rdd2.collectPartition(0)); + rdd2.collectPartitions(new int[] {0})[0]); + + parts = rdd2.collectPartitions(new int[] {1, 2}); Assert.assertEquals(Arrays.asList(new Tuple2<Integer, Integer>(3, 1), new Tuple2<Integer, Integer>(4, 0)), - rdd2.collectPartition(1)); + parts[0]); Assert.assertEquals(Arrays.asList(new Tuple2<Integer, Integer>(5, 1), new Tuple2<Integer, Integer>(6, 0), new Tuple2<Integer, Integer>(7, 1)), - rdd2.collectPartition(2)); + parts[1]); } } diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index d81b7c90c1..7015119551 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -576,8 +576,13 @@ class RDD(object): # Take only up to num elements from each partition we try mapped = self.mapPartitions(takeUpToNum) items = [] + # TODO(shivaram): Similar to the scala implementation, update the take + # method to scan multiple splits based on an estimate of how many elements + # we have per-split. for partition in range(mapped._jrdd.splits().size()): - iterator = mapped._jrdd.collectPartition(partition).iterator() + partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1) + partitionsToTake[0] = partition + iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator() items.extend(mapped._collect_iterator_through_file(iterator)) if len(items) >= num: break |