aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorShivaram Venkataraman <shivaram@eecs.berkeley.edu>2013-12-18 11:40:07 -0800
committerShivaram Venkataraman <shivaram@eecs.berkeley.edu>2013-12-18 11:40:07 -0800
commitaf0cd6bd27dda73b326bcb6a66addceadebf5e54 (patch)
treee80a49fcc5e8c11325d580a47d0b75658c44afc6
parent7a8169be9a0b6b3d0d53a98aa38940d47b201296 (diff)
downloadspark-af0cd6bd27dda73b326bcb6a66addceadebf5e54.tar.gz
spark-af0cd6bd27dda73b326bcb6a66addceadebf5e54.tar.bz2
spark-af0cd6bd27dda73b326bcb6a66addceadebf5e54.zip
Add collectPartition to JavaRDD interface.
Also remove takePartition from PythonRDD and use collectPartition in rdd.py.
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/JavaAPISuite.java28
-rw-r--r--python/pyspark/context.py3
-rw-r--r--python/pyspark/rdd.py2
5 files changed, 39 insertions, 9 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
index 9e912d3adb..1d71875ed1 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
@@ -26,7 +26,7 @@ import com.google.common.base.Optional
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.spark.{SparkContext, Partition, TaskContext}
-import org.apache.spark.rdd.RDD
+import org.apache.spark.rdd.{RDD, PartitionPruningRDD}
import org.apache.spark.api.java.JavaPairRDD._
import org.apache.spark.api.java.function.{Function2 => JFunction2, Function => JFunction, _}
import org.apache.spark.partial.{PartialResult, BoundedDouble}
@@ -245,6 +245,15 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
}
/**
+ * Return an array that contains all of the elements in a specific partition of this RDD.
+ */
+ def collectPartition(partitionId: Int): JList[T] = {
+ import scala.collection.JavaConversions._
+ val partition = new PartitionPruningRDD[T](rdd, _ == partitionId)
+ new java.util.ArrayList(partition.collect().toSeq)
+ }
+
+ /**
* Reduces the elements of this RDD using the specified commutative and associative binary operator.
*/
def reduce(f: JFunction2[T, T, T]): T = rdd.reduce(f)
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index a659cc06c2..ca42c76928 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -235,10 +235,6 @@ private[spark] object PythonRDD {
file.close()
}
- def takePartition[T](rdd: RDD[T], partition: Int): Iterator[T] = {
- implicit val cm : ClassTag[T] = rdd.elementClassTag
- rdd.context.runJob(rdd, ((x: Iterator[T]) => x.toArray), Seq(partition), true).head.iterator
- }
}
private class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] {
diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
index 4234f6eac7..2862ed3019 100644
--- a/core/src/test/scala/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
@@ -897,4 +897,32 @@ public class JavaAPISuite implements Serializable {
new Tuple2<Integer, Integer>(0, 4)), rdd3.collect());
}
+
+ @Test
+ public void collectPartition() {
+ JavaRDD<Integer> rdd1 = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5, 6, 7), 3);
+
+ JavaPairRDD<Integer, Integer> rdd2 = rdd1.map(new PairFunction<Integer, Integer, Integer>() {
+ @Override
+ public Tuple2<Integer, Integer> call(Integer i) throws Exception {
+ return new Tuple2<Integer, Integer>(i, i % 2);
+ }
+ });
+
+ Assert.assertEquals(Arrays.asList(1, 2), rdd1.collectPartition(0));
+ Assert.assertEquals(Arrays.asList(3, 4), rdd1.collectPartition(1));
+ Assert.assertEquals(Arrays.asList(5, 6, 7), rdd1.collectPartition(2));
+
+ Assert.assertEquals(Arrays.asList(new Tuple2<Integer, Integer>(1, 1),
+ new Tuple2<Integer, Integer>(2, 0)),
+ rdd2.collectPartition(0));
+ Assert.assertEquals(Arrays.asList(new Tuple2<Integer, Integer>(3, 1),
+ new Tuple2<Integer, Integer>(4, 0)),
+ rdd2.collectPartition(1));
+ Assert.assertEquals(Arrays.asList(new Tuple2<Integer, Integer>(5, 1),
+ new Tuple2<Integer, Integer>(6, 0),
+ new Tuple2<Integer, Integer>(7, 1)),
+ rdd2.collectPartition(2));
+ }
+
}
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index cbd41e58c4..0604f6836c 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -43,7 +43,6 @@ class SparkContext(object):
_gateway = None
_jvm = None
_writeToFile = None
- _takePartition = None
_next_accum_id = 0
_active_spark_context = None
_lock = Lock()
@@ -134,8 +133,6 @@ class SparkContext(object):
SparkContext._jvm = SparkContext._gateway.jvm
SparkContext._writeToFile = \
SparkContext._jvm.PythonRDD.writeToFile
- SparkContext._takePartition = \
- SparkContext._jvm.PythonRDD.takePartition
if instance:
if SparkContext._active_spark_context and SparkContext._active_spark_context != instance:
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 61720dcf1a..d81b7c90c1 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -577,7 +577,7 @@ class RDD(object):
mapped = self.mapPartitions(takeUpToNum)
items = []
for partition in range(mapped._jrdd.splits().size()):
- iterator = self.ctx._takePartition(mapped._jrdd.rdd(), partition)
+ iterator = mapped._jrdd.collectPartition(partition).iterator()
items.extend(mapped._collect_iterator_through_file(iterator))
if len(items) >= num:
break