aboutsummaryrefslogtreecommitdiff
path: root/core/src/main/scala/spark/api/java/JavaRDDLike.scala
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/main/scala/spark/api/java/JavaRDDLike.scala')
-rw-r--r--core/src/main/scala/spark/api/java/JavaRDDLike.scala13
1 files changed, 7 insertions, 6 deletions
diff --git a/core/src/main/scala/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
index 90b45cf875..d884529d7a 100644
--- a/core/src/main/scala/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/spark/api/java/JavaRDDLike.scala
@@ -12,7 +12,7 @@ import spark.storage.StorageLevel
import com.google.common.base.Optional
-trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround[T] {
+trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
def wrapRDD(rdd: RDD[T]): This
implicit val classManifest: ClassManifest[T]
@@ -82,12 +82,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround
}
/**
- * Part of the workaround for SPARK-668; called in PairFlatMapWorkaround.java.
+ * Return a new RDD by first applying a function to all elements of this
+ * RDD, and then flattening the results.
*/
- private[spark] def doFlatMap[K, V](f: PairFlatMapFunction[T, K, V]): JavaPairRDD[K, V] = {
+ def flatMap[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairRDD[K2, V2] = {
import scala.collection.JavaConverters._
def fn = (x: T) => f.apply(x).asScala
- def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]]
+ def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K2, V2]]]
JavaPairRDD.fromRDD(rdd.flatMap(fn)(cm))(f.keyType(), f.valueType())
}
@@ -110,8 +111,8 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends PairFlatMapWorkaround
/**
* Return a new RDD by applying a function to each partition of this RDD.
*/
- def mapPartitions[K, V](f: PairFlatMapFunction[java.util.Iterator[T], K, V]):
- JavaPairRDD[K, V] = {
+ def mapPartitions[K2, V2](f: PairFlatMapFunction[java.util.Iterator[T], K2, V2]):
+ JavaPairRDD[K2, V2] = {
def fn = (x: Iterator[T]) => asScalaIterator(f.apply(asJavaIterator(x)).iterator())
JavaPairRDD.fromRDD(rdd.mapPartitions(fn))(f.keyType(), f.valueType())
}