aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorReynold Xin <rxin@apache.org>2013-12-19 13:35:09 -0800
committerReynold Xin <rxin@apache.org>2013-12-19 13:35:09 -0800
commit7990c5637519ae2def30dfba19b7c83562c0ec00 (patch)
tree0c645fee598e3a9da68e5e8bb88bcafef651f04b /core
parent440e531a5e7720c42f0c53ce98425b63b4194b7b (diff)
parent9cc3a6d3c0a64b80af77ae358c58d4b29b18c534 (diff)
downloadspark-7990c5637519ae2def30dfba19b7c83562c0ec00.tar.gz
spark-7990c5637519ae2def30dfba19b7c83562c0ec00.tar.bz2
spark-7990c5637519ae2def30dfba19b7c83562c0ec00.zip
Merge pull request #276 from shivaram/collectPartition
Add collectPartition to JavaRDD interface. This interface is useful for implementing `take` from other language frontends where the data is serialized. Also remove `takePartition` from PythonRDD and use `collectPartition` in rdd.py. Thanks @concretevitamin for the original change and tests.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/JavaAPISuite.java33
3 files changed, 44 insertions, 4 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
index 9e912d3adb..f344804b4c 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
@@ -245,6 +245,17 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
}
/**
+ * Return an array that contains all of the elements in a specific partition of this RDD.
+ */
+ def collectPartitions(partitionIds: Array[Int]): Array[JList[T]] = {
+ // This is useful for implementing `take` from other language frontends
+ // like Python where the data is serialized.
+ import scala.collection.JavaConversions._
+ val res = context.runJob(rdd, (it: Iterator[T]) => it.toArray, partitionIds, true)
+ res.map(x => new java.util.ArrayList(x.toSeq)).toArray
+ }
+
+ /**
* Reduces the elements of this RDD using the specified commutative and associative binary operator.
*/
def reduce(f: JFunction2[T, T, T]): T = rdd.reduce(f)
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index a659cc06c2..ca42c76928 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -235,10 +235,6 @@ private[spark] object PythonRDD {
file.close()
}
- def takePartition[T](rdd: RDD[T], partition: Int): Iterator[T] = {
- implicit val cm : ClassTag[T] = rdd.elementClassTag
- rdd.context.runJob(rdd, ((x: Iterator[T]) => x.toArray), Seq(partition), true).head.iterator
- }
}
private class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] {
diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
index 4234f6eac7..79913dc718 100644
--- a/core/src/test/scala/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
@@ -897,4 +897,37 @@ public class JavaAPISuite implements Serializable {
new Tuple2<Integer, Integer>(0, 4)), rdd3.collect());
}
+
+ @Test
+ public void collectPartitions() {
+ JavaRDD<Integer> rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7), 3);
+
+ JavaPairRDD<Integer, Integer> rdd2 = rdd1.map(new PairFunction<Integer, Integer, Integer>() {
+ @Override
+ public Tuple2<Integer, Integer> call(Integer i) throws Exception {
+ return new Tuple2<Integer, Integer>(i, i % 2);
+ }
+ });
+
+ List[] parts = rdd1.collectPartitions(new int[] {0});
+ Assert.assertEquals(Arrays.asList(1, 2), parts[0]);
+
+ parts = rdd1.collectPartitions(new int[] {1, 2});
+ Assert.assertEquals(Arrays.asList(3, 4), parts[0]);
+ Assert.assertEquals(Arrays.asList(5, 6, 7), parts[1]);
+
+ Assert.assertEquals(Arrays.asList(new Tuple2<Integer, Integer>(1, 1),
+ new Tuple2<Integer, Integer>(2, 0)),
+ rdd2.collectPartitions(new int[] {0})[0]);
+
+ parts = rdd2.collectPartitions(new int[] {1, 2});
+ Assert.assertEquals(Arrays.asList(new Tuple2<Integer, Integer>(3, 1),
+ new Tuple2<Integer, Integer>(4, 0)),
+ parts[0]);
+ Assert.assertEquals(Arrays.asList(new Tuple2<Integer, Integer>(5, 1),
+ new Tuple2<Integer, Integer>(6, 0),
+ new Tuple2<Integer, Integer>(7, 1)),
+ parts[1]);
+ }
+
}