diff options
author | Reynold Xin <rxin@apache.org> | 2013-12-19 13:35:09 -0800 |
---|---|---|
committer | Reynold Xin <rxin@apache.org> | 2013-12-19 13:35:09 -0800 |
commit | 7990c5637519ae2def30dfba19b7c83562c0ec00 (patch) | |
tree | 0c645fee598e3a9da68e5e8bb88bcafef651f04b | |
parent | 440e531a5e7720c42f0c53ce98425b63b4194b7b (diff) | |
parent | 9cc3a6d3c0a64b80af77ae358c58d4b29b18c534 (diff) | |
download | spark-7990c5637519ae2def30dfba19b7c83562c0ec00.tar.gz spark-7990c5637519ae2def30dfba19b7c83562c0ec00.tar.bz2 spark-7990c5637519ae2def30dfba19b7c83562c0ec00.zip |
Merge pull request #276 from shivaram/collectPartition
Add collectPartition to JavaRDD interface.
This interface is useful for implementing `take` from other language frontends where the data is serialized. Also remove `takePartition` from PythonRDD and use `collectPartition` in rdd.py.
Thanks @concretevitamin for the original change and tests.
-rw-r--r-- | core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala | 11 | ||||
-rw-r--r-- | core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala | 4 | ||||
-rw-r--r-- | core/src/test/scala/org/apache/spark/JavaAPISuite.java | 33 | ||||
-rw-r--r-- | python/pyspark/context.py | 3 | ||||
-rw-r--r-- | python/pyspark/rdd.py | 7 |
5 files changed, 50 insertions, 8 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 9e912d3adb..f344804b4c 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -245,6 +245,17 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { } /** + * Return an array that contains all of the elements in a specific partition of this RDD. + */ + def collectPartitions(partitionIds: Array[Int]): Array[JList[T]] = { + // This is useful for implementing `take` from other language frontends + // like Python where the data is serialized. + import scala.collection.JavaConversions._ + val res = context.runJob(rdd, (it: Iterator[T]) => it.toArray, partitionIds, true) + res.map(x => new java.util.ArrayList(x.toSeq)).toArray + } + + /** * Reduces the elements of this RDD using the specified commutative and associative binary operator. */ def reduce(f: JFunction2[T, T, T]): T = rdd.reduce(f) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index a659cc06c2..ca42c76928 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -235,10 +235,6 @@ private[spark] object PythonRDD { file.close() } - def takePartition[T](rdd: RDD[T], partition: Int): Iterator[T] = { - implicit val cm : ClassTag[T] = rdd.elementClassTag - rdd.context.runJob(rdd, ((x: Iterator[T]) => x.toArray), Seq(partition), true).head.iterator - } } private class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] { diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java index 4234f6eac7..79913dc718 100644 --- a/core/src/test/scala/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java @@ -897,4 +897,37 @@ public class JavaAPISuite implements Serializable { new Tuple2<Integer, Integer>(0, 4)), rdd3.collect()); } + + @Test + public void collectPartitions() { + JavaRDD<Integer> rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7), 3); + + JavaPairRDD<Integer, Integer> rdd2 = rdd1.map(new PairFunction<Integer, Integer, Integer>() { + @Override + public Tuple2<Integer, Integer> call(Integer i) throws Exception { + return new Tuple2<Integer, Integer>(i, i % 2); + } + }); + + List[] parts = rdd1.collectPartitions(new int[] {0}); + Assert.assertEquals(Arrays.asList(1, 2), parts[0]); + + parts = rdd1.collectPartitions(new int[] {1, 2}); + Assert.assertEquals(Arrays.asList(3, 4), parts[0]); + Assert.assertEquals(Arrays.asList(5, 6, 7), parts[1]); + + Assert.assertEquals(Arrays.asList(new Tuple2<Integer, Integer>(1, 1), + new Tuple2<Integer, Integer>(2, 0)), + rdd2.collectPartitions(new int[] {0})[0]); + + parts = rdd2.collectPartitions(new int[] {1, 2}); + Assert.assertEquals(Arrays.asList(new Tuple2<Integer, Integer>(3, 1), + new Tuple2<Integer, Integer>(4, 0)), + parts[0]); + Assert.assertEquals(Arrays.asList(new Tuple2<Integer, Integer>(5, 1), + new Tuple2<Integer, Integer>(6, 0), + new Tuple2<Integer, Integer>(7, 1)), + parts[1]); + } + } diff --git a/python/pyspark/context.py b/python/pyspark/context.py index cbd41e58c4..0604f6836c 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -43,7 +43,6 @@ class SparkContext(object): _gateway = None _jvm = None _writeToFile = None - _takePartition = None _next_accum_id = 0 _active_spark_context = None _lock = Lock() @@ -134,8 +133,6 @@ class SparkContext(object): SparkContext._jvm = SparkContext._gateway.jvm SparkContext._writeToFile = \ SparkContext._jvm.PythonRDD.writeToFile - SparkContext._takePartition = \ - SparkContext._jvm.PythonRDD.takePartition if instance: if SparkContext._active_spark_context and SparkContext._active_spark_context != instance: diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 7cbc66d3c9..f87923e6fa 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -579,8 +579,13 @@ class RDD(object): # Take only up to num elements from each partition we try mapped = self.mapPartitions(takeUpToNum) items = [] + # TODO(shivaram): Similar to the scala implementation, update the take + # method to scan multiple splits based on an estimate of how many elements + # we have per-split. for partition in range(mapped._jrdd.splits().size()): - iterator = self.ctx._takePartition(mapped._jrdd.rdd(), partition) + partitionsToTake = self.ctx._gateway.new_array(self.ctx._jvm.int, 1) + partitionsToTake[0] = partition + iterator = mapped._jrdd.collectPartitions(partitionsToTake)[0].iterator() items.extend(mapped._collect_iterator_through_file(iterator)) if len(items) >= num: break |