aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala17
-rw-r--r--core/src/test/java/org/apache/spark/JavaAPISuite.java17
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala8
3 files changed, 42 insertions, 0 deletions
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 74ac97091f..e1c49e35ab 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -1236,6 +1236,23 @@ abstract class RDD[T: ClassTag](
/** The [[org.apache.spark.SparkContext]] that this RDD was created on. */
def context = sc
+ /**
+ * Private API for changing an RDD's ClassTag.
+ * Used for internal Java <-> Scala API compatibility.
+ */
+ private[spark] def retag(cls: Class[T]): RDD[T] = {
+ val classTag: ClassTag[T] = ClassTag.apply(cls)
+ this.retag(classTag)
+ }
+
+ /**
+ * Private API for changing an RDD's ClassTag.
+ * Used for internal Java <-> Scala API compatibility.
+ */
+ private[spark] def retag(implicit classTag: ClassTag[T]): RDD[T] = {
+ this.mapPartitions(identity, preservesPartitioning = true)(classTag)
+ }
+
// Avoid handling doCheckpoint multiple times to prevent excessive recursion
@transient private var doCheckpointCalled = false
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index e8bd65f8e4..fab64a54e2 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -1245,4 +1245,21 @@ public class JavaAPISuite implements Serializable {
Assert.assertTrue(worExactCounts.get(0) == 2);
Assert.assertTrue(worExactCounts.get(1) == 4);
}
+
+ private static class SomeCustomClass implements Serializable {
+ public SomeCustomClass() {
+ // Intentionally left blank
+ }
+ }
+
+ @Test
+ public void collectUnderlyingScalaRDD() {
+ List<SomeCustomClass> data = new ArrayList<SomeCustomClass>();
+ for (int i = 0; i < 100; i++) {
+ data.add(new SomeCustomClass());
+ }
+ JavaRDD<SomeCustomClass> rdd = sc.parallelize(data);
+ SomeCustomClass[] collected = (SomeCustomClass[]) rdd.rdd().retag(SomeCustomClass.class).collect();
+ Assert.assertEquals(data.size(), collected.length);
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index ae6e525875..b31e3a09e5 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.rdd
import scala.collection.mutable.{ArrayBuffer, HashMap}
+import scala.collection.JavaConverters._
import scala.reflect.ClassTag
import org.scalatest.FunSuite
@@ -26,6 +27,7 @@ import org.apache.spark._
import org.apache.spark.SparkContext._
import org.apache.spark.util.Utils
+import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.rdd.RDDSuiteUtils._
class RDDSuite extends FunSuite with SharedSparkContext {
@@ -718,6 +720,12 @@ class RDDSuite extends FunSuite with SharedSparkContext {
assert(ids.length === n)
}
+ test("retag with implicit ClassTag") {
+ val jsc: JavaSparkContext = new JavaSparkContext(sc)
+ val jrdd: JavaRDD[String] = jsc.parallelize(Seq("A", "B", "C").asJava)
+ jrdd.rdd.retag.collect()
+ }
+
test("getNarrowAncestors") {
val rdd1 = sc.parallelize(1 to 100, 4)
val rdd2 = rdd1.filter(_ % 2 == 0).map(_ + 1)