aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java9
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java13
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala17
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala102
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala25
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala1
-rw-r--r--core/src/test/java/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala114
-rw-r--r--core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java9
-rw-r--r--core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java3
-rw-r--r--core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java3
-rw-r--r--core/src/test/scala/org/apache/spark/ShuffleSuite.scala107
-rw-r--r--core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala14
16 files changed, 402 insertions, 52 deletions
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
index ee82d67993..a1a1fb0142 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -125,7 +125,7 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
assert (partitionWriters == null);
if (!records.hasNext()) {
partitionLengths = new long[numPartitions];
- shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
+ shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null);
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
return;
}
@@ -155,9 +155,10 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
writer.commitAndClose();
}
- partitionLengths =
- writePartitionedFile(shuffleBlockResolver.getDataFile(shuffleId, mapId));
- shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
+ File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);
+ File tmp = Utils.tempFileWith(output);
+ partitionLengths = writePartitionedFile(tmp);
+ shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp);
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index 6a0a89e81c..744c3008ca 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -41,7 +41,7 @@ import org.apache.spark.annotation.Private;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.io.CompressionCodec;
import org.apache.spark.io.CompressionCodec$;
-import org.apache.spark.io.LZFCompressionCodec;
+import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.network.util.LimitedInputStream;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
@@ -53,7 +53,7 @@ import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.TimeTrackingOutputStream;
import org.apache.spark.unsafe.Platform;
-import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.util.Utils;
@Private
public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
@@ -206,8 +206,10 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
final SpillInfo[] spills = sorter.closeAndGetSpills();
sorter = null;
final long[] partitionLengths;
+ final File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);
+ final File tmp = Utils.tempFileWith(output);
try {
- partitionLengths = mergeSpills(spills);
+ partitionLengths = mergeSpills(spills, tmp);
} finally {
for (SpillInfo spill : spills) {
if (spill.file.exists() && ! spill.file.delete()) {
@@ -215,7 +217,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
}
}
}
- shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
+ shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp);
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
}
@@ -248,8 +250,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
*
* @return the partition lengths in the merged file.
*/
- private long[] mergeSpills(SpillInfo[] spills) throws IOException {
- final File outputFile = shuffleBlockResolver.getDataFile(shuffleId, mapId);
+ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOException {
final boolean compressionEnabled = sparkConf.getBoolean("spark.shuffle.compress", true);
final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf);
final boolean fastMergeEnabled =
diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
index cd253a78c2..39fadd8783 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
@@ -21,13 +21,13 @@ import java.util.concurrent.ConcurrentLinkedQueue
import scala.collection.JavaConverters._
-import org.apache.spark.{Logging, SparkConf, SparkEnv}
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.serializer.Serializer
import org.apache.spark.storage._
-import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}
+import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils}
+import org.apache.spark.{Logging, SparkConf, SparkEnv}
/** A group of writers for a ShuffleMapTask, one writer per reducer. */
private[spark] trait ShuffleWriterGroup {
@@ -84,17 +84,8 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf)
Array.tabulate[DiskBlockObjectWriter](numReducers) { bucketId =>
val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
val blockFile = blockManager.diskBlockManager.getFile(blockId)
- // Because of previous failures, the shuffle file may already exist on this machine.
- // If so, remove it.
- if (blockFile.exists) {
- if (blockFile.delete()) {
- logInfo(s"Removed existing shuffle file $blockFile")
- } else {
- logWarning(s"Failed to remove existing shuffle file $blockFile")
- }
- }
- blockManager.getDiskWriter(blockId, blockFile, serializerInstance, bufferSize,
- writeMetrics)
+ val tmp = Utils.tempFileWith(blockFile)
+ blockManager.getDiskWriter(blockId, tmp, serializerInstance, bufferSize, writeMetrics)
}
}
// Creating the file to write to and creating a disk writer both involve interacting with
diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
index 5e4c2b5d0a..05b1eed7f3 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
@@ -21,13 +21,12 @@ import java.io._
import com.google.common.io.ByteStreams
-import org.apache.spark.{SparkConf, SparkEnv, Logging}
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.network.netty.SparkTransportConf
+import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID
import org.apache.spark.storage._
import org.apache.spark.util.Utils
-
-import IndexShuffleBlockResolver.NOOP_REDUCE_ID
+import org.apache.spark.{SparkEnv, Logging, SparkConf}
/**
* Create and maintain the shuffle blocks' mapping between logic block and physical file location.
@@ -40,10 +39,13 @@ import IndexShuffleBlockResolver.NOOP_REDUCE_ID
*/
// Note: Changes to the format in this file should be kept in sync with
// org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getSortBasedShuffleBlockData().
-private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleBlockResolver
+private[spark] class IndexShuffleBlockResolver(
+ conf: SparkConf,
+ _blockManager: BlockManager = null)
+ extends ShuffleBlockResolver
with Logging {
- private lazy val blockManager = SparkEnv.get.blockManager
+ private lazy val blockManager = Option(_blockManager).getOrElse(SparkEnv.get.blockManager)
private val transportConf = SparkTransportConf.fromSparkConf(conf)
@@ -75,13 +77,68 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleB
}
/**
+ * Check whether the given index and data files match each other.
+ * If so, return the partition lengths in the data file. Otherwise return null.
+ */
+ private def checkIndexAndDataFile(index: File, data: File, blocks: Int): Array[Long] = {
+ // the index file should have `block + 1` longs as offset.
+ if (index.length() != (blocks + 1) * 8) {
+ return null
+ }
+ val lengths = new Array[Long](blocks)
+ // Read the lengths of blocks
+ val in = try {
+ new DataInputStream(new BufferedInputStream(new FileInputStream(index)))
+ } catch {
+ case e: IOException =>
+ return null
+ }
+ try {
+ // Convert the offsets into lengths of each block
+ var offset = in.readLong()
+ if (offset != 0L) {
+ return null
+ }
+ var i = 0
+ while (i < blocks) {
+ val off = in.readLong()
+ lengths(i) = off - offset
+ offset = off
+ i += 1
+ }
+ } catch {
+ case e: IOException =>
+ return null
+ } finally {
+ in.close()
+ }
+
+ // the size of data file should match with index file
+ if (data.length() == lengths.sum) {
+ lengths
+ } else {
+ null
+ }
+ }
+
+ /**
* Write an index file with the offsets of each block, plus a final offset at the end for the
* end of the output file. This will be used by getBlockData to figure out where each block
* begins and ends.
+ *
+ * It will commit the data and index file as an atomic operation, use the existing ones, or
+ * replace them with new ones.
+ *
+ * Note: the `lengths` will be updated to match the existing index file if use the existing ones.
* */
- def writeIndexFile(shuffleId: Int, mapId: Int, lengths: Array[Long]): Unit = {
+ def writeIndexFileAndCommit(
+ shuffleId: Int,
+ mapId: Int,
+ lengths: Array[Long],
+ dataTmp: File): Unit = {
val indexFile = getIndexFile(shuffleId, mapId)
- val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile)))
+ val indexTmp = Utils.tempFileWith(indexFile)
+ val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp)))
Utils.tryWithSafeFinally {
// We take in lengths of each block, need to convert it to offsets.
var offset = 0L
@@ -93,6 +150,37 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleB
} {
out.close()
}
+
+ val dataFile = getDataFile(shuffleId, mapId)
+ // There is only one IndexShuffleBlockResolver per executor, this synchronization make sure
+ // the following check and rename are atomic.
+ synchronized {
+ val existingLengths = checkIndexAndDataFile(indexFile, dataFile, lengths.length)
+ if (existingLengths != null) {
+ // Another attempt for the same task has already written our map outputs successfully,
+ // so just use the existing partition lengths and delete our temporary map outputs.
+ System.arraycopy(existingLengths, 0, lengths, 0, lengths.length)
+ if (dataTmp != null && dataTmp.exists()) {
+ dataTmp.delete()
+ }
+ indexTmp.delete()
+ } else {
+ // This is the first successful attempt in writing the map outputs for this task,
+ // so override any existing index and data files with the ones we wrote.
+ if (indexFile.exists()) {
+ indexFile.delete()
+ }
+ if (dataFile.exists()) {
+ dataFile.delete()
+ }
+ if (!indexTmp.renameTo(indexFile)) {
+ throw new IOException("fail to rename file " + indexTmp + " to " + indexFile)
+ }
+ if (dataTmp != null && dataTmp.exists() && !dataTmp.renameTo(dataFile)) {
+ throw new IOException("fail to rename file " + dataTmp + " to " + dataFile)
+ }
+ }
+ }
}
override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = {
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
index 41df70c602..412bf70000 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
@@ -17,6 +17,8 @@
package org.apache.spark.shuffle.hash
+import java.io.IOException
+
import org.apache.spark._
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.scheduler.MapStatus
@@ -106,6 +108,29 @@ private[spark] class HashShuffleWriter[K, V](
writer.commitAndClose()
writer.fileSegment().length
}
+ // rename all shuffle files to final paths
+ // Note: there is only one ShuffleBlockResolver in executor
+ shuffleBlockResolver.synchronized {
+ shuffle.writers.zipWithIndex.foreach { case (writer, i) =>
+ val output = blockManager.diskBlockManager.getFile(writer.blockId)
+ if (sizes(i) > 0) {
+ if (output.exists()) {
+ // Use length of existing file and delete our own temporary one
+ sizes(i) = output.length()
+ writer.file.delete()
+ } else {
+ // Commit by renaming our temporary file to something the fetcher expects
+ if (!writer.file.renameTo(output)) {
+ throw new IOException(s"fail to rename ${writer.file} to $output")
+ }
+ }
+ } else {
+ if (output.exists()) {
+ output.delete()
+ }
+ }
+ }
+ }
MapStatus(blockManager.shuffleServerId, sizes)
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index 808317b017..f83cf8859e 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -20,8 +20,9 @@ package org.apache.spark.shuffle.sort
import org.apache.spark._
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.scheduler.MapStatus
-import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle}
+import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, ShuffleWriter}
import org.apache.spark.storage.ShuffleBlockId
+import org.apache.spark.util.Utils
import org.apache.spark.util.collection.ExternalSorter
private[spark] class SortShuffleWriter[K, V, C](
@@ -65,11 +66,11 @@ private[spark] class SortShuffleWriter[K, V, C](
// Don't bother including the time to open the merged output file in the shuffle write time,
// because it just opens a single file, so is typically too fast to measure accurately
// (see SPARK-3570).
- val outputFile = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
+ val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
+ val tmp = Utils.tempFileWith(output)
val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
- val partitionLengths = sorter.writePartitionedFile(blockId, outputFile)
- shuffleBlockResolver.writeIndexFile(dep.shuffleId, mapId, partitionLengths)
-
+ val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
+ shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index c374b93766..661c706af3 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -21,10 +21,10 @@ import java.io._
import java.nio.{ByteBuffer, MappedByteBuffer}
import scala.collection.mutable.{ArrayBuffer, HashMap}
-import scala.concurrent.{ExecutionContext, Await, Future}
import scala.concurrent.duration._
-import scala.util.control.NonFatal
+import scala.concurrent.{Await, ExecutionContext, Future}
import scala.util.Random
+import scala.util.control.NonFatal
import sun.nio.ch.DirectBuffer
@@ -38,9 +38,8 @@ import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.shuffle.ExternalShuffleClient
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
import org.apache.spark.rpc.RpcEnv
-import org.apache.spark.serializer.{SerializerInstance, Serializer}
+import org.apache.spark.serializer.{Serializer, SerializerInstance}
import org.apache.spark.shuffle.ShuffleManager
-import org.apache.spark.shuffle.hash.HashShuffleManager
import org.apache.spark.util._
private[spark] sealed trait BlockValues
@@ -660,7 +659,7 @@ private[spark] class BlockManager(
val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _)
val syncWrites = conf.getBoolean("spark.shuffle.sync", false)
new DiskBlockObjectWriter(file, serializerInstance, bufferSize, compressStream,
- syncWrites, writeMetrics)
+ syncWrites, writeMetrics, blockId)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
index 80d426fadc..e2dd80f243 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
@@ -34,14 +34,15 @@ import org.apache.spark.util.Utils
* reopened again.
*/
private[spark] class DiskBlockObjectWriter(
- file: File,
+ val file: File,
serializerInstance: SerializerInstance,
bufferSize: Int,
compressStream: OutputStream => OutputStream,
syncWrites: Boolean,
// These write metrics concurrently shared with other active DiskBlockObjectWriters who
// are themselves performing writes. All updates must be relative.
- writeMetrics: ShuffleWriteMetrics)
+ writeMetrics: ShuffleWriteMetrics,
+ val blockId: BlockId = null)
extends OutputStream
with Logging {
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 316c194ff3..1b3acb8ef7 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -21,8 +21,8 @@ import java.io._
import java.lang.management.ManagementFactory
import java.net._
import java.nio.ByteBuffer
-import java.util.{Properties, Locale, Random, UUID}
import java.util.concurrent._
+import java.util.{Locale, Properties, Random, UUID}
import javax.net.ssl.HttpsURLConnection
import scala.collection.JavaConverters._
@@ -30,7 +30,7 @@ import scala.collection.Map
import scala.collection.mutable.ArrayBuffer
import scala.io.Source
import scala.reflect.ClassTag
-import scala.util.{Failure, Success, Try}
+import scala.util.Try
import scala.util.control.{ControlThrowable, NonFatal}
import com.google.common.io.{ByteStreams, Files}
@@ -42,7 +42,6 @@ import org.apache.hadoop.security.UserGroupInformation
import org.apache.log4j.PropertyConfigurator
import org.eclipse.jetty.util.MultiException
import org.json4s._
-
import tachyon.TachyonURI
import tachyon.client.{TachyonFS, TachyonFile}
@@ -2169,6 +2168,13 @@ private[spark] object Utils extends Logging {
val resource = createResource
try f.apply(resource) finally resource.close()
}
+
+ /**
+ * Returns a path of temporary file which is in the same directory with `path`.
+ */
+ def tempFileWith(path: File): File = {
+ new File(path.getAbsolutePath + "." + UUID.randomUUID())
+ }
}
/**
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index bd6844d045..2440139ac9 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -638,7 +638,6 @@ private[spark] class ExternalSorter[K, V, C](
* called by the SortShuffleWriter.
*
* @param blockId block ID to write to. The index file will be blockId.name + ".index".
- * @param context a TaskContext for a running Spark task, for us to update shuffle metrics.
* @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
*/
def writePartitionedFile(
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala b/core/src/test/java/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala
new file mode 100644
index 0000000000..0b19861fc4
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort
+
+import java.io.{File, FileInputStream, FileOutputStream}
+
+import org.mockito.Answers.RETURNS_SMART_NULLS
+import org.mockito.Matchers._
+import org.mockito.Mockito._
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+import org.mockito.{Mock, MockitoAnnotations}
+import org.scalatest.BeforeAndAfterEach
+
+import org.apache.spark.shuffle.IndexShuffleBlockResolver
+import org.apache.spark.storage._
+import org.apache.spark.util.Utils
+import org.apache.spark.{SparkConf, SparkFunSuite}
+
+
+class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEach {
+
+ @Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = _
+ @Mock(answer = RETURNS_SMART_NULLS) private var diskBlockManager: DiskBlockManager = _
+
+ private var tempDir: File = _
+ private val conf: SparkConf = new SparkConf(loadDefaults = false)
+
+ override def beforeEach(): Unit = {
+ tempDir = Utils.createTempDir()
+ MockitoAnnotations.initMocks(this)
+
+ when(blockManager.diskBlockManager).thenReturn(diskBlockManager)
+ when(diskBlockManager.getFile(any[BlockId])).thenAnswer(
+ new Answer[File] {
+ override def answer(invocation: InvocationOnMock): File = {
+ new File(tempDir, invocation.getArguments.head.toString)
+ }
+ })
+ }
+
+ override def afterEach(): Unit = {
+ Utils.deleteRecursively(tempDir)
+ }
+
+ test("commit shuffle files multiple times") {
+ val lengths = Array[Long](10, 0, 20)
+ val resolver = new IndexShuffleBlockResolver(conf, blockManager)
+ val dataTmp = File.createTempFile("shuffle", null, tempDir)
+ val out = new FileOutputStream(dataTmp)
+ out.write(new Array[Byte](30))
+ out.close()
+ resolver.writeIndexFileAndCommit(1, 2, lengths, dataTmp)
+
+ val dataFile = resolver.getDataFile(1, 2)
+ assert(dataFile.exists())
+ assert(dataFile.length() === 30)
+ assert(!dataTmp.exists())
+
+ val dataTmp2 = File.createTempFile("shuffle", null, tempDir)
+ val out2 = new FileOutputStream(dataTmp2)
+ val lengths2 = new Array[Long](3)
+ out2.write(Array[Byte](1))
+ out2.write(new Array[Byte](29))
+ out2.close()
+ resolver.writeIndexFileAndCommit(1, 2, lengths2, dataTmp2)
+ assert(lengths2.toSeq === lengths.toSeq)
+ assert(dataFile.exists())
+ assert(dataFile.length() === 30)
+ assert(!dataTmp2.exists())
+
+ // The dataFile should be the previous one
+ val in = new FileInputStream(dataFile)
+ val firstByte = new Array[Byte](1)
+ in.read(firstByte)
+ assert(firstByte(0) === 0)
+
+ // remove data file
+ dataFile.delete()
+
+ val dataTmp3 = File.createTempFile("shuffle", null, tempDir)
+ val out3 = new FileOutputStream(dataTmp3)
+ val lengths3 = Array[Long](10, 10, 15)
+ out3.write(Array[Byte](2))
+ out3.write(new Array[Byte](34))
+ out3.close()
+ resolver.writeIndexFileAndCommit(1, 2, lengths3, dataTmp3)
+ assert(lengths3.toSeq != lengths.toSeq)
+ assert(dataFile.exists())
+ assert(dataFile.length() === 35)
+ assert(!dataTmp2.exists())
+
+ // The dataFile should be the previous one
+ val in2 = new FileInputStream(dataFile)
+ val firstByte2 = new Array[Byte](1)
+ in2.read(firstByte2)
+ assert(firstByte2(0) === 2)
+ }
+}
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index 0e0eca515a..bc85918c59 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -130,7 +130,8 @@ public class UnsafeShuffleWriterSuite {
(Integer) args[3],
new CompressStream(),
false,
- (ShuffleWriteMetrics) args[4]
+ (ShuffleWriteMetrics) args[4],
+ (BlockId) args[0]
);
}
});
@@ -169,9 +170,13 @@ public class UnsafeShuffleWriterSuite {
@Override
public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2];
+ File tmp = (File) invocationOnMock.getArguments()[3];
+ mergedOutputFile.delete();
+ tmp.renameTo(mergedOutputFile);
return null;
}
- }).when(shuffleBlockResolver).writeIndexFile(anyInt(), anyInt(), any(long[].class));
+ }).when(shuffleBlockResolver)
+ .writeIndexFileAndCommit(anyInt(), anyInt(), any(long[].class), any(File.class));
when(diskBlockManager.createTempShuffleBlock()).thenAnswer(
new Answer<Tuple2<TempShuffleBlockId, File>>() {
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index 3bca790f30..d87a1d2a56 100644
--- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -117,7 +117,8 @@ public abstract class AbstractBytesToBytesMapSuite {
(Integer) args[3],
new CompressStream(),
false,
- (ShuffleWriteMetrics) args[4]
+ (ShuffleWriteMetrics) args[4],
+ (BlockId) args[0]
);
}
});
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
index 11c3a7be38..a1c9f6fab8 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -130,7 +130,8 @@ public class UnsafeExternalSorterSuite {
(Integer) args[3],
new CompressStream(),
false,
- (ShuffleWriteMetrics) args[4]
+ (ShuffleWriteMetrics) args[4],
+ (BlockId) args[0]
);
}
});
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index 4a0877d86f..0de10ae485 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -17,12 +17,16 @@
package org.apache.spark
+import java.util.concurrent.{Callable, Executors, ExecutorService, CyclicBarrier}
+
import org.scalatest.Matchers
import org.apache.spark.ShuffleSuite.NonJavaSerializableClass
+import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.rdd.{CoGroupedRDD, OrderedRDDFunctions, RDD, ShuffledRDD, SubtractedRDD}
-import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
+import org.apache.spark.scheduler.{MyRDD, MapStatus, SparkListener, SparkListenerTaskEnd}
import org.apache.spark.serializer.KryoSerializer
+import org.apache.spark.shuffle.ShuffleWriter
import org.apache.spark.storage.{ShuffleDataBlockId, ShuffleBlockId}
import org.apache.spark.util.MutablePair
@@ -317,6 +321,107 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
assert(metrics.bytesWritten === metrics.byresRead)
assert(metrics.bytesWritten > 0)
}
+
+ test("multiple simultaneous attempts for one task (SPARK-8029)") {
+ sc = new SparkContext("local", "test", conf)
+ val mapTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
+ val manager = sc.env.shuffleManager
+
+ val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0L)
+ val metricsSystem = sc.env.metricsSystem
+ val shuffleMapRdd = new MyRDD(sc, 1, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1))
+ val shuffleHandle = manager.registerShuffle(0, 1, shuffleDep)
+
+ // first attempt -- its successful
+ val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0,
+ new TaskContextImpl(0, 0, 0L, 0, taskMemoryManager, metricsSystem,
+ InternalAccumulator.create(sc)))
+ val data1 = (1 to 10).map { x => x -> x}
+
+ // second attempt -- also successful. We'll write out different data,
+ // just to simulate the fact that the records may get written differently
+ // depending on what gets spilled, what gets combined, etc.
+ val writer2 = manager.getWriter[Int, Int](shuffleHandle, 0,
+ new TaskContextImpl(0, 0, 1L, 0, taskMemoryManager, metricsSystem,
+ InternalAccumulator.create(sc)))
+ val data2 = (11 to 20).map { x => x -> x}
+
+ // interleave writes of both attempts -- we want to test that both attempts can occur
+ // simultaneously, and everything is still OK
+
+ def writeAndClose(
+ writer: ShuffleWriter[Int, Int])(
+ iter: Iterator[(Int, Int)]): Option[MapStatus] = {
+ val files = writer.write(iter)
+ writer.stop(true)
+ }
+ val interleaver = new InterleaveIterators(
+ data1, writeAndClose(writer1), data2, writeAndClose(writer2))
+ val (mapOutput1, mapOutput2) = interleaver.run()
+
+ // check that we can read the map output and it has the right data
+ assert(mapOutput1.isDefined)
+ assert(mapOutput2.isDefined)
+ assert(mapOutput1.get.location === mapOutput2.get.location)
+ assert(mapOutput1.get.getSizeForBlock(0) === mapOutput1.get.getSizeForBlock(0))
+
+ // register one of the map outputs -- doesn't matter which one
+ mapOutput1.foreach { case mapStatus =>
+ mapTrackerMaster.registerMapOutputs(0, Array(mapStatus))
+ }
+
+ val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1,
+ new TaskContextImpl(1, 0, 2L, 0, taskMemoryManager, metricsSystem,
+ InternalAccumulator.create(sc)))
+ val readData = reader.read().toIndexedSeq
+ assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq)
+
+ manager.unregisterShuffle(0)
+ }
+}
+
+/**
+ * Utility to help tests make sure that we can process two different iterators simultaneously
+ * in different threads. This makes sure that in your test, you don't completely process data1 with
+ * f1 before processing data2 with f2 (or vice versa). It adds a barrier so that the functions only
+ * process one element, before pausing to wait for the other function to "catch up".
+ */
+class InterleaveIterators[T, R](
+ data1: Seq[T],
+ f1: Iterator[T] => R,
+ data2: Seq[T],
+ f2: Iterator[T] => R) {
+
+ require(data1.size == data2.size)
+
+ val barrier = new CyclicBarrier(2)
+ class BarrierIterator[E](id: Int, sub: Iterator[E]) extends Iterator[E] {
+ def hasNext: Boolean = sub.hasNext
+
+ def next: E = {
+ barrier.await()
+ sub.next()
+ }
+ }
+
+ val c1 = new Callable[R] {
+ override def call(): R = f1(new BarrierIterator(1, data1.iterator))
+ }
+ val c2 = new Callable[R] {
+ override def call(): R = f2(new BarrierIterator(2, data2.iterator))
+ }
+
+ val e: ExecutorService = Executors.newFixedThreadPool(2)
+
+ def run(): (R, R) = {
+ val future1 = e.submit(c1)
+ val future2 = e.submit(c2)
+ val r1 = future1.get()
+ val r2 = future2.get()
+ e.shutdown()
+ (r1, r2)
+ }
}
object ShuffleSuite {
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
index b92a302806..d3b1b2b620 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
@@ -68,6 +68,17 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
when(dependency.serializer).thenReturn(Some(new JavaSerializer(conf)))
when(taskContext.taskMetrics()).thenReturn(taskMetrics)
when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile)
+ doAnswer(new Answer[Void] {
+ def answer(invocationOnMock: InvocationOnMock): Void = {
+ val tmp: File = invocationOnMock.getArguments()(3).asInstanceOf[File]
+ if (tmp != null) {
+ outputFile.delete
+ tmp.renameTo(outputFile)
+ }
+ null
+ }
+ }).when(blockResolver)
+ .writeIndexFileAndCommit(anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File]))
when(blockManager.diskBlockManager).thenReturn(diskBlockManager)
when(blockManager.getDiskWriter(
any[BlockId],
@@ -84,7 +95,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
args(3).asInstanceOf[Int],
compressStream = identity,
syncWrites = false,
- args(4).asInstanceOf[ShuffleWriteMetrics]
+ args(4).asInstanceOf[ShuffleWriteMetrics],
+ blockId = args(0).asInstanceOf[BlockId]
)
}
})