aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJosh Rosen <rosenville@gmail.com>2012-07-26 12:04:18 -0700
committerJosh Rosen <rosenville@gmail.com>2012-07-26 12:46:47 -0700
commitc5e2810dc75ac0bae94b1e50ff8c0d198d185b52 (patch)
tree766d3f24b3659e89d636e4bbab363ecd9999bedf
parent2a60c998cc9ca8b9a90d8c7865f7494963395b3b (diff)
downloadspark-c5e2810dc75ac0bae94b1e50ff8c0d198d185b52.tar.gz
spark-c5e2810dc75ac0bae94b1e50ff8c0d198d185b52.tar.bz2
spark-c5e2810dc75ac0bae94b1e50ff8c0d198d185b52.zip
Add persist(), splits(), glom(), and mapPartitions() to Java API.
-rw-r--r--core/src/main/scala/spark/api/java/JavaDoubleRDD.scala3
-rw-r--r--core/src/main/scala/spark/api/java/JavaPairRDD.scala5
-rw-r--r--core/src/main/scala/spark/api/java/JavaRDD.scala3
-rw-r--r--core/src/main/scala/spark/api/java/JavaRDDLike.scala25
-rw-r--r--core/src/test/scala/spark/JavaAPISuite.java50
5 files changed, 83 insertions, 3 deletions
diff --git a/core/src/main/scala/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/spark/api/java/JavaDoubleRDD.scala
index 9a90d0af79..7c0b17c45e 100644
--- a/core/src/main/scala/spark/api/java/JavaDoubleRDD.scala
+++ b/core/src/main/scala/spark/api/java/JavaDoubleRDD.scala
@@ -5,6 +5,7 @@ import spark.SparkContext.doubleRDDToDoubleRDDFunctions
import spark.api.java.function.{Function => JFunction}
import spark.util.StatCounter
import spark.partial.{BoundedDouble, PartialResult}
+import spark.storage.StorageLevel
import java.lang.Double
@@ -23,6 +24,8 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
def cache(): JavaDoubleRDD = fromRDD(srdd.cache())
+ def persist(newLevel: StorageLevel): JavaDoubleRDD = fromRDD(srdd.persist(newLevel))
+
// first() has to be overriden here in order for its return type to be Double instead of Object.
override def first(): Double = srdd.first()
diff --git a/core/src/main/scala/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/spark/api/java/JavaPairRDD.scala
index 99d1b1e208..c28a13b061 100644
--- a/core/src/main/scala/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/spark/api/java/JavaPairRDD.scala
@@ -5,13 +5,13 @@ import spark.api.java.function.{Function2 => JFunction2}
import spark.api.java.function.{Function => JFunction}
import spark.partial.BoundedDouble
import spark.partial.PartialResult
+import spark.storage.StorageLevel
import spark._
import java.util.{List => JList}
import java.util.Comparator
import scala.Tuple2
-import scala.collection.Map
import scala.collection.JavaConversions._
import org.apache.hadoop.mapred.JobConf
@@ -33,6 +33,9 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
def cache(): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.cache())
+ def persist(newLevel: StorageLevel): JavaPairRDD[K, V] =
+ new JavaPairRDD[K, V](rdd.persist(newLevel))
+
// Transformations (return a new RDD)
def distinct(): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.distinct())
diff --git a/core/src/main/scala/spark/api/java/JavaRDD.scala b/core/src/main/scala/spark/api/java/JavaRDD.scala
index 598d4cf15b..541aa1e60b 100644
--- a/core/src/main/scala/spark/api/java/JavaRDD.scala
+++ b/core/src/main/scala/spark/api/java/JavaRDD.scala
@@ -2,6 +2,7 @@ package spark.api.java
import spark._
import spark.api.java.function.{Function => JFunction}
+import spark.storage.StorageLevel
class JavaRDD[T](val rdd: RDD[T])(implicit val classManifest: ClassManifest[T]) extends
JavaRDDLike[T, JavaRDD[T]] {
@@ -12,6 +13,8 @@ JavaRDDLike[T, JavaRDD[T]] {
def cache(): JavaRDD[T] = wrapRDD(rdd.cache())
+ def persist(newLevel: StorageLevel): JavaRDD[T] = wrapRDD(rdd.persist(newLevel))
+
// Transformations (return a new RDD)
def distinct(): JavaRDD[T] = wrapRDD(rdd.distinct())
diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
index 1c6948eb7f..785dd96394 100644
--- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
@@ -9,7 +9,7 @@ import spark.storage.StorageLevel
import java.util.{List => JList}
import scala.collection.JavaConversions._
-import java.lang
+import java.{util, lang}
import scala.Tuple2
trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
@@ -19,6 +19,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
def rdd: RDD[T]
+ def splits: JList[Split] = new java.util.ArrayList(rdd.splits.toSeq)
+
def context: SparkContext = rdd.context
def id: Int = rdd.id
@@ -56,9 +58,28 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
import scala.collection.JavaConverters._
def fn = (x: T) => f.apply(x).asScala
def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]]
- new JavaPairRDD(rdd.flatMap(fn)(cm))(f.keyType(), f.valueType())
+ JavaPairRDD.fromRDD(rdd.flatMap(fn)(cm))(f.keyType(), f.valueType())
+ }
+
+ def mapPartitions[U](f: FlatMapFunction[java.util.Iterator[T], U]): JavaRDD[U] = {
+ def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator())
+ JavaRDD.fromRDD(rdd.mapPartitions(fn)(f.elementType()))(f.elementType())
+ }
+
+ def mapPartitions(f: DoubleFlatMapFunction[java.util.Iterator[T]]): JavaDoubleRDD = {
+ def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator())
+ new JavaDoubleRDD(rdd.mapPartitions(fn).map((x: java.lang.Double) => x.doubleValue()))
+ }
+
+ def mapPartitions[K, V](f: PairFlatMapFunction[java.util.Iterator[T], K, V]):
+ JavaPairRDD[K, V] = {
+ def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator())
+ JavaPairRDD.fromRDD(rdd.mapPartitions(fn))(f.keyType(), f.valueType())
}
+ def glom(): JavaRDD[JList[T]] =
+ new JavaRDD(rdd.glom().map(x => new java.util.ArrayList[T](x.toSeq)))
+
def cartesian[U](other: JavaRDDLike[U, _]): JavaPairRDD[T, U] =
JavaPairRDD.fromRDD(rdd.cartesian(other.rdd)(other.classManifest))(classManifest,
other.classManifest)
diff --git a/core/src/test/scala/spark/JavaAPISuite.java b/core/src/test/scala/spark/JavaAPISuite.java
index f6c0e539e6..436a8ab0c7 100644
--- a/core/src/test/scala/spark/JavaAPISuite.java
+++ b/core/src/test/scala/spark/JavaAPISuite.java
@@ -21,6 +21,7 @@ import spark.api.java.JavaSparkContext;
import spark.api.java.function.*;
import spark.partial.BoundedDouble;
import spark.partial.PartialResult;
+import spark.storage.StorageLevel;
import spark.util.StatCounter;
import java.io.File;
@@ -337,6 +338,55 @@ public class JavaAPISuite implements Serializable {
Assert.assertEquals(11, pairs.count());
}
+ @Test
+ public void mapPartitions() {
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2);
+ JavaRDD<Integer> partitionSums = rdd.mapPartitions(
+ new FlatMapFunction<Iterator<Integer>, Integer>() {
+ @Override
+ public Iterable<Integer> apply(Iterator<Integer> iter) {
+ int sum = 0;
+ while (iter.hasNext()) {
+ sum += iter.next();
+ }
+ return Collections.singletonList(sum);
+ }
+ });
+ Assert.assertEquals("[3, 7]", partitionSums.collect().toString());
+ }
+
+ @Test
+ public void persist() {
+ JavaDoubleRDD doubleRDD = sc.parallelizeDoubles(Arrays.asList(1.0, 1.0, 2.0, 3.0, 5.0, 8.0));
+ doubleRDD = doubleRDD.persist(StorageLevel.DISK_ONLY());
+ Assert.assertEquals(20, doubleRDD.sum(), 0.1);
+
+ List<Tuple2<Integer, String>> pairs = Arrays.asList(
+ new Tuple2<Integer, String>(1, "a"),
+ new Tuple2<Integer, String>(2, "aa"),
+ new Tuple2<Integer, String>(3, "aaa")
+ );
+ JavaPairRDD<Integer, String> pairRDD = sc.parallelizePairs(pairs);
+ pairRDD = pairRDD.persist(StorageLevel.DISK_ONLY());
+ Assert.assertEquals("a", pairRDD.first()._2());
+
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
+ rdd = rdd.persist(StorageLevel.DISK_ONLY());
+ Assert.assertEquals(1, rdd.first().intValue());
+ }
+
+ @Test
+ public void iterator() {
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5), 2);
+ Assert.assertEquals(1, rdd.iterator(rdd.splits().get(0)).next().intValue());
+ }
+
+ @Test
+ public void glom() {
+ JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4), 2);
+ Assert.assertEquals("[1, 2]", rdd.glom().first().toString());
+ }
+
// File input / output tests are largely adapted from FileSuite:
@Test