aboutsummaryrefslogtreecommitdiff
path: root/core/src
diff options
context:
space:
mode:
Diffstat (limited to 'core/src')
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala13
-rw-r--r--core/src/test/scala/org/apache/spark/DistributedSuite.scala77
3 files changed, 46 insertions, 58 deletions
diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
index 9dc274c9fe..07caadbe40 100644
--- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
@@ -68,7 +68,7 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
* loaded yet. */
private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf)
- private def canUseKryo(ct: ClassTag[_]): Boolean = {
+ def canUseKryo(ct: ClassTag[_]): Boolean = {
primitiveAndPrimitiveArrayClassTags.contains(ct) || ct == stringClassTag
}
@@ -128,8 +128,18 @@ private[spark] class SerializerManager(defaultSerializer: Serializer, conf: Spar
/** Serializes into a chunked byte buffer. */
def dataSerialize[T: ClassTag](blockId: BlockId, values: Iterator[T]): ChunkedByteBuffer = {
+ dataSerializeWithExplicitClassTag(blockId, values, implicitly[ClassTag[T]])
+ }
+
+ /** Serializes into a chunked byte buffer. */
+ def dataSerializeWithExplicitClassTag(
+ blockId: BlockId,
+ values: Iterator[_],
+ classTag: ClassTag[_]): ChunkedByteBuffer = {
val bbos = new ChunkedByteBufferOutputStream(1024 * 1024 * 4, ByteBuffer.allocate)
- dataSerializeStream(blockId, bbos, values)
+ val byteStream = new BufferedOutputStream(bbos)
+ val ser = getSerializer(classTag).newInstance()
+ ser.serializeStream(wrapForCompression(blockId, byteStream)).writeAll(values).close()
bbos.toChunkedByteBuffer
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 015e71d126..fe84652798 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -498,7 +498,8 @@ private[spark] class BlockManager(
diskStore.getBytes(blockId)
} else if (level.useMemory && memoryStore.contains(blockId)) {
// The block was not found on disk, so serialize an in-memory copy:
- serializerManager.dataSerialize(blockId, memoryStore.getValues(blockId).get)
+ serializerManager.dataSerializeWithExplicitClassTag(
+ blockId, memoryStore.getValues(blockId).get, info.classTag)
} else {
handleLocalReadFailure(blockId)
}
@@ -973,8 +974,16 @@ private[spark] class BlockManager(
if (level.replication > 1) {
val remoteStartTime = System.currentTimeMillis
val bytesToReplicate = doGetLocalBytes(blockId, info)
+ // [SPARK-16550] Erase the typed classTag when using default serialization, since
+ // NettyBlockRpcServer crashes when deserializing repl-defined classes.
+ // TODO(ekl) remove this once the classloader issue on the remote end is fixed.
+ val remoteClassTag = if (!serializerManager.canUseKryo(classTag)) {
+ scala.reflect.classTag[Any]
+ } else {
+ classTag
+ }
try {
- replicate(blockId, bytesToReplicate, level, classTag)
+ replicate(blockId, bytesToReplicate, level, remoteClassTag)
} finally {
bytesToReplicate.dispose()
}
diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
index 6beae842b0..4ee0e00fde 100644
--- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala
+++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
@@ -149,61 +149,16 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex
sc.parallelize(1 to 10).count()
}
- test("caching") {
+ private def testCaching(storageLevel: StorageLevel): Unit = {
sc = new SparkContext(clusterUrl, "test")
- val data = sc.parallelize(1 to 1000, 10).cache()
- assert(data.count() === 1000)
- assert(data.count() === 1000)
- assert(data.count() === 1000)
- }
-
- test("caching on disk") {
- sc = new SparkContext(clusterUrl, "test")
- val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.DISK_ONLY)
- assert(data.count() === 1000)
- assert(data.count() === 1000)
- assert(data.count() === 1000)
- }
-
- test("caching in memory, replicated") {
- sc = new SparkContext(clusterUrl, "test")
- val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_ONLY_2)
- assert(data.count() === 1000)
- assert(data.count() === 1000)
- assert(data.count() === 1000)
- }
-
- test("caching in memory, serialized, replicated") {
- sc = new SparkContext(clusterUrl, "test")
- val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_ONLY_SER_2)
- assert(data.count() === 1000)
- assert(data.count() === 1000)
- assert(data.count() === 1000)
- }
-
- test("caching on disk, replicated") {
- sc = new SparkContext(clusterUrl, "test")
- val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.DISK_ONLY_2)
- assert(data.count() === 1000)
- assert(data.count() === 1000)
- assert(data.count() === 1000)
- }
-
- test("caching in memory and disk, replicated") {
- sc = new SparkContext(clusterUrl, "test")
- val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_AND_DISK_2)
- assert(data.count() === 1000)
- assert(data.count() === 1000)
- assert(data.count() === 1000)
- }
-
- test("caching in memory and disk, serialized, replicated") {
- sc = new SparkContext(clusterUrl, "test")
- val data = sc.parallelize(1 to 1000, 10).persist(StorageLevel.MEMORY_AND_DISK_SER_2)
-
- assert(data.count() === 1000)
- assert(data.count() === 1000)
- assert(data.count() === 1000)
+ sc.jobProgressListener.waitUntilExecutorsUp(2, 30000)
+ val data = sc.parallelize(1 to 1000, 10)
+ val cachedData = data.persist(storageLevel)
+ assert(cachedData.count === 1000)
+ assert(sc.getExecutorStorageStatus.map(_.rddBlocksById(cachedData.id).size).sum ===
+ storageLevel.replication * data.getNumPartitions)
+ assert(cachedData.count === 1000)
+ assert(cachedData.count === 1000)
// Get all the locations of the first partition and try to fetch the partitions
// from those locations.
@@ -221,6 +176,20 @@ class DistributedSuite extends SparkFunSuite with Matchers with LocalSparkContex
}
}
+ Seq(
+ "caching" -> StorageLevel.MEMORY_ONLY,
+ "caching on disk" -> StorageLevel.DISK_ONLY,
+ "caching in memory, replicated" -> StorageLevel.MEMORY_ONLY_2,
+ "caching in memory, serialized, replicated" -> StorageLevel.MEMORY_ONLY_SER_2,
+ "caching on disk, replicated" -> StorageLevel.DISK_ONLY_2,
+ "caching in memory and disk, replicated" -> StorageLevel.MEMORY_AND_DISK_2,
+ "caching in memory and disk, serialized, replicated" -> StorageLevel.MEMORY_AND_DISK_SER_2
+ ).foreach { case (testName, storageLevel) =>
+ test(testName) {
+ testCaching(storageLevel)
+ }
+ }
+
test("compute without caching when no partitions fit in memory") {
val size = 10000
val conf = new SparkConf()