aboutsummaryrefslogtreecommitdiff
path: root/core/src
diff options
context:
space:
mode:
authorMatei Zaharia <matei@databricks.com>2014-05-10 12:10:24 -0700
committerPatrick Wendell <pwendell@gmail.com>2014-05-10 12:10:24 -0700
commit7eefc9d2b3f6ebc0ecb5562da7323f1e06afbb35 (patch)
tree32a35b0898c5710cdbd73b56ef9dcb9914e1cf02 /core/src
parent8e94d2721a9d3d36697e13f8cc6567ae8aeee78b (diff)
downloadspark-7eefc9d2b3f6ebc0ecb5562da7323f1e06afbb35.tar.gz
spark-7eefc9d2b3f6ebc0ecb5562da7323f1e06afbb35.tar.bz2
spark-7eefc9d2b3f6ebc0ecb5562da7323f1e06afbb35.zip
SPARK-1708. Add a ClassTag on Serializer and things that depend on it
This pull request contains a rebased patch from @heathermiller (https://github.com/heathermiller/spark/pull/1) to add ClassTags on Serializer and types that depend on it (Broadcast and AccumulableCollection). Putting these in the public API signatures now will allow us to use Scala Pickling for serialization down the line without breaking binary compatibility. One question remaining is whether we also want them on Accumulator -- Accumulator is passed as part of a bigger Task or TaskResult object via the closure serializer so it doesn't seem super useful to add the ClassTag there. Broadcast and AccumulableCollection in contrast were being serialized directly. CC @rxin, @pwendell, @heathermiller Author: Matei Zaharia <matei@databricks.com> Closes #700 from mateiz/spark-1708 and squashes the following commits: 1a3d8b0 [Matei Zaharia] Use fake ClassTag in Java 3b449ed [Matei Zaharia] test fix 2209a27 [Matei Zaharia] Code style fixes 9d48830 [Matei Zaharia] Add a ClassTag on Serializer and things that depend on it
Diffstat (limited to 'core/src')
-rw-r--r--core/src/main/scala/org/apache/spark/Accumulators.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala13
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/Serializer.scala17
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala11
18 files changed, 65 insertions, 42 deletions
diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala
index 6d652faae1..cdfd338081 100644
--- a/core/src/main/scala/org/apache/spark/Accumulators.scala
+++ b/core/src/main/scala/org/apache/spark/Accumulators.scala
@@ -21,6 +21,7 @@ import java.io.{ObjectInputStream, Serializable}
import scala.collection.generic.Growable
import scala.collection.mutable.Map
+import scala.reflect.ClassTag
import org.apache.spark.serializer.JavaSerializer
@@ -164,9 +165,9 @@ trait AccumulableParam[R, T] extends Serializable {
def zero(initialValue: R): R
}
-private[spark]
-class GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializable, T]
- extends AccumulableParam[R,T] {
+private[spark] class
+GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializable: ClassTag, T]
+ extends AccumulableParam[R, T] {
def addAccumulator(growable: R, elem: T): R = {
growable += elem
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 9d7c2c8d3d..c639b3e15d 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -756,7 +756,7 @@ class SparkContext(config: SparkConf) extends Logging {
* Growable and TraversableOnce are the standard APIs that guarantee += and ++=, implemented by
* standard mutable collections. So you can use this with mutable Map, Set, etc.
*/
- def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable, T]
+ def accumulableCollection[R <% Growable[T] with TraversableOnce[T] with Serializable: ClassTag, T]
(initialValue: R): Accumulable[R, T] = {
val param = new GrowableAccumulableParam[R,T]
new Accumulable(initialValue, param)
@@ -767,7 +767,7 @@ class SparkContext(config: SparkConf) extends Logging {
* [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
* The variable will be sent to each cluster only once.
*/
- def broadcast[T](value: T): Broadcast[T] = {
+ def broadcast[T: ClassTag](value: T): Broadcast[T] = {
val bc = env.broadcastManager.newBroadcast[T](value, isLocal)
cleaner.foreach(_.registerBroadcastForCleanup(bc))
bc
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
index 8b95cda511..a7cfee6d01 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
@@ -447,7 +447,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
* [[org.apache.spark.broadcast.Broadcast]] object for reading it in distributed functions.
* The variable will be sent to each cluster only once.
*/
- def broadcast[T](value: T): Broadcast[T] = sc.broadcast(value)
+ def broadcast[T](value: T): Broadcast[T] = sc.broadcast(value)(fakeClassTag)
/** Shut down the SparkContext. */
def stop() {
diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
index 738a3b1bed..76956f6a34 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
@@ -21,6 +21,8 @@ import java.io.Serializable
import org.apache.spark.SparkException
+import scala.reflect.ClassTag
+
/**
* A broadcast variable. Broadcast variables allow the programmer to keep a read-only variable
* cached on each machine rather than shipping a copy of it with tasks. They can be used, for
@@ -50,7 +52,7 @@ import org.apache.spark.SparkException
* @param id A unique identifier for the broadcast variable.
* @tparam T Type of the data contained in the broadcast variable.
*/
-abstract class Broadcast[T](val id: Long) extends Serializable {
+abstract class Broadcast[T: ClassTag](val id: Long) extends Serializable {
/**
* Flag signifying whether the broadcast variable is valid
diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
index 8c8ce9b169..a8c827030a 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastFactory.scala
@@ -17,6 +17,8 @@
package org.apache.spark.broadcast
+import scala.reflect.ClassTag
+
import org.apache.spark.SecurityManager
import org.apache.spark.SparkConf
import org.apache.spark.annotation.DeveloperApi
@@ -31,7 +33,7 @@ import org.apache.spark.annotation.DeveloperApi
@DeveloperApi
trait BroadcastFactory {
def initialize(isDriver: Boolean, conf: SparkConf, securityMgr: SecurityManager): Unit
- def newBroadcast[T](value: T, isLocal: Boolean, id: Long): Broadcast[T]
+ def newBroadcast[T: ClassTag](value: T, isLocal: Boolean, id: Long): Broadcast[T]
def unbroadcast(id: Long, removeFromDriver: Boolean, blocking: Boolean): Unit
def stop(): Unit
}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala
index cf62aca4d4..c88be6aba6 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/BroadcastManager.scala
@@ -19,6 +19,8 @@ package org.apache.spark.broadcast
import java.util.concurrent.atomic.AtomicLong
+import scala.reflect.ClassTag
+
import org.apache.spark._
private[spark] class BroadcastManager(
@@ -56,7 +58,7 @@ private[spark] class BroadcastManager(
private val nextBroadcastId = new AtomicLong(0)
- def newBroadcast[T](value_ : T, isLocal: Boolean) = {
+ def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean) = {
broadcastFactory.newBroadcast[T](value_, isLocal, nextBroadcastId.getAndIncrement())
}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
index 29372f16f2..78fc286e51 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
@@ -22,6 +22,8 @@ import java.io.{BufferedInputStream, BufferedOutputStream}
import java.net.{URL, URLConnection, URI}
import java.util.concurrent.TimeUnit
+import scala.reflect.ClassTag
+
import org.apache.spark.{HttpServer, Logging, SecurityManager, SparkConf, SparkEnv}
import org.apache.spark.io.CompressionCodec
import org.apache.spark.storage.{BroadcastBlockId, StorageLevel}
@@ -34,7 +36,8 @@ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedH
* (through a HTTP server running at the driver) and stored in the BlockManager of the
* executor to speed up future accesses.
*/
-private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
+private[spark] class HttpBroadcast[T: ClassTag](
+ @transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
def getValue = value_
@@ -173,7 +176,7 @@ private[spark] object HttpBroadcast extends Logging {
files += file.getAbsolutePath
}
- def read[T](id: Long): T = {
+ def read[T: ClassTag](id: Long): T = {
logDebug("broadcast read server: " + serverUri + " id: broadcast-" + id)
val url = serverUri + "/" + BroadcastBlockId(id).name
diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala
index e3f6cdc615..d5a031e2bb 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcastFactory.scala
@@ -17,6 +17,8 @@
package org.apache.spark.broadcast
+import scala.reflect.ClassTag
+
import org.apache.spark.{SecurityManager, SparkConf}
/**
@@ -29,7 +31,7 @@ class HttpBroadcastFactory extends BroadcastFactory {
HttpBroadcast.initialize(isDriver, conf, securityMgr)
}
- def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
+ def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) =
new HttpBroadcast[T](value_, isLocal, id)
def stop() { HttpBroadcast.stop() }
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
index 2659274c5e..734de37ba1 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala
@@ -19,6 +19,7 @@ package org.apache.spark.broadcast
import java.io.{ByteArrayInputStream, ObjectInputStream, ObjectOutputStream}
+import scala.reflect.ClassTag
import scala.math
import scala.util.Random
@@ -44,7 +45,8 @@ import org.apache.spark.util.Utils
* copies of the broadcast data (one per executor) as done by the
* [[org.apache.spark.broadcast.HttpBroadcast]].
*/
-private[spark] class TorrentBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long)
+private[spark] class TorrentBroadcast[T: ClassTag](
+ @transient var value_ : T, isLocal: Boolean, id: Long)
extends Broadcast[T](id) with Logging with Serializable {
def getValue = value_
diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala
index d216b58718..1de8396a0e 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcastFactory.scala
@@ -17,6 +17,8 @@
package org.apache.spark.broadcast
+import scala.reflect.ClassTag
+
import org.apache.spark.{SecurityManager, SparkConf}
/**
@@ -30,7 +32,7 @@ class TorrentBroadcastFactory extends BroadcastFactory {
TorrentBroadcast.initialize(isDriver, conf)
}
- def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) =
+ def newBroadcast[T: ClassTag](value_ : T, isLocal: Boolean, id: Long) =
new TorrentBroadcast[T](value_, isLocal, id)
def stop() { TorrentBroadcast.stop() }
diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
index 888af541cf..34c51b8330 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
@@ -84,7 +84,7 @@ private[spark] object CheckpointRDD extends Logging {
"part-%05d".format(splitId)
}
- def writeToFile[T](
+ def writeToFile[T: ClassTag](
path: String,
broadcastedConf: Broadcast[SerializableWritable[Configuration]],
blockSize: Int = -1
@@ -160,7 +160,7 @@ private[spark] object CheckpointRDD extends Logging {
val conf = SparkHadoopUtil.get.newConfiguration()
val fs = path.getFileSystem(conf)
val broadcastedConf = sc.broadcast(new SerializableWritable(conf))
- sc.runJob(rdd, CheckpointRDD.writeToFile(path.toString, broadcastedConf, 1024) _)
+ sc.runJob(rdd, CheckpointRDD.writeToFile[Int](path.toString, broadcastedConf, 1024) _)
val cpRDD = new CheckpointRDD[Int](sc, path.toString)
assert(cpRDD.partitions.length == rdd.partitions.length, "Number of partitions is not the same")
assert(cpRDD.collect.toList == rdd.collect.toList, "Data of partitions not the same")
diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
index 5f03d7d650..2425929fc7 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
@@ -77,7 +77,7 @@ private[spark] class ParallelCollectionPartition[T: ClassTag](
slice = in.readInt()
val ser = sfactory.newInstance()
- Utils.deserializeViaNestedStream(in, ser)(ds => values = ds.readObject())
+ Utils.deserializeViaNestedStream(in, ser)(ds => values = ds.readObject[Seq[T]]())
}
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
index 953f0555e5..c3b2a33fb5 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
@@ -92,7 +92,7 @@ private[spark] class RDDCheckpointData[T: ClassTag](@transient rdd: RDD[T])
// Save to file, and reload it as an RDD
val broadcastedConf = rdd.context.broadcast(
new SerializableWritable(rdd.context.hadoopConfiguration))
- rdd.context.runJob(rdd, CheckpointRDD.writeToFile(path.toString, broadcastedConf) _)
+ rdd.context.runJob(rdd, CheckpointRDD.writeToFile[T](path.toString, broadcastedConf) _)
val newRDD = new CheckpointRDD[T](rdd.context, path.toString)
if (newRDD.partitions.size != rdd.partitions.size) {
throw new SparkException(
diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
index e9163deaf2..0a7e1ec539 100644
--- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
@@ -20,6 +20,8 @@ package org.apache.spark.serializer
import java.io._
import java.nio.ByteBuffer
+import scala.reflect.ClassTag
+
import org.apache.spark.SparkConf
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.ByteBufferInputStream
@@ -36,7 +38,7 @@ private[spark] class JavaSerializationStream(out: OutputStream, counterReset: In
* But only call it every 10,000th time to avoid bloated serialization streams (when
* the stream 'resets' object class descriptions have to be re-written)
*/
- def writeObject[T](t: T): SerializationStream = {
+ def writeObject[T: ClassTag](t: T): SerializationStream = {
objOut.writeObject(t)
if (counterReset > 0 && counter >= counterReset) {
objOut.reset()
@@ -46,6 +48,7 @@ private[spark] class JavaSerializationStream(out: OutputStream, counterReset: In
}
this
}
+
def flush() { objOut.flush() }
def close() { objOut.close() }
}
@@ -57,12 +60,12 @@ extends DeserializationStream {
Class.forName(desc.getName, false, loader)
}
- def readObject[T](): T = objIn.readObject().asInstanceOf[T]
+ def readObject[T: ClassTag](): T = objIn.readObject().asInstanceOf[T]
def close() { objIn.close() }
}
private[spark] class JavaSerializerInstance(counterReset: Int) extends SerializerInstance {
- def serialize[T](t: T): ByteBuffer = {
+ def serialize[T: ClassTag](t: T): ByteBuffer = {
val bos = new ByteArrayOutputStream()
val out = serializeStream(bos)
out.writeObject(t)
@@ -70,13 +73,13 @@ private[spark] class JavaSerializerInstance(counterReset: Int) extends Serialize
ByteBuffer.wrap(bos.toByteArray)
}
- def deserialize[T](bytes: ByteBuffer): T = {
+ def deserialize[T: ClassTag](bytes: ByteBuffer): T = {
val bis = new ByteBufferInputStream(bytes)
val in = deserializeStream(bis)
in.readObject().asInstanceOf[T]
}
- def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = {
+ def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = {
val bis = new ByteBufferInputStream(bytes)
val in = deserializeStream(bis, loader)
in.readObject().asInstanceOf[T]
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index c4daec7875..5286f7b4c2 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -31,6 +31,8 @@ import org.apache.spark.scheduler.MapStatus
import org.apache.spark.storage._
import org.apache.spark.storage.{GetBlock, GotBlock, PutBlock}
+import scala.reflect.ClassTag
+
/**
* A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]].
*
@@ -95,7 +97,7 @@ private[spark]
class KryoSerializationStream(kryo: Kryo, outStream: OutputStream) extends SerializationStream {
val output = new KryoOutput(outStream)
- def writeObject[T](t: T): SerializationStream = {
+ def writeObject[T: ClassTag](t: T): SerializationStream = {
kryo.writeClassAndObject(output, t)
this
}
@@ -108,7 +110,7 @@ private[spark]
class KryoDeserializationStream(kryo: Kryo, inStream: InputStream) extends DeserializationStream {
val input = new KryoInput(inStream)
- def readObject[T](): T = {
+ def readObject[T: ClassTag](): T = {
try {
kryo.readClassAndObject(input).asInstanceOf[T]
} catch {
@@ -131,18 +133,18 @@ private[spark] class KryoSerializerInstance(ks: KryoSerializer) extends Serializ
lazy val output = ks.newKryoOutput()
lazy val input = new KryoInput()
- def serialize[T](t: T): ByteBuffer = {
+ def serialize[T: ClassTag](t: T): ByteBuffer = {
output.clear()
kryo.writeClassAndObject(output, t)
ByteBuffer.wrap(output.toBytes)
}
- def deserialize[T](bytes: ByteBuffer): T = {
+ def deserialize[T: ClassTag](bytes: ByteBuffer): T = {
input.setBuffer(bytes.array)
kryo.readClassAndObject(input).asInstanceOf[T]
}
- def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T = {
+ def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T = {
val oldClassLoader = kryo.getClassLoader
kryo.setClassLoader(loader)
input.setBuffer(bytes.array)
diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
index f2c8f9b621..ee26970a3d 100644
--- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
@@ -20,6 +20,8 @@ package org.apache.spark.serializer
import java.io.{ByteArrayOutputStream, EOFException, InputStream, OutputStream}
import java.nio.ByteBuffer
+import scala.reflect.ClassTag
+
import org.apache.spark.SparkEnv
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.util.{ByteBufferInputStream, NextIterator}
@@ -59,17 +61,17 @@ object Serializer {
*/
@DeveloperApi
trait SerializerInstance {
- def serialize[T](t: T): ByteBuffer
+ def serialize[T: ClassTag](t: T): ByteBuffer
- def deserialize[T](bytes: ByteBuffer): T
+ def deserialize[T: ClassTag](bytes: ByteBuffer): T
- def deserialize[T](bytes: ByteBuffer, loader: ClassLoader): T
+ def deserialize[T: ClassTag](bytes: ByteBuffer, loader: ClassLoader): T
def serializeStream(s: OutputStream): SerializationStream
def deserializeStream(s: InputStream): DeserializationStream
- def serializeMany[T](iterator: Iterator[T]): ByteBuffer = {
+ def serializeMany[T: ClassTag](iterator: Iterator[T]): ByteBuffer = {
// Default implementation uses serializeStream
val stream = new ByteArrayOutputStream()
serializeStream(stream).writeAll(iterator)
@@ -85,18 +87,17 @@ trait SerializerInstance {
}
}
-
/**
* :: DeveloperApi ::
* A stream for writing serialized objects.
*/
@DeveloperApi
trait SerializationStream {
- def writeObject[T](t: T): SerializationStream
+ def writeObject[T: ClassTag](t: T): SerializationStream
def flush(): Unit
def close(): Unit
- def writeAll[T](iter: Iterator[T]): SerializationStream = {
+ def writeAll[T: ClassTag](iter: Iterator[T]): SerializationStream = {
while (iter.hasNext) {
writeObject(iter.next())
}
@@ -111,7 +112,7 @@ trait SerializationStream {
*/
@DeveloperApi
trait DeserializationStream {
- def readObject[T](): T
+ def readObject[T: ClassTag](): T
def close(): Unit
/**
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 3f0ed61c5b..95777fbf57 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -850,7 +850,7 @@ private[spark] object Utils extends Logging {
/**
* Clone an object using a Spark serializer.
*/
- def clone[T](value: T, serializer: SerializerInstance): T = {
+ def clone[T: ClassTag](value: T, serializer: SerializerInstance): T = {
serializer.deserialize[T](serializer.serialize(value))
}
diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
index 5d4673aebe..cdd6b3d8fe 100644
--- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala
@@ -18,6 +18,7 @@
package org.apache.spark.serializer
import scala.collection.mutable
+import scala.reflect.ClassTag
import com.esotericsoftware.kryo.Kryo
import org.scalatest.FunSuite
@@ -31,7 +32,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext {
test("basic types") {
val ser = new KryoSerializer(conf).newInstance()
- def check[T](t: T) {
+ def check[T: ClassTag](t: T) {
assert(ser.deserialize[T](ser.serialize(t)) === t)
}
check(1)
@@ -61,7 +62,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext {
test("pairs") {
val ser = new KryoSerializer(conf).newInstance()
- def check[T](t: T) {
+ def check[T: ClassTag](t: T) {
assert(ser.deserialize[T](ser.serialize(t)) === t)
}
check((1, 1))
@@ -85,7 +86,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext {
test("Scala data structures") {
val ser = new KryoSerializer(conf).newInstance()
- def check[T](t: T) {
+ def check[T: ClassTag](t: T) {
assert(ser.deserialize[T](ser.serialize(t)) === t)
}
check(List[Int]())
@@ -108,7 +109,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext {
test("ranges") {
val ser = new KryoSerializer(conf).newInstance()
- def check[T](t: T) {
+ def check[T: ClassTag](t: T) {
assert(ser.deserialize[T](ser.serialize(t)) === t)
// Check that very long ranges don't get written one element at a time
assert(ser.serialize(t).limit < 100)
@@ -129,7 +130,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext {
test("custom registrator") {
val ser = new KryoSerializer(conf).newInstance()
- def check[T](t: T) {
+ def check[T: ClassTag](t: T) {
assert(ser.deserialize[T](ser.serialize(t)) === t)
}