diff options
-rw-r--r-- | core/src/main/scala/spark/storage/DiskStore.scala | 2 | ||||
-rw-r--r-- | core/src/test/scala/spark/ShuffleSuite.scala | 22 |
2 files changed, 19 insertions, 5 deletions
diff --git a/core/src/main/scala/spark/storage/DiskStore.scala b/core/src/main/scala/spark/storage/DiskStore.scala index 57d4dafefc..1829c2f92e 100644 --- a/core/src/main/scala/spark/storage/DiskStore.scala +++ b/core/src/main/scala/spark/storage/DiskStore.scala @@ -59,6 +59,8 @@ private class DiskStore(blockManager: BlockManager, rootDirs: String) // Flush the partial writes, and set valid length to be the length of the entire file. // Return the number of bytes written for this commit. override def commit(): Long = { + // NOTE: Flush the serializer first and then the compressed/buffered output stream + objOut.flush() bs.flush() val prevPos = lastValidPosition lastValidPosition = channel.position() diff --git a/core/src/test/scala/spark/ShuffleSuite.scala b/core/src/test/scala/spark/ShuffleSuite.scala index a4fe14b9ae..271f4a4e44 100644 --- a/core/src/test/scala/spark/ShuffleSuite.scala +++ b/core/src/test/scala/spark/ShuffleSuite.scala @@ -305,15 +305,27 @@ class ShuffleSuite extends FunSuite with ShouldMatchers with LocalSparkContext { assert(c.partitioner.get === p) } - test("shuffle local cluster") { - // Use a local cluster with 2 processes to make sure there are both local and remote blocks + test("shuffle non-zero block size") { sc = new SparkContext("local-cluster[2,1,512]", "test") + val NUM_BLOCKS = 3 + val a = sc.parallelize(1 to 10, 2) - val b = a.map { - x => (x, x * 2) + val b = a.map { x => + (x, new ShuffleSuite.NonJavaSerializableClass(x * 2)) } - val c = new ShuffledRDD(b, new HashPartitioner(3)) + // If the Kryo serializer is not used correctly, the shuffle would fail because the + // default Java serializer cannot handle the non serializable class. + val c = new ShuffledRDD(b, new HashPartitioner(NUM_BLOCKS), + classOf[spark.KryoSerializer].getName) + val shuffleId = c.dependencies.head.asInstanceOf[ShuffleDependency[Int, Int]].shuffleId + assert(c.count === 10) + + // All blocks must have non-zero size + (0 until NUM_BLOCKS).foreach { id => + val statuses = SparkEnv.get.mapOutputTracker.getServerStatuses(shuffleId, id) + assert(statuses.forall(s => s._2 > 0)) + } } test("shuffle serializer") { |