aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorShivaram Venkataraman <shivaram@eecs.berkeley.edu>2013-12-19 11:40:34 -0800
committerShivaram Venkataraman <shivaram@eecs.berkeley.edu>2013-12-19 11:40:34 -0800
commitd3234f9726db3917af4688ba70933938b078b0bd (patch)
treeb343f29b81bbaf8d1165f76e8e5748876b3fb008
parentaf0cd6bd27dda73b326bcb6a66addceadebf5e54 (diff)
downloadspark-d3234f9726db3917af4688ba70933938b078b0bd.tar.gz
spark-d3234f9726db3917af4688ba70933938b078b0bd.tar.bz2
spark-d3234f9726db3917af4688ba70933938b078b0bd.zip
Make collectPartitions take an array of partitions
Change the implementation to use runJob instead of PartitionPruningRDD. Also update the unit tests and the python take implementation to use the new interface.
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala8
-rw-r--r--core/src/test/scala/org/apache/spark/JavaAPISuite.java19
-rw-r--r--python/pyspark/rdd.py7
3 files changed, 22 insertions, 12 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
index 1d71875ed1..458d9dcbc3 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
@@ -26,7 +26,7 @@ import com.google.common.base.Optional
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.spark.{SparkContext, Partition, TaskContext}
-import org.apache.spark.rdd.{RDD, PartitionPruningRDD}
+import org.apache.spark.rdd.RDD
import org.apache.spark.api.java.JavaPairRDD._
import org.apache.spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _}
import org.apache.spark.partial.{PartialResult, BoundedDouble}
@@ -247,10 +247,10 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
/**
* Return an array that contains all of the elements in a specific partition of this RDD.
*/
- def collectPartition(partitionId: Int): JList[T] = {
+ def collectPartitions(partitionIds: Array[Int]): Array[JList[T]] = {
import scala.collection.JavaConversions._
- val partition = new PartitionPruningRDD[T](rdd, _ == partitionId)
- new java.util.ArrayList(partition.collect().toSeq)
+ val res = context.runJob(rdd, (it: Iterator[T]) => it.toArray, partitionIds, true)
+ res.map(x => new java.util.ArrayList(x.toSeq)).toArray
}
/**
diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
index 2862ed3019..79913dc718 100644
--- a/core/src/test/scala/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
@@ -899,7 +899,7 @@ public class JavaAPISuite implements Serializable {
}
@Test
- public void collectPartition() {
+ public void collectPartitions() {
JavaRDD<Integer> rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7), 3);
JavaPairRDD<Integer, Integer> rdd2 = rdd1.map(new PairFunction<Integer, Integer, Integer>() {
@@ -909,20 +909,25 @@ public class JavaAPISuite implements Serializable {
}
});
- Assert.assertEquals(Arrays.asList(1, 2), rdd1.collectPartition(0));
- Assert.assertEquals(Arrays.asList(3, 4), rdd1.collectPartition(1));
- Assert.assertEquals(Arrays.asList(5, 6, 7), rdd1.collectPartition(2));
+ List[] parts = rdd1.collectPartitions(new int[] {0});
+ Assert.assertEquals(Arrays.asList(1, 2), parts[0]);
+
+ parts = rdd1.collectPartitions(new int[] {1, 2});
+ Assert.assertEquals(Arrays.asList(3, 4), parts[0]);
+ Assert.assertEquals(Arrays.asList(5, 6, 7), parts[1]);
Assert.assertEquals(Arrays.asList(new Tuple2<Integer, Integer>(1, 1),
new Tuple2<Integer, Integer>(2, 0)),
- rdd2.collectPartition(0));
+ rdd2.collectPartitions(new int[] {0})[0]);
+
+ parts = rdd2.collectPartitions(new int[] {1, 2});
Assert.assertEquals(Arrays.asList(new Tuple2<Integer, Integer>(3, 1),
new Tuple2<Integer, Integer>(4, 0)),
- rdd2.collectPartition(1));
+ parts[0]);
Assert.assertEquals(Arrays.asList(new Tuple2<Integer, Integer>(5, 1),
new Tuple2<Integer, Integer>(6, 0),
new Tuple2<Integer, Integer>(7, 1)),
- rdd2.collectPartition(2));
+ parts[1]);
}
}
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index d81b7c90c1..7015119551 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -576,8 +576,13 @@ class RDD(object):
# Take only up to num elements from each partition we try
mapped = self.mapPartitions(takeUpToNum)
items = []
+ # TODO(shivaram): Similar to the scala implementation, update the take
+ # method to scan multiple splits based on an estimate of how many elements
+ # we have per-split.
for partition in range(mapped._jrdd.splits().size()):
- iterator = mapped._jrdd.collectPartition(partition).iterator()
+ partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1)
+ partitionsToTake[0] = partition
+ iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator()
items.extend(mapped._collect_iterator_through_file(iterator))
if len(items) >= num:
break