aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMatei Zaharia <matei@databricks.com>2014-05-10 12:10:24 -0700
committerPatrick Wendell <pwendell@gmail.com>2014-05-10 12:10:24 -0700
commit7eefc9d2b3f6ebc0ecb5562da7323f1e06afbb35 (patch)
tree32a35b0898c5710cdbd73b56ef9dcb9914e1cf02
parent8e94d2721a9d3d36697e13f8cc6567ae8aeee78b (diff)
downloadspark-7eefc9d2b3f6ebc0ecb5562da7323f1e06afbb35.tar.gz
spark-7eefc9d2b3f6ebc0ecb5562da7323f1e06afbb35.tar.bz2
spark-7eefc9d2b3f6ebc0ecb5562da7323f1e06afbb35.zip
SPARK-1708. Add a ClassTag on Serializer and things that depend on it
This pull request contains a rebased patch from @heathermiller (https://github.com/heathermiller/spark/pull/1) to add ClassTags on Serializer and types that depend on it (Broadcast and AccumulableCollection). Putting these in the public API signatures now will allow us to use Scala Pickling for serialization down the line without breaking binary compatibility. One question remaining is whether we also want them on Accumulator -- Accumulator is passed as part of a bigger Task or TaskResult object via the closure serializer so it doesn't seem super useful to add the ClassTag there. Broadcast and AccumulableCollection in contrast were being serialized directly. CC @rxin, @pwendell, @heathermiller Author: Matei Zaharia <matei@databricks.com> Closes #700 from mateiz/spark-1708 and squashes the following commits: 1a3d8b0 [Matei Zaharia] Use fake ClassTag in Java 3b449ed [Matei Zaharia] test fix 2209a27 [Matei Zaharia] Code style fixes 9d48830 [Matei Zaharia] Add a ClassTag on Serializer and things that depend on it
-rw-r--r--core/src/main/scala/org/apache/spark/Accumulators.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala13
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/Serializer.scala17
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala11
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala12
-rw-r--r--graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala45
-rw-r--r--graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala5
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala6
22 files changed, 103 insertions, 72 deletions
diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala
index 6d652faae1..cdfd338081 100644
--- a/core/src/main/scala/org/apache/spark/Accumulators.scala
+++ b/core/src/main/scala/org/apache/spark/Accumulators.scala
@@ -21,6 +21,7 @@ import java.io.{ObjectInputStream, Serializable}
import scala.collection.generic.Growable
import scala.collection.mutable.Map
+import scala.reflect.ClassTag
import org.apache.spark.serializer.JavaSerializer
@@ -164,9 +165,9 @@ trait AccumulableParam[R, T] extends Serializable {
def zero(initialValue: R): R
}
-private[spark]
-class GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializable, T]
- extends AccumulableParam[R,T] {
+private[spark] class
+GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializable: ClassTag, T]
+ extends AccumulableParam[R, T] {
def addAccumulator(growable: R, elem: T): R = {
growable += elem
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 9d7c2c8d3d..c639b3e15d 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -756,7 +756,7 @@ class SparkContext(config: SparkConf) extends Logging {
* Growable and TraversableOnce are the standard APIs that guarantee += and ++=, implemented by
* standard mutable collections. So you can use this with mutable Map, Set, etc.
*/
- def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable, T]
+ def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable: ClassTag, T]
(initialValue: R): Accumulable[R, T] = {
val param = new GrowableAccumulableParam[R,T]
new Accumulable(initialValue, param)
@@ -767,7 +767,7 @@ class SparkContext(config: SparkConf) extends Logging {
* [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
* The variable will be sent to each cluster only once.
*/
- def broadcast[T](value: T): Broadcast[T] = {
+ def broadcast[T: ClassTag](value: T): Broadcast[T] = {
val bc = env.broadcastManager.newBroadcast[T](value, isLocal)
cleaner.foreach(_.registerBroadcastForCleanup(bc))
bc
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
index 8b95cda511..a7cfee6d01 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
@@ -447,7 +447,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
* [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
* The variable will be sent to each cluster only once.
*/
- def broadcast[T](value: T): Broadcast[T] = sc.broadcast(value)
+ def broadcast[T](value: T): Broadcast[T] = sc.broadcast(value)(fakeClassTag)
/** Shut down the SparkContext. */
def stop() {
diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
index 738a3b1bed..76956f6a34 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
@@ -21,6 +21,8 @@ import java.io.Serializable
import org.apache.spark.SparkException
+import scala.reflect.ClassTag
+
/**
* A broadcast variable. Broadcast variables allow the programmer to keep a read-only variable
* cached on each machine rather than shipping a copy of it with tasks. They can be used, for
@@ -50,7 +52,7 @@ import org.apache.spark.SparkException
* @param id A unique identifier for the broadcast variable.
* @tparam T Type of the data contained in the broadcast variable.
*/
-abstract class Broadcast[T](val id: Long) extends Serializable {
+abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable {
/**
* Flag signifying whether the broadcast variable is valid
diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
index 8c8ce9b169..a8c827030a 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
@@ -17,6 +17,8 @@
package org.apache.spark.broadcast
+import scala.reflect.ClassTag
+
import org.apache.spark.SecurityManager
import org.apache.spark.SparkConf
import org.apache.spark.annotation.DeveloperApi
@@ -31,7 +33,7 @@ import org.apache.spark.annotation.DeveloperApi
@DeveloperApi
trait BroadcastFactory {
def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit
- def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T]
+ def newBroadcast[T: ClassTag](value: T, isLocal: Boolean, id: Long): Broadcast[T]
def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit
def stop(): Unit
}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala
index cf62aca4d4..c88be6aba6 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala
@@ -19,6 +19,8 @@ package org.apache.spark.broadcast
import java.util.concurrent.atomic.AtomicLong
+import scala.reflect.ClassTag
+
import org.apache.spark._
private[spark] class BroadcastManager(
@@ -56,7 +58,7 @@ private[spark] class BroadcastManager(
private val nextBroadcastId = new AtomicLong(0)
- def newBroadcast[T](value_ : T, isLocal: Boolean) = {
+ def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean) = {
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
index 29372f16f2..78fc286e51 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
@@ -22,6 +22,8 @@ import java.io.{BufferedInputStream, BufferedOutputStream}
import java.net.{URL, URLConnection, URI}
import java.util.concurrent.TimeUnit
+import scala.reflect.ClassTag
+
import org.apache.spark.{HttpServer, Logging, SecurityManager, SparkConf, SparkEnv}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
@@ -34,7 +36,8 @@ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedH
* (through a HTTP server running at the driver) and stored in the BlockManager of the
* executor to speed up future accesses.
*/
-private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
+private[spark] class HttpBroadcast[T: ClassTag](
+ @transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
def getValue = value_
@@ -173,7 +176,7 @@ private[spark] object HttpBroadcast extends Logging {
files += file.getAbsolutePath
}
- def read[T](id: Long): T = {
+ def read[T: ClassTag](id: Long): T = {
logDebug("broadcast read server: " + serverUri + " id: broadcast-" + id)
val url = serverUri + "/" + BroadcastBlockId(id).name
diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala
index e3f6cdc615..d5a031e2bb 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala
@@ -17,6 +17,8 @@
package org.apache.spark.broadcast
+import scala.reflect.ClassTag
+
import org.apache.spark.{SecurityManager, SparkConf}
/**
@@ -29,7 +31,7 @@ class HttpBroadcastFactory extends BroadcastFactory {
HttpBroadcast.initialize(isDriver, conf, securityMgr)
}
- def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
+ def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) =
new HttpBroadcast[T](value_, isLocal, id)
def stop() { HttpBroadcast.stop() }
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
index 2659274c5e..734de37ba1 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -19,6 +19,7 @@ package org.apache.spark.broadcast
import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream}
+import scala.reflect.ClassTag
import scala.math
import scala.util.Random
@@ -44,7 +45,8 @@ import org.apache.spark.util.Utils
* copies of the broadcast data (one per executor) as done by the
* [[org.apache.spark.broadcast.HttpBroadcast]].
*/
-private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
+private[spark] class TorrentBroadcast[T: ClassTag](
+ @transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
def getValue = value_
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala
index d216b58718..1de8396a0e 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala
@@ -17,6 +17,8 @@
package org.apache.spark.broadcast
+import scala.reflect.ClassTag
+
import org.apache.spark.{SecurityManager, SparkConf}
/**
@@ -30,7 +32,7 @@ class TorrentBroadcastFactory extends BroadcastFactory {
TorrentBroadcast.initialize(isDriver, conf)
}
- def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
+ def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) =
new TorrentBroadcast[T](value_, isLocal, id)
def stop() { TorrentBroadcast.stop() }
diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
index 888af541cf..34c51b8330 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
@@ -84,7 +84,7 @@ private[spark] object CheckpointRDD extends Logging {
"part-%05d".format(splitId)
}
- def writeToFile[T](
+ def writeToFile[T: ClassTag](
path: String,
broadcastedConf: Broadcast[SerializableWritable[Configuration]],
blockSize: Int = -1
@@ -160,7 +160,7 @@ private[spark] object CheckpointRDD extends Logging {
val conf = SparkHadoopUtil.get.newConfiguration()
val fs = path.getFileSystem(conf)
val broadcastedConf = sc.broadcast(new SerializableWritable(conf))
- sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, broadcastedConf, 1024) _)
+ sc.runJob(rdd, CheckpointRDD.writeToFile[Int](path.toString, broadcastedConf, 1024) _)
val cpRDD = new CheckpointRDD[Int](sc, path.toString)
assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same")
assert(cpRDD.collect.toList == rdd.collect.toList, "Data of partitions not the same")
diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
index 5f03d7d650..2425929fc7 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
@@ -77,7 +77,7 @@ private[spark] class ParallelCollectionPartition[T: ClassTag](
slice = in.readInt()
val ser = sfactory.newInstance()
- Utils.deserializeViaNestedStream(in, ser)(ds => values = ds.readObject())
+ Utils.deserializeViaNestedStream(in, ser)(ds => values = ds.readObject[Seq[T]]())
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
index 953f0555e5..c3b2a33fb5 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
@@ -92,7 +92,7 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
// Save to file, and reload it as an RDD
val broadcastedConf = rdd.context.broadcast(
new SerializableWritable(rdd.context.hadoopConfiguration))
- rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path.toString, broadcastedConf) _)
+ rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _)
val newRDD = new CheckpointRDD[T](rdd.context, path.toString)
if (newRDD.partitions.size != rdd.partitions.size) {
throw new SparkException(
diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
index e9163deaf2..0a7e1ec539 100644
--- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
@@ -20,6 +20,8 @@ package org.apache.spark.serializer
import java.io._
import java.nio.ByteBuffer
+import scala.reflect.ClassTag
+
import org.apache.spark.SparkConf
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.ByteBufferInputStream
@@ -36,7 +38,7 @@ private[spark] class JavaSerializationStream(out: OutputStream, counterReset: In
* But only call it every 10,000th time to avoid bloated serialization streams (when
* the stream 'resets' object class descriptions have to be re-written)
*/
- def writeObject[T](t: T): SerializationStream = {
+ def writeObject[T: ClassTag](t: T): SerializationStream = {
objOut.writeObject(t)
if (counterReset > 0 && counter >= counterReset) {
objOut.reset()
@@ -46,6 +48,7 @@ private[spark] class JavaSerializationStream(out: OutputStream, counterReset: In
}
this
}
+
def flush() { objOut.flush() }
def close() { objOut.close() }
}
@@ -57,12 +60,12 @@ extends DeserializationStream {
Class.forName(desc.getName, false, loader)
}
- def readObject[T](): T = objIn.readObject().asInstanceOf[T]
+ def readObject[T: ClassTag](): T = objIn.readObject().asInstanceOf[T]
def close() { objIn.close() }
}
private[spark] class JavaSerializerInstance(counterReset: Int) extends SerializerInstance {
- def serialize[T](t: T): ByteBuffer = {
+ def serialize[T: ClassTag](t: T): ByteBuffer = {
val bos = new ByteArrayOutputStream()
val out = serializeStream(bos)
out.writeObject(t)
@@ -70,13 +73,13 @@ private[spark] class JavaSerializerInstance(counterReset: Int) extends Serialize
ByteBuffer.wrap(bos.toByteArray)
}
- def deserialize[T](bytes: ByteBuffer): T = {
+ def deserialize[T: ClassTag](bytes: ByteBuffer): T = {
val bis = new ByteBufferInputStream(bytes)
val in = deserializeStream(bis)
in.readObject().asInstanceOf[T]
}
- def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = {
+ def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = {
val bis = new ByteBufferInputStream(bytes)
val in = deserializeStream(bis, loader)
in.readObject().asInstanceOf[T]
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index c4daec7875..5286f7b4c2 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -31,6 +31,8 @@ import org.apache.spark.scheduler.MapStatus
import org.apache.spark.storage._
import org.apache.spark.storage.{GetBlock, GotBlock, PutBlock}
+import scala.reflect.ClassTag
+
/**
* A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]].
*
@@ -95,7 +97,7 @@ private[spark]
class KryoSerializationStream(kryo: Kryo, outStream: OutputStream) extends SerializationStream {
val output = new KryoOutput(outStream)
- def writeObject[T](t: T): SerializationStream = {
+ def writeObject[T: ClassTag](t: T): SerializationStream = {
kryo.writeClassAndObject(output, t)
this
}
@@ -108,7 +110,7 @@ private[spark]
class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends DeserializationStream {
val input = new KryoInput(inStream)
- def readObject[T](): T = {
+ def readObject[T: ClassTag](): T = {
try {
kryo.readClassAndObject(input).asInstanceOf[T]
} catch {
@@ -131,18 +133,18 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ
lazy val output = ks.newKryoOutput()
lazy val input = new KryoInput()
- def serialize[T](t: T): ByteBuffer = {
+ def serialize[T: ClassTag](t: T): ByteBuffer = {
output.clear()
kryo.writeClassAndObject(output, t)
ByteBuffer.wrap(output.toBytes)
}
- def deserialize[T](bytes: ByteBuffer): T = {
+ def deserialize[T: ClassTag](bytes: ByteBuffer): T = {
input.setBuffer(bytes.array)
kryo.readClassAndObject(input).asInstanceOf[T]
}
- def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = {
+ def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = {
val oldClassLoader = kryo.getClassLoader
kryo.setClassLoader(loader)
input.setBuffer(bytes.array)
diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
index f2c8f9b621..ee26970a3d 100644
--- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
@@ -20,6 +20,8 @@ package org.apache.spark.serializer
import java.io.{ByteArrayOutputStream, EOFException, InputStream, OutputStream}
import java.nio.ByteBuffer
+import scala.reflect.ClassTag
+
import org.apache.spark.SparkEnv
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.{ByteBufferInputStream, NextIterator}
@@ -59,17 +61,17 @@ object Serializer {
*/
@DeveloperApi
trait SerializerInstance {
- def serialize[T](t: T): ByteBuffer
+ def serialize[T: ClassTag](t: T): ByteBuffer
- def deserialize[T](bytes: ByteBuffer): T
+ def deserialize[T: ClassTag](bytes: ByteBuffer): T
- def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T
+ def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T
def serializeStream(s: OutputStream): SerializationStream
def deserializeStream(s: InputStream): DeserializationStream
- def serializeMany[T](iterator: Iterator[T]): ByteBuffer = {
+ def serializeMany[T: ClassTag](iterator: Iterator[T]): ByteBuffer = {
// Default implementation uses serializeStream
val stream = new ByteArrayOutputStream()
serializeStream(stream).writeAll(iterator)
@@ -85,18 +87,17 @@ trait SerializerInstance {
}
}
-
/**
* :: DeveloperApi ::
* A stream for writing serialized objects.
*/
@DeveloperApi
trait SerializationStream {
- def writeObject[T](t: T): SerializationStream
+ def writeObject[T: ClassTag](t: T): SerializationStream
def flush(): Unit
def close(): Unit
- def writeAll[T](iter: Iterator[T]): SerializationStream = {
+ def writeAll[T: ClassTag](iter: Iterator[T]): SerializationStream = {
while (iter.hasNext) {
writeObject(iter.next())
}
@@ -111,7 +112,7 @@ trait SerializationStream {
*/
@DeveloperApi
trait DeserializationStream {
- def readObject[T](): T
+ def readObject[T: ClassTag](): T
def close(): Unit
/**
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 3f0ed61c5b..95777fbf57 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -850,7 +850,7 @@ private[spark] object Utils extends Logging {
/**
* Clone an object using a Spark serializer.
*/
- def clone[T](value: T, serializer: SerializerInstance): T = {
+ def clone[T: ClassTag](value: T, serializer: SerializerInstance): T = {
serializer.deserialize[T](serializer.serialize(value))
}
diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
index 5d4673aebe..cdd6b3d8fe 100644
--- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.serializer
import scala.collection.mutable
+import scala.reflect.ClassTag
import com.esotericsoftware.kryo.Kryo
import org.scalatest.FunSuite
@@ -31,7 +32,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext {
test("basic types") {
val ser = new KryoSerializer(conf).newInstance()
- def check[T](t: T) {
+ def check[T: ClassTag](t: T) {
assert(ser.deserialize[T](ser.serialize(t)) === t)
}
check(1)
@@ -61,7 +62,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext {
test("pairs") {
val ser = new KryoSerializer(conf).newInstance()
- def check[T](t: T) {
+ def check[T: ClassTag](t: T) {
assert(ser.deserialize[T](ser.serialize(t)) === t)
}
check((1, 1))
@@ -85,7 +86,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext {
test("Scala data structures") {
val ser = new KryoSerializer(conf).newInstance()
- def check[T](t: T) {
+ def check[T: ClassTag](t: T) {
assert(ser.deserialize[T](ser.serialize(t)) === t)
}
check(List[Int]())
@@ -108,7 +109,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext {
test("ranges") {
val ser = new KryoSerializer(conf).newInstance()
- def check[T](t: T) {
+ def check[T: ClassTag](t: T) {
assert(ser.deserialize[T](ser.serialize(t)) === t)
// Check that very long ranges don't get written one element at a time
assert(ser.serialize(t).limit < 100)
@@ -129,7 +130,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext {
test("custom registrator") {
val ser = new KryoSerializer(conf).newInstance()
- def check[T](t: T) {
+ def check[T: ClassTag](t: T) {
assert(ser.deserialize[T](ser.serialize(t)) === t)
}
diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala
index a197dac87d..576a3e371b 100644
--- a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala
@@ -28,6 +28,8 @@ import org.apache.spark.serializer.{DeserializationStream, SerializationStream,
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
+import scala.reflect.ClassTag
+
object WikipediaPageRankStandalone {
def main(args: Array[String]) {
if (args.length < 4) {
@@ -143,15 +145,15 @@ class WPRSerializer extends org.apache.spark.serializer.Serializer {
}
class WPRSerializerInstance extends SerializerInstance {
- def serialize[T](t: T): ByteBuffer = {
+ def serialize[T: ClassTag](t: T): ByteBuffer = {
throw new UnsupportedOperationException()
}
- def deserialize[T](bytes: ByteBuffer): T = {
+ def deserialize[T: ClassTag](bytes: ByteBuffer): T = {
throw new UnsupportedOperationException()
}
- def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = {
+ def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = {
throw new UnsupportedOperationException()
}
@@ -167,7 +169,7 @@ class WPRSerializerInstance extends SerializerInstance {
class WPRSerializationStream(os: OutputStream) extends SerializationStream {
val dos = new DataOutputStream(os)
- def writeObject[T](t: T): SerializationStream = t match {
+ def writeObject[T: ClassTag](t: T): SerializationStream = t match {
case (id: String, wrapper: ArrayBuffer[_]) => wrapper(0) match {
case links: Array[String] => {
dos.writeInt(0) // links
@@ -200,7 +202,7 @@ class WPRSerializationStream(os: OutputStream) extends SerializationStream {
class WPRDeserializationStream(is: InputStream) extends DeserializationStream {
val dis = new DataInputStream(is)
- def readObject[T](): T = {
+ def readObject[T: ClassTag](): T = {
val typeId = dis.readInt()
typeId match {
case 0 => {
diff --git a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala b/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala
index 2f0531ee5f..1de42eeca1 100644
--- a/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala
+++ b/graphx/src/main/scala/org/apache/spark/graphx/impl/Serializers.scala
@@ -17,20 +17,22 @@
package org.apache.spark.graphx.impl
+import scala.language.existentials
+
import java.io.{EOFException, InputStream, OutputStream}
import java.nio.ByteBuffer
+import scala.reflect.ClassTag
+
import org.apache.spark.graphx._
import org.apache.spark.serializer._
-import scala.language.existentials
-
private[graphx]
class VertexIdMsgSerializer extends Serializer with Serializable {
override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
- def writeObject[T](t: T) = {
+ def writeObject[T: ClassTag](t: T) = {
val msg = t.asInstanceOf[(VertexId, _)]
writeVarLong(msg._1, optimizePositive = false)
this
@@ -38,7 +40,7 @@ class VertexIdMsgSerializer extends Serializer with Serializable {
}
override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
- override def readObject[T](): T = {
+ override def readObject[T: ClassTag](): T = {
(readVarLong(optimizePositive = false), null).asInstanceOf[T]
}
}
@@ -51,7 +53,7 @@ class IntVertexBroadcastMsgSerializer extends Serializer with Serializable {
override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
- def writeObject[T](t: T) = {
+ def writeObject[T: ClassTag](t: T) = {
val msg = t.asInstanceOf[VertexBroadcastMsg[Int]]
writeVarLong(msg.vid, optimizePositive = false)
writeInt(msg.data)
@@ -60,7 +62,7 @@ class IntVertexBroadcastMsgSerializer extends Serializer with Serializable {
}
override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
- override def readObject[T](): T = {
+ override def readObject[T: ClassTag](): T = {
val a = readVarLong(optimizePositive = false)
val b = readInt()
new VertexBroadcastMsg[Int](0, a, b).asInstanceOf[T]
@@ -75,7 +77,7 @@ class LongVertexBroadcastMsgSerializer extends Serializer with Serializable {
override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
- def writeObject[T](t: T) = {
+ def writeObject[T: ClassTag](t: T) = {
val msg = t.asInstanceOf[VertexBroadcastMsg[Long]]
writeVarLong(msg.vid, optimizePositive = false)
writeLong(msg.data)
@@ -84,7 +86,7 @@ class LongVertexBroadcastMsgSerializer extends Serializer with Serializable {
}
override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
- override def readObject[T](): T = {
+ override def readObject[T: ClassTag](): T = {
val a = readVarLong(optimizePositive = false)
val b = readLong()
new VertexBroadcastMsg[Long](0, a, b).asInstanceOf[T]
@@ -99,7 +101,7 @@ class DoubleVertexBroadcastMsgSerializer extends Serializer with Serializable {
override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
- def writeObject[T](t: T) = {
+ def writeObject[T: ClassTag](t: T) = {
val msg = t.asInstanceOf[VertexBroadcastMsg[Double]]
writeVarLong(msg.vid, optimizePositive = false)
writeDouble(msg.data)
@@ -108,7 +110,7 @@ class DoubleVertexBroadcastMsgSerializer extends Serializer with Serializable {
}
override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
- def readObject[T](): T = {
+ def readObject[T: ClassTag](): T = {
val a = readVarLong(optimizePositive = false)
val b = readDouble()
new VertexBroadcastMsg[Double](0, a, b).asInstanceOf[T]
@@ -123,7 +125,7 @@ class IntAggMsgSerializer extends Serializer with Serializable {
override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
- def writeObject[T](t: T) = {
+ def writeObject[T: ClassTag](t: T) = {
val msg = t.asInstanceOf[(VertexId, Int)]
writeVarLong(msg._1, optimizePositive = false)
writeUnsignedVarInt(msg._2)
@@ -132,7 +134,7 @@ class IntAggMsgSerializer extends Serializer with Serializable {
}
override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
- override def readObject[T](): T = {
+ override def readObject[T: ClassTag](): T = {
val a = readVarLong(optimizePositive = false)
val b = readUnsignedVarInt()
(a, b).asInstanceOf[T]
@@ -147,7 +149,7 @@ class LongAggMsgSerializer extends Serializer with Serializable {
override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
- def writeObject[T](t: T) = {
+ def writeObject[T: ClassTag](t: T) = {
val msg = t.asInstanceOf[(VertexId, Long)]
writeVarLong(msg._1, optimizePositive = false)
writeVarLong(msg._2, optimizePositive = true)
@@ -156,7 +158,7 @@ class LongAggMsgSerializer extends Serializer with Serializable {
}
override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
- override def readObject[T](): T = {
+ override def readObject[T: ClassTag](): T = {
val a = readVarLong(optimizePositive = false)
val b = readVarLong(optimizePositive = true)
(a, b).asInstanceOf[T]
@@ -171,7 +173,7 @@ class DoubleAggMsgSerializer extends Serializer with Serializable {
override def newInstance(): SerializerInstance = new ShuffleSerializerInstance {
override def serializeStream(s: OutputStream) = new ShuffleSerializationStream(s) {
- def writeObject[T](t: T) = {
+ def writeObject[T: ClassTag](t: T) = {
val msg = t.asInstanceOf[(VertexId, Double)]
writeVarLong(msg._1, optimizePositive = false)
writeDouble(msg._2)
@@ -180,7 +182,7 @@ class DoubleAggMsgSerializer extends Serializer with Serializable {
}
override def deserializeStream(s: InputStream) = new ShuffleDeserializationStream(s) {
- def readObject[T](): T = {
+ def readObject[T: ClassTag](): T = {
val a = readVarLong(optimizePositive = false)
val b = readDouble()
(a, b).asInstanceOf[T]
@@ -196,7 +198,7 @@ class DoubleAggMsgSerializer extends Serializer with Serializable {
private[graphx]
abstract class ShuffleSerializationStream(s: OutputStream) extends SerializationStream {
// The implementation should override this one.
- def writeObject[T](t: T): SerializationStream
+ def writeObject[T: ClassTag](t: T): SerializationStream
def writeInt(v: Int) {
s.write(v >> 24)
@@ -309,7 +311,7 @@ abstract class ShuffleSerializationStream(s: OutputStream) extends Serialization
private[graphx]
abstract class ShuffleDeserializationStream(s: InputStream) extends DeserializationStream {
// The implementation should override this one.
- def readObject[T](): T
+ def readObject[T: ClassTag](): T
def readInt(): Int = {
val first = s.read()
@@ -398,11 +400,12 @@ abstract class ShuffleDeserializationStream(s: InputStream) extends Deserializat
private[graphx] sealed trait ShuffleSerializerInstance extends SerializerInstance {
- override def serialize[T](t: T): ByteBuffer = throw new UnsupportedOperationException
+ override def serialize[T: ClassTag](t: T): ByteBuffer = throw new UnsupportedOperationException
- override def deserialize[T](bytes: ByteBuffer): T = throw new UnsupportedOperationException
+ override def deserialize[T: ClassTag](bytes: ByteBuffer): T =
+ throw new UnsupportedOperationException
- override def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T =
+ override def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T =
throw new UnsupportedOperationException
// The implementation should override the following two.
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala b/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala
index 73438d9535..91caa6b605 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala
+++ b/graphx/src/test/scala/org/apache/spark/graphx/SerializerSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.graphx
import java.io.{EOFException, ByteArrayInputStream, ByteArrayOutputStream}
import scala.util.Random
+import scala.reflect.ClassTag
import org.scalatest.FunSuite
@@ -164,7 +165,7 @@ class SerializerSuite extends FunSuite with LocalSparkContext {
def testVarLongEncoding(v: Long, optimizePositive: Boolean) {
val bout = new ByteArrayOutputStream
val stream = new ShuffleSerializationStream(bout) {
- def writeObject[T](t: T): SerializationStream = {
+ def writeObject[T: ClassTag](t: T): SerializationStream = {
writeVarLong(t.asInstanceOf[Long], optimizePositive = optimizePositive)
this
}
@@ -173,7 +174,7 @@ class SerializerSuite extends FunSuite with LocalSparkContext {
val bin = new ByteArrayInputStream(bout.toByteArray)
val dstream = new ShuffleDeserializationStream(bin) {
- def readObject[T](): T = {
+ def readObject[T: ClassTag](): T = {
readVarLong(optimizePositive).asInstanceOf[T]
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
index 5067c14ddf..1c6e29b3cd 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkSqlSerializer.scala
@@ -19,6 +19,8 @@ package org.apache.spark.sql.execution
import java.nio.ByteBuffer
+import scala.reflect.ClassTag
+
import com.esotericsoftware.kryo.io.{Input, Output}
import com.esotericsoftware.kryo.{Serializer, Kryo}
@@ -59,11 +61,11 @@ private[sql] object SparkSqlSerializer {
new KryoSerializer(sparkConf)
}
- def serialize[T](o: T): Array[Byte] = {
+ def serialize[T: ClassTag](o: T): Array[Byte] = {
ser.newInstance().serialize(o).array()
}
- def deserialize[T](bytes: Array[Byte]): T = {
+ def deserialize[T: ClassTag](bytes: Array[Byte]): T = {
ser.newInstance().deserialize[T](ByteBuffer.wrap(bytes))
}
}