aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala8
-rw-r--r--core/src/test/scala/org/apache/spark/JavaAPISuite.java19
-rw-r--r--python/pyspark/rdd.py7
3 files changed, 22 insertions, 12 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
index 1d71875ed1..458d9dcbc3 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
@@ -26,7 +26,7 @@ import com.google.common.base.Optional
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.spark.{SparkContext, Partition, TaskContext}
-import org.apache.spark.rdd.{RDD, PartitionPruningRDD}
+import org.apache.spark.rdd.RDD
import org.apache.spark.api.java.JavaPairRDD._
import org.apache.spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _}
import org.apache.spark.partial.{PartialResult, BoundedDouble}
@@ -247,10 +247,10 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
/**
* Return an array that contains all of the elements in a specific partition of this RDD.
*/
- def collectPartition(partitionId: Int): JList[T] = {
+ def collectPartitions(partitionIds: Array[Int]): Array[JList[T]] = {
import scala.collection.JavaConversions._
- val partition = new PartitionPruningRDD[T](rdd, _ == partitionId)
- new java.util.ArrayList(partition.collect().toSeq)
+ val res = context.runJob(rdd, (it: Iterator[T]) => it.toArray, partitionIds, true)
+ res.map(x => new java.util.ArrayList(x.toSeq)).toArray
}
/**
diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
index 2862ed3019..79913dc718 100644
--- a/core/src/test/scala/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
@@ -899,7 +899,7 @@ public class JavaAPISuite implements Serializable {
}
@Test
- public void collectPartition() {
+ public void collectPartitions() {
JavaRDD<Integer> rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7), 3);
JavaPairRDD<Integer, Integer> rdd2 = rdd1.map(new PairFunction<Integer, Integer, Integer>() {
@@ -909,20 +909,25 @@ public class JavaAPISuite implements Serializable {
}
});
- Assert.assertEquals(Arrays.asList(1, 2), rdd1.collectPartition(0));
- Assert.assertEquals(Arrays.asList(3, 4), rdd1.collectPartition(1));
- Assert.assertEquals(Arrays.asList(5, 6, 7), rdd1.collectPartition(2));
+ List[] parts = rdd1.collectPartitions(new int[] {0});
+ Assert.assertEquals(Arrays.asList(1, 2), parts[0]);
+
+ parts = rdd1.collectPartitions(new int[] {1, 2});
+ Assert.assertEquals(Arrays.asList(3, 4), parts[0]);
+ Assert.assertEquals(Arrays.asList(5, 6, 7), parts[1]);
Assert.assertEquals(Arrays.asList(new Tuple2<Integer, Integer>(1, 1),
new Tuple2<Integer, Integer>(2, 0)),
- rdd2.collectPartition(0));
+ rdd2.collectPartitions(new int[] {0})[0]);
+
+ parts = rdd2.collectPartitions(new int[] {1, 2});
Assert.assertEquals(Arrays.asList(new Tuple2<Integer, Integer>(3, 1),
new Tuple2<Integer, Integer>(4, 0)),
- rdd2.collectPartition(1));
+ parts[0]);
Assert.assertEquals(Arrays.asList(new Tuple2<Integer, Integer>(5, 1),
new Tuple2<Integer, Integer>(6, 0),
new Tuple2<Integer, Integer>(7, 1)),
- rdd2.collectPartition(2));
+ parts[1]);
}
}
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index d81b7c90c1..7015119551 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -576,8 +576,13 @@ class RDD(object):
# Take only up to num elements from each partition we try
mapped = self.mapPartitions(takeUpToNum)
items = []
+ # TODO(shivaram): Similar to the scala implementation, update the take
+ # method to scan multiple splits based on an estimate of how many elements
+ # we have per-split.
for partition in range(mapped._jrdd.splits().size()):
- iterator = mapped._jrdd.collectPartition(partition).iterator()
+ partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1)
+ partitionsToTake[0] = partition
+ iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator()
items.extend(mapped._collect_iterator_through_file(iterator))
if len(items) >= num:
break