aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala3
4 files changed, 23 insertions, 9 deletions
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
index 0846225e4f..c38b96528d 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -35,6 +35,7 @@ import org.apache.spark.Partitioner._
import org.apache.spark.SparkContext.rddToPairRDDFunctions
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
+import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap
import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, PairFunction}
import org.apache.spark.partial.{BoundedDouble, PartialResult}
import org.apache.spark.rdd.{OrderedRDDFunctions, RDD}
@@ -265,10 +266,10 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
* before sending results to a reducer, similarly to a "combiner" in MapReduce.
*/
def reduceByKeyLocally(func: JFunction2[V, V, V]): java.util.Map[K, V] =
- mapAsJavaMap(rdd.reduceByKeyLocally(func))
+ mapAsSerializableJavaMap(rdd.reduceByKeyLocally(func))
/** Count the number of elements for each key, and return the result to the master as a Map. */
- def countByKey(): java.util.Map[K, Long] = mapAsJavaMap(rdd.countByKey())
+ def countByKey(): java.util.Map[K, Long] = mapAsSerializableJavaMap(rdd.countByKey())
/**
* :: Experimental ::
@@ -277,7 +278,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
*/
@Experimental
def countByKeyApprox(timeout: Long): PartialResult[java.util.Map[K, BoundedDouble]] =
- rdd.countByKeyApprox(timeout).map(mapAsJavaMap)
+ rdd.countByKeyApprox(timeout).map(mapAsSerializableJavaMap)
/**
* :: Experimental ::
@@ -287,7 +288,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
@Experimental
def countByKeyApprox(timeout: Long, confidence: Double = 0.95)
: PartialResult[java.util.Map[K, BoundedDouble]] =
- rdd.countByKeyApprox(timeout, confidence).map(mapAsJavaMap)
+ rdd.countByKeyApprox(timeout, confidence).map(mapAsSerializableJavaMap)
/**
* Aggregate the values of each key, using given combine functions and a neutral "zero value".
@@ -614,7 +615,8 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])
/**
* Return the key-value pairs in this RDD to the master as a Map.
*/
- def collectAsMap(): java.util.Map[K, V] = mapAsJavaMap(rdd.collectAsMap())
+ def collectAsMap(): java.util.Map[K, V] = mapAsSerializableJavaMap(rdd.collectAsMap())
+
/**
* Pass each value in the key-value pair RDD through a map function without changing the keys;
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
index 545bc0e9e9..c744399483 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
@@ -30,6 +30,7 @@ import org.apache.spark.{FutureAction, Partition, SparkContext, TaskContext}
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaPairRDD._
import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
+import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap
import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2, _}
import org.apache.spark.partial.{BoundedDouble, PartialResult}
import org.apache.spark.rdd.RDD
@@ -390,7 +391,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* combine step happens locally on the master, equivalent to running a single reduce task.
*/
def countByValue(): java.util.Map[T, java.lang.Long] =
- mapAsJavaMap(rdd.countByValue().map((x => (x._1, new java.lang.Long(x._2)))))
+ mapAsSerializableJavaMap(rdd.countByValue().map((x => (x._1, new java.lang.Long(x._2)))))
/**
* (Experimental) Approximate version of countByValue().
@@ -399,13 +400,13 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
timeout: Long,
confidence: Double
): PartialResult[java.util.Map[T, BoundedDouble]] =
- rdd.countByValueApprox(timeout, confidence).map(mapAsJavaMap)
+ rdd.countByValueApprox(timeout, confidence).map(mapAsSerializableJavaMap)
/**
* (Experimental) Approximate version of countByValue().
*/
def countByValueApprox(timeout: Long): PartialResult[java.util.Map[T, BoundedDouble]] =
- rdd.countByValueApprox(timeout).map(mapAsJavaMap)
+ rdd.countByValueApprox(timeout).map(mapAsSerializableJavaMap)
/**
* Take the first num elements of the RDD. This currently scans the partitions *one by one*, so
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala
index 22810cb1c6..b52d0a5028 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala
@@ -19,10 +19,20 @@ package org.apache.spark.api.java
import com.google.common.base.Optional
+import scala.collection.convert.Wrappers.MapWrapper
+
private[spark] object JavaUtils {
def optionToOptional[T](option: Option[T]): Optional[T] =
option match {
case Some(value) => Optional.of(value)
case None => Optional.absent()
}
+
+ // Workaround for SPARK-3926 / SI-8911
+ def mapAsSerializableJavaMap[A, B](underlying: collection.Map[A, B]) =
+ new SerializableMapWrapper(underlying)
+
+ class SerializableMapWrapper[A, B](underlying: collection.Map[A, B])
+ extends MapWrapper(underlying) with java.io.Serializable
+
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala
index e9d04ce7aa..df01411f60 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/api/java/Row.scala
@@ -22,6 +22,7 @@ import scala.collection.convert.Wrappers.{JListWrapper, JMapWrapper}
import scala.collection.JavaConversions
import scala.math.BigDecimal
+import org.apache.spark.api.java.JavaUtils.mapAsSerializableJavaMap
import org.apache.spark.sql.catalyst.expressions.{Row => ScalaRow}
/**
@@ -114,7 +115,7 @@ object Row {
// they are actually accessed.
case row: ScalaRow => new Row(row)
case map: scala.collection.Map[_, _] =>
- JavaConversions.mapAsJavaMap(
+ mapAsSerializableJavaMap(
map.map {
case (key, value) => (toJavaValue(key), toJavaValue(value))
}