aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--LICENSE3
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala62
-rw-r--r--core/src/test/java/org/apache/spark/JavaAPISuite.java13
3 files changed, 75 insertions, 3 deletions
diff --git a/LICENSE b/LICENSE
index 3c667bf450..0a42d389e4 100644
--- a/LICENSE
+++ b/LICENSE
@@ -646,7 +646,8 @@ THE SOFTWARE.
========================================================================
For Scala Interpreter classes (all .scala files in repl/src/main/scala
-except for Main.Scala, SparkHelper.scala and ExecutorClassLoader.scala):
+except for Main.Scala, SparkHelper.scala and ExecutorClassLoader.scala),
+and for SerializableMapWrapper in JavaUtils.scala:
========================================================================
Copyright (c) 2002-2013 EPFL
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala
index b52d0a5028..86e9493130 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala
@@ -19,7 +19,8 @@ package org.apache.spark.api.java
import com.google.common.base.Optional
-import scala.collection.convert.Wrappers.MapWrapper
+import java.{util => ju}
+import scala.collection.mutable
private[spark] object JavaUtils {
def optionToOptional[T](option: Option[T]): Optional[T] =
@@ -32,7 +33,64 @@ private[spark] object JavaUtils {
def mapAsSerializableJavaMap[A, B](underlying: collection.Map[A, B]) =
new SerializableMapWrapper(underlying)
+ // Implementation is copied from scala.collection.convert.Wrappers.MapWrapper,
+ // but implements java.io.Serializable. It can't just be subclassed to make it
+ // Serializable since the MapWrapper class has no no-arg constructor. This class
+ // doesn't need a no-arg constructor though.
class SerializableMapWrapper[A, B](underlying: collection.Map[A, B])
- extends MapWrapper(underlying) with java.io.Serializable
+ extends ju.AbstractMap[A, B] with java.io.Serializable { self =>
+ override def size = underlying.size
+
+ override def get(key: AnyRef): B = try {
+ underlying get key.asInstanceOf[A] match {
+ case None => null.asInstanceOf[B]
+ case Some(v) => v
+ }
+ } catch {
+ case ex: ClassCastException => null.asInstanceOf[B]
+ }
+
+ override def entrySet: ju.Set[ju.Map.Entry[A, B]] = new ju.AbstractSet[ju.Map.Entry[A, B]] {
+ def size = self.size
+
+ def iterator = new ju.Iterator[ju.Map.Entry[A, B]] {
+ val ui = underlying.iterator
+ var prev : Option[A] = None
+
+ def hasNext = ui.hasNext
+
+ def next() = {
+ val (k, v) = ui.next
+ prev = Some(k)
+ new ju.Map.Entry[A, B] {
+ import scala.util.hashing.byteswap32
+ def getKey = k
+ def getValue = v
+ def setValue(v1 : B) = self.put(k, v1)
+ override def hashCode = byteswap32(k.hashCode) + (byteswap32(v.hashCode) << 16)
+ override def equals(other: Any) = other match {
+ case e: ju.Map.Entry[_, _] => k == e.getKey && v == e.getValue
+ case _ => false
+ }
+ }
+ }
+
+ def remove() {
+ prev match {
+ case Some(k) =>
+ underlying match {
+ case mm: mutable.Map[a, _] =>
+ mm remove k
+ prev = None
+ case _ =>
+ throw new UnsupportedOperationException("remove")
+ }
+ case _ =>
+ throw new IllegalStateException("next must be called at least once before remove")
+ }
+ }
+ }
+ }
+ }
}
diff --git a/core/src/test/java/org/apache/spark/JavaAPISuite.java b/core/src/test/java/org/apache/spark/JavaAPISuite.java
index 3ad4f2f193..e5bdad6bda 100644
--- a/core/src/test/java/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/java/org/apache/spark/JavaAPISuite.java
@@ -1357,6 +1357,19 @@ public class JavaAPISuite implements Serializable {
pairRDD.collectAsMap(); // Used to crash with ClassCastException
}
+ @SuppressWarnings("unchecked")
+ @Test
+ public void collectAsMapAndSerialize() throws Exception {
+ JavaPairRDD<String,Integer> rdd =
+ sc.parallelizePairs(Arrays.asList(new Tuple2<String,Integer>("foo", 1)));
+ Map<String,Integer> map = rdd.collectAsMap();
+ ByteArrayOutputStream bytes = new ByteArrayOutputStream();
+ new ObjectOutputStream(bytes).writeObject(map);
+ Map<String,Integer> deserializedMap = (Map<String,Integer>)
+ new ObjectInputStream(new ByteArrayInputStream(bytes.toByteArray())).readObject();
+ Assert.assertEquals(1, deserializedMap.get("foo").intValue());
+ }
+
@Test
@SuppressWarnings("unchecked")
public void sampleByKey() {