aboutsummaryrefslogtreecommitdiff
path: root/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'core/src/test/scala/org/apache/spark/SortShuffleSuite.scala')
-rw-r--r--core/src/test/scala/org/apache/spark/SortShuffleSuite.scala65
1 files changed, 65 insertions, 0 deletions
diff --git a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
index 63358172ea..b8ab227517 100644
--- a/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/SortShuffleSuite.scala
@@ -17,13 +17,78 @@
package org.apache.spark
+import java.io.File
+
+import scala.collection.JavaConverters._
+
+import org.apache.commons.io.FileUtils
+import org.apache.commons.io.filefilter.TrueFileFilter
import org.scalatest.BeforeAndAfterAll
+import org.apache.spark.rdd.ShuffledRDD
+import org.apache.spark.shuffle.sort.SortShuffleManager
+import org.apache.spark.serializer.{JavaSerializer, KryoSerializer}
+import org.apache.spark.util.Utils
+
class SortShuffleSuite extends ShuffleSuite with BeforeAndAfterAll {
// This test suite should run all tests in ShuffleSuite with sort-based shuffle.
+ private var tempDir: File = _
+
override def beforeAll() {
conf.set("spark.shuffle.manager", "sort")
}
+
+ override def beforeEach(): Unit = {
+ tempDir = Utils.createTempDir()
+ conf.set("spark.local.dir", tempDir.getAbsolutePath)
+ }
+
+ override def afterEach(): Unit = {
+ try {
+ Utils.deleteRecursively(tempDir)
+ } finally {
+ super.afterEach()
+ }
+ }
+
+ test("SortShuffleManager properly cleans up files for shuffles that use the serialized path") {
+ sc = new SparkContext("local", "test", conf)
+ // Create a shuffled RDD and verify that it actually uses the new serialized map output path
+ val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x))
+ val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4))
+ .setSerializer(new KryoSerializer(conf))
+ val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
+ assert(SortShuffleManager.canUseSerializedShuffle(shuffleDep))
+ ensureFilesAreCleanedUp(shuffledRdd)
+ }
+
+ test("SortShuffleManager properly cleans up files for shuffles that use the deserialized path") {
+ sc = new SparkContext("local", "test", conf)
+ // Create a shuffled RDD and verify that it actually uses the old deserialized map output path
+ val rdd = sc.parallelize(1 to 10, 1).map(x => (x, x))
+ val shuffledRdd = new ShuffledRDD[Int, Int, Int](rdd, new HashPartitioner(4))
+ .setSerializer(new JavaSerializer(conf))
+ val shuffleDep = shuffledRdd.dependencies.head.asInstanceOf[ShuffleDependency[_, _, _]]
+ assert(!SortShuffleManager.canUseSerializedShuffle(shuffleDep))
+ ensureFilesAreCleanedUp(shuffledRdd)
+ }
+
+ private def ensureFilesAreCleanedUp(shuffledRdd: ShuffledRDD[_, _, _]): Unit = {
+ def getAllFiles: Set[File] =
+ FileUtils.listFiles(tempDir, TrueFileFilter.INSTANCE, TrueFileFilter.INSTANCE).asScala.toSet
+ val filesBeforeShuffle = getAllFiles
+ // Force the shuffle to be performed
+ shuffledRdd.count()
+ // Ensure that the shuffle actually created files that will need to be cleaned up
+ val filesCreatedByShuffle = getAllFiles -- filesBeforeShuffle
+ filesCreatedByShuffle.map(_.getName) should be
+ Set("shuffle_0_0_0.data", "shuffle_0_0_0.index")
+ // Check that the cleanup actually removes the files
+ sc.env.blockManager.master.removeShuffle(0, blocking = true)
+ for (file <- filesCreatedByShuffle) {
+ assert (!file.exists(), s"Shuffle file $file was not cleaned up")
+ }
+ }
}