aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-11-12 22:44:57 -0800
committerDavies Liu <davies.liu@gmail.com>2015-11-12 22:44:57 -0800
commitad960885bfee7850c18eb5338546cecf2b2e9876 (patch)
tree07f56b97c6e7f38a7400dabb98be6f1942fab184
parentea5ae2705afa4eaadd4192c37d74c97364378cf9 (diff)
downloadspark-ad960885bfee7850c18eb5338546cecf2b2e9876.tar.gz
spark-ad960885bfee7850c18eb5338546cecf2b2e9876.tar.bz2
spark-ad960885bfee7850c18eb5338546cecf2b2e9876.zip
[SPARK-8029] Robust shuffle writer
Currently, all the shuffle writer will write to target path directly, the file could be corrupted by other attempt of the same partition on the same executor. They should write to temporary file then rename to target path, as what we do in output committer. In order to make the rename atomic, the temporary file should be created in the same local directory (FileSystem). This PR is based on #9214 , thanks to squito . Closes #9214 Author: Davies Liu <davies@databricks.com> Closes #9610 from davies/safe_shuffle.
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java9
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java13
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala17
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala102
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala25
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala1
-rw-r--r--core/src/test/java/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala114
-rw-r--r--core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java9
-rw-r--r--core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java3
-rw-r--r--core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java3
-rw-r--r--core/src/test/scala/org/apache/spark/ShuffleSuite.scala107
-rw-r--r--core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala14
16 files changed, 402 insertions, 52 deletions
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
index ee82d67993..a1a1fb0142 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -125,7 +125,7 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
assert (partitionWriters == null);
if (!records.hasNext()) {
partitionLengths = new long[numPartitions];
- shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
+ shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, null);
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
return;
}
@@ -155,9 +155,10 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
writer.commitAndClose();
}
- partitionLengths =
- writePartitionedFile(shuffleBlockResolver.getDataFile(shuffleId, mapId));
- shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
+ File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);
+ File tmp = Utils.tempFileWith(output);
+ partitionLengths = writePartitionedFile(tmp);
+ shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp);
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index 6a0a89e81c..744c3008ca 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -41,7 +41,7 @@ import org.apache.spark.annotation.Private;
import org.apache.spark.executor.ShuffleWriteMetrics;
import org.apache.spark.io.CompressionCodec;
import org.apache.spark.io.CompressionCodec$;
-import org.apache.spark.io.LZFCompressionCodec;
+import org.apache.spark.memory.TaskMemoryManager;
import org.apache.spark.network.util.LimitedInputStream;
import org.apache.spark.scheduler.MapStatus;
import org.apache.spark.scheduler.MapStatus$;
@@ -53,7 +53,7 @@ import org.apache.spark.shuffle.ShuffleWriter;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.TimeTrackingOutputStream;
import org.apache.spark.unsafe.Platform;
-import org.apache.spark.memory.TaskMemoryManager;
+import org.apache.spark.util.Utils;
@Private
public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
@@ -206,8 +206,10 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
final SpillInfo[] spills = sorter.closeAndGetSpills();
sorter = null;
final long[] partitionLengths;
+ final File output = shuffleBlockResolver.getDataFile(shuffleId, mapId);
+ final File tmp = Utils.tempFileWith(output);
try {
- partitionLengths = mergeSpills(spills);
+ partitionLengths = mergeSpills(spills, tmp);
} finally {
for (SpillInfo spill : spills) {
if (spill.file.exists() && ! spill.file.delete()) {
@@ -215,7 +217,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
}
}
}
- shuffleBlockResolver.writeIndexFile(shuffleId, mapId, partitionLengths);
+ shuffleBlockResolver.writeIndexFileAndCommit(shuffleId, mapId, partitionLengths, tmp);
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
}
@@ -248,8 +250,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
*
* @return the partition lengths in the merged file.
*/
- private long[] mergeSpills(SpillInfo[] spills) throws IOException {
- final File outputFile = shuffleBlockResolver.getDataFile(shuffleId, mapId);
+ private long[] mergeSpills(SpillInfo[] spills, File outputFile) throws IOException {
final boolean compressionEnabled = sparkConf.getBoolean("spark.shuffle.compress", true);
final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf);
final boolean fastMergeEnabled =
diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
index cd253a78c2..39fadd8783 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockResolver.scala
@@ -21,13 +21,13 @@ import java.util.concurrent.ConcurrentLinkedQueue
import scala.collection.JavaConverters._
-import org.apache.spark.{Logging, SparkConf, SparkEnv}
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.serializer.Serializer
import org.apache.spark.storage._
-import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap}
+import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedHashMap, Utils}
+import org.apache.spark.{Logging, SparkConf, SparkEnv}
/** A group of writers for a ShuffleMapTask, one writer per reducer. */
private[spark] trait ShuffleWriterGroup {
@@ -84,17 +84,8 @@ private[spark] class FileShuffleBlockResolver(conf: SparkConf)
Array.tabulate[DiskBlockObjectWriter](numReducers) { bucketId =>
val blockId = ShuffleBlockId(shuffleId, mapId, bucketId)
val blockFile = blockManager.diskBlockManager.getFile(blockId)
- // Because of previous failures, the shuffle file may already exist on this machine.
- // If so, remove it.
- if (blockFile.exists) {
- if (blockFile.delete()) {
- logInfo(s"Removed existing shuffle file $blockFile")
- } else {
- logWarning(s"Failed to remove existing shuffle file $blockFile")
- }
- }
- blockManager.getDiskWriter(blockId, blockFile, serializerInstance, bufferSize,
- writeMetrics)
+ val tmp = Utils.tempFileWith(blockFile)
+ blockManager.getDiskWriter(blockId, tmp, serializerInstance, bufferSize, writeMetrics)
}
}
// Creating the file to write to and creating a disk writer both involve interacting with
diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
index 5e4c2b5d0a..05b1eed7f3 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockResolver.scala
@@ -21,13 +21,12 @@ import java.io._
import com.google.common.io.ByteStreams
-import org.apache.spark.{SparkConf, SparkEnv, Logging}
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
import org.apache.spark.network.netty.SparkTransportConf
+import org.apache.spark.shuffle.IndexShuffleBlockResolver.NOOP_REDUCE_ID
import org.apache.spark.storage._
import org.apache.spark.util.Utils
-
-import IndexShuffleBlockResolver.NOOP_REDUCE_ID
+import org.apache.spark.{SparkEnv, Logging, SparkConf}
/**
* Create and maintain the shuffle blocks' mapping between logic block and physical file location.
@@ -40,10 +39,13 @@ import IndexShuffleBlockResolver.NOOP_REDUCE_ID
*/
// Note: Changes to the format in this file should be kept in sync with
// org.apache.spark.network.shuffle.ExternalShuffleBlockResolver#getSortBasedShuffleBlockData().
-private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleBlockResolver
+private[spark] class IndexShuffleBlockResolver(
+ conf: SparkConf,
+ _blockManager: BlockManager = null)
+ extends ShuffleBlockResolver
with Logging {
- private lazy val blockManager = SparkEnv.get.blockManager
+ private lazy val blockManager = Option(_blockManager).getOrElse(SparkEnv.get.blockManager)
private val transportConf = SparkTransportConf.fromSparkConf(conf)
@@ -75,13 +77,68 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleB
}
/**
+ * Check whether the given index and data files match each other.
+ * If so, return the partition lengths in the data file. Otherwise return null.
+ */
+ private def checkIndexAndDataFile(index: File, data: File, blocks: Int): Array[Long] = {
+ // the index file should have `block + 1` longs as offset.
+ if (index.length() != (blocks + 1) * 8) {
+ return null
+ }
+ val lengths = new Array[Long](blocks)
+ // Read the lengths of blocks
+ val in = try {
+ new DataInputStream(new BufferedInputStream(new FileInputStream(index)))
+ } catch {
+ case e: IOException =>
+ return null
+ }
+ try {
+ // Convert the offsets into lengths of each block
+ var offset = in.readLong()
+ if (offset != 0L) {
+ return null
+ }
+ var i = 0
+ while (i < blocks) {
+ val off = in.readLong()
+ lengths(i) = off - offset
+ offset = off
+ i += 1
+ }
+ } catch {
+ case e: IOException =>
+ return null
+ } finally {
+ in.close()
+ }
+
+ // the size of data file should match with index file
+ if (data.length() == lengths.sum) {
+ lengths
+ } else {
+ null
+ }
+ }
+
+ /**
* Write an index file with the offsets of each block, plus a final offset at the end for the
* end of the output file. This will be used by getBlockData to figure out where each block
* begins and ends.
+ *
+ * It will commit the data and index file as an atomic operation, use the existing ones, or
+ * replace them with new ones.
+ *
+ * Note: the `lengths` will be updated to match the existing index file if use the existing ones.
* */
- def writeIndexFile(shuffleId: Int, mapId: Int, lengths: Array[Long]): Unit = {
+ def writeIndexFileAndCommit(
+ shuffleId: Int,
+ mapId: Int,
+ lengths: Array[Long],
+ dataTmp: File): Unit = {
val indexFile = getIndexFile(shuffleId, mapId)
- val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexFile)))
+ val indexTmp = Utils.tempFileWith(indexFile)
+ val out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexTmp)))
Utils.tryWithSafeFinally {
// We take in lengths of each block, need to convert it to offsets.
var offset = 0L
@@ -93,6 +150,37 @@ private[spark] class IndexShuffleBlockResolver(conf: SparkConf) extends ShuffleB
} {
out.close()
}
+
+ val dataFile = getDataFile(shuffleId, mapId)
+ // There is only one IndexShuffleBlockResolver per executor, this synchronization make sure
+ // the following check and rename are atomic.
+ synchronized {
+ val existingLengths = checkIndexAndDataFile(indexFile, dataFile, lengths.length)
+ if (existingLengths != null) {
+ // Another attempt for the same task has already written our map outputs successfully,
+ // so just use the existing partition lengths and delete our temporary map outputs.
+ System.arraycopy(existingLengths, 0, lengths, 0, lengths.length)
+ if (dataTmp != null && dataTmp.exists()) {
+ dataTmp.delete()
+ }
+ indexTmp.delete()
+ } else {
+ // This is the first successful attempt in writing the map outputs for this task,
+ // so override any existing index and data files with the ones we wrote.
+ if (indexFile.exists()) {
+ indexFile.delete()
+ }
+ if (dataFile.exists()) {
+ dataFile.delete()
+ }
+ if (!indexTmp.renameTo(indexFile)) {
+ throw new IOException("fail to rename file " + indexTmp + " to " + indexFile)
+ }
+ if (dataTmp != null && dataTmp.exists() && !dataTmp.renameTo(dataFile)) {
+ throw new IOException("fail to rename file " + dataTmp + " to " + dataFile)
+ }
+ }
+ }
}
override def getBlockData(blockId: ShuffleBlockId): ManagedBuffer = {
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
index 41df70c602..412bf70000 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
@@ -17,6 +17,8 @@
package org.apache.spark.shuffle.hash
+import java.io.IOException
+
import org.apache.spark._
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.scheduler.MapStatus
@@ -106,6 +108,29 @@ private[spark] class HashShuffleWriter[K, V](
writer.commitAndClose()
writer.fileSegment().length
}
+ // rename all shuffle files to final paths
+ // Note: there is only one ShuffleBlockResolver in executor
+ shuffleBlockResolver.synchronized {
+ shuffle.writers.zipWithIndex.foreach { case (writer, i) =>
+ val output = blockManager.diskBlockManager.getFile(writer.blockId)
+ if (sizes(i) > 0) {
+ if (output.exists()) {
+ // Use length of existing file and delete our own temporary one
+ sizes(i) = output.length()
+ writer.file.delete()
+ } else {
+ // Commit by renaming our temporary file to something the fetcher expects
+ if (!writer.file.renameTo(output)) {
+ throw new IOException(s"fail to rename ${writer.file} to $output")
+ }
+ }
+ } else {
+ if (output.exists()) {
+ output.delete()
+ }
+ }
+ }
+ }
MapStatus(blockManager.shuffleServerId, sizes)
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index 808317b017..f83cf8859e 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -20,8 +20,9 @@ package org.apache.spark.shuffle.sort
import org.apache.spark._
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.scheduler.MapStatus
-import org.apache.spark.shuffle.{IndexShuffleBlockResolver, ShuffleWriter, BaseShuffleHandle}
+import org.apache.spark.shuffle.{BaseShuffleHandle, IndexShuffleBlockResolver, ShuffleWriter}
import org.apache.spark.storage.ShuffleBlockId
+import org.apache.spark.util.Utils
import org.apache.spark.util.collection.ExternalSorter
private[spark] class SortShuffleWriter[K, V, C](
@@ -65,11 +66,11 @@ private[spark] class SortShuffleWriter[K, V, C](
// Don't bother including the time to open the merged output file in the shuffle write time,
// because it just opens a single file, so is typically too fast to measure accurately
// (see SPARK-3570).
- val outputFile = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
+ val output = shuffleBlockResolver.getDataFile(dep.shuffleId, mapId)
+ val tmp = Utils.tempFileWith(output)
val blockId = ShuffleBlockId(dep.shuffleId, mapId, IndexShuffleBlockResolver.NOOP_REDUCE_ID)
- val partitionLengths = sorter.writePartitionedFile(blockId, outputFile)
- shuffleBlockResolver.writeIndexFile(dep.shuffleId, mapId, partitionLengths)
-
+ val partitionLengths = sorter.writePartitionedFile(blockId, tmp)
+ shuffleBlockResolver.writeIndexFileAndCommit(dep.shuffleId, mapId, partitionLengths, tmp)
mapStatus = MapStatus(blockManager.shuffleServerId, partitionLengths)
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index c374b93766..661c706af3 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -21,10 +21,10 @@ import java.io._
import java.nio.{ByteBuffer, MappedByteBuffer}
import scala.collection.mutable.{ArrayBuffer, HashMap}
-import scala.concurrent.{ExecutionContext, Await, Future}
import scala.concurrent.duration._
-import scala.util.control.NonFatal
+import scala.concurrent.{Await, ExecutionContext, Future}
import scala.util.Random
+import scala.util.control.NonFatal
import sun.nio.ch.DirectBuffer
@@ -38,9 +38,8 @@ import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.network.shuffle.ExternalShuffleClient
import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
import org.apache.spark.rpc.RpcEnv
-import org.apache.spark.serializer.{SerializerInstance, Serializer}
+import org.apache.spark.serializer.{Serializer, SerializerInstance}
import org.apache.spark.shuffle.ShuffleManager
-import org.apache.spark.shuffle.hash.HashShuffleManager
import org.apache.spark.util._
private[spark] sealed trait BlockValues
@@ -660,7 +659,7 @@ private[spark] class BlockManager(
val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _)
val syncWrites = conf.getBoolean("spark.shuffle.sync", false)
new DiskBlockObjectWriter(file, serializerInstance, bufferSize, compressStream,
- syncWrites, writeMetrics)
+ syncWrites, writeMetrics, blockId)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
index 80d426fadc..e2dd80f243 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockObjectWriter.scala
@@ -34,14 +34,15 @@ import org.apache.spark.util.Utils
* reopened again.
*/
private[spark] class DiskBlockObjectWriter(
- file: File,
+ val file: File,
serializerInstance: SerializerInstance,
bufferSize: Int,
compressStream: OutputStream => OutputStream,
syncWrites: Boolean,
// These write metrics concurrently shared with other active DiskBlockObjectWriters who
// are themselves performing writes. All updates must be relative.
- writeMetrics: ShuffleWriteMetrics)
+ writeMetrics: ShuffleWriteMetrics,
+ val blockId: BlockId = null)
extends OutputStream
with Logging {
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index 316c194ff3..1b3acb8ef7 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -21,8 +21,8 @@ import java.io._
import java.lang.management.ManagementFactory
import java.net._
import java.nio.ByteBuffer
-import java.util.{Properties, Locale, Random, UUID}
import java.util.concurrent._
+import java.util.{Locale, Properties, Random, UUID}
import javax.net.ssl.HttpsURLConnection
import scala.collection.JavaConverters._
@@ -30,7 +30,7 @@ import scala.collection.Map
import scala.collection.mutable.ArrayBuffer
import scala.io.Source
import scala.reflect.ClassTag
-import scala.util.{Failure, Success, Try}
+import scala.util.Try
import scala.util.control.{ControlThrowable, NonFatal}
import com.google.common.io.{ByteStreams, Files}
@@ -42,7 +42,6 @@ import org.apache.hadoop.security.UserGroupInformation
import org.apache.log4j.PropertyConfigurator
import org.eclipse.jetty.util.MultiException
import org.json4s._
-
import tachyon.TachyonURI
import tachyon.client.{TachyonFS, TachyonFile}
@@ -2169,6 +2168,13 @@ private[spark] object Utils extends Logging {
val resource = createResource
try f.apply(resource) finally resource.close()
}
+
+ /**
+ * Returns a path of temporary file which is in the same directory with `path`.
+ */
+ def tempFileWith(path: File): File = {
+ new File(path.getAbsolutePath + "." + UUID.randomUUID())
+ }
}
/**
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index bd6844d045..2440139ac9 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -638,7 +638,6 @@ private[spark] class ExternalSorter[K, V, C](
* called by the SortShuffleWriter.
*
* @param blockId block ID to write to. The index file will be blockId.name + ".index".
- * @param context a TaskContext for a running Spark task, for us to update shuffle metrics.
* @return array of lengths, in bytes, of each partition of the file (used by map output tracker)
*/
def writePartitionedFile(
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala b/core/src/test/java/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala
new file mode 100644
index 0000000000..0b19861fc4
--- /dev/null
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/IndexShuffleBlockResolverSuite.scala
@@ -0,0 +1,114 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.shuffle.sort
+
+import java.io.{File, FileInputStream, FileOutputStream}
+
+import org.mockito.Answers.RETURNS_SMART_NULLS
+import org.mockito.Matchers._
+import org.mockito.Mockito._
+import org.mockito.invocation.InvocationOnMock
+import org.mockito.stubbing.Answer
+import org.mockito.{Mock, MockitoAnnotations}
+import org.scalatest.BeforeAndAfterEach
+
+import org.apache.spark.shuffle.IndexShuffleBlockResolver
+import org.apache.spark.storage._
+import org.apache.spark.util.Utils
+import org.apache.spark.{SparkConf, SparkFunSuite}
+
+
+class IndexShuffleBlockResolverSuite extends SparkFunSuite with BeforeAndAfterEach {
+
+ @Mock(answer = RETURNS_SMART_NULLS) private var blockManager: BlockManager = _
+ @Mock(answer = RETURNS_SMART_NULLS) private var diskBlockManager: DiskBlockManager = _
+
+ private var tempDir: File = _
+ private val conf: SparkConf = new SparkConf(loadDefaults = false)
+
+ override def beforeEach(): Unit = {
+ tempDir = Utils.createTempDir()
+ MockitoAnnotations.initMocks(this)
+
+ when(blockManager.diskBlockManager).thenReturn(diskBlockManager)
+ when(diskBlockManager.getFile(any[BlockId])).thenAnswer(
+ new Answer[File] {
+ override def answer(invocation: InvocationOnMock): File = {
+ new File(tempDir, invocation.getArguments.head.toString)
+ }
+ })
+ }
+
+ override def afterEach(): Unit = {
+ Utils.deleteRecursively(tempDir)
+ }
+
+ test("commit shuffle files multiple times") {
+ val lengths = Array[Long](10, 0, 20)
+ val resolver = new IndexShuffleBlockResolver(conf, blockManager)
+ val dataTmp = File.createTempFile("shuffle", null, tempDir)
+ val out = new FileOutputStream(dataTmp)
+ out.write(new Array[Byte](30))
+ out.close()
+ resolver.writeIndexFileAndCommit(1, 2, lengths, dataTmp)
+
+ val dataFile = resolver.getDataFile(1, 2)
+ assert(dataFile.exists())
+ assert(dataFile.length() === 30)
+ assert(!dataTmp.exists())
+
+ val dataTmp2 = File.createTempFile("shuffle", null, tempDir)
+ val out2 = new FileOutputStream(dataTmp2)
+ val lengths2 = new Array[Long](3)
+ out2.write(Array[Byte](1))
+ out2.write(new Array[Byte](29))
+ out2.close()
+ resolver.writeIndexFileAndCommit(1, 2, lengths2, dataTmp2)
+ assert(lengths2.toSeq === lengths.toSeq)
+ assert(dataFile.exists())
+ assert(dataFile.length() === 30)
+ assert(!dataTmp2.exists())
+
+ // The dataFile should be the previous one
+ val in = new FileInputStream(dataFile)
+ val firstByte = new Array[Byte](1)
+ in.read(firstByte)
+ assert(firstByte(0) === 0)
+
+ // remove data file
+ dataFile.delete()
+
+ val dataTmp3 = File.createTempFile("shuffle", null, tempDir)
+ val out3 = new FileOutputStream(dataTmp3)
+ val lengths3 = Array[Long](10, 10, 15)
+ out3.write(Array[Byte](2))
+ out3.write(new Array[Byte](34))
+ out3.close()
+ resolver.writeIndexFileAndCommit(1, 2, lengths3, dataTmp3)
+ assert(lengths3.toSeq != lengths.toSeq)
+ assert(dataFile.exists())
+ assert(dataFile.length() === 35)
+ assert(!dataTmp2.exists())
+
+ // The dataFile should be the previous one
+ val in2 = new FileInputStream(dataFile)
+ val firstByte2 = new Array[Byte](1)
+ in2.read(firstByte2)
+ assert(firstByte2(0) === 2)
+ }
+}
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index 0e0eca515a..bc85918c59 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -130,7 +130,8 @@ public class UnsafeShuffleWriterSuite {
(Integer) args[3],
new CompressStream(),
false,
- (ShuffleWriteMetrics) args[4]
+ (ShuffleWriteMetrics) args[4],
+ (BlockId) args[0]
);
}
});
@@ -169,9 +170,13 @@ public class UnsafeShuffleWriterSuite {
@Override
public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
partitionSizesInMergedFile = (long[]) invocationOnMock.getArguments()[2];
+ File tmp = (File) invocationOnMock.getArguments()[3];
+ mergedOutputFile.delete();
+ tmp.renameTo(mergedOutputFile);
return null;
}
- }).when(shuffleBlockResolver).writeIndexFile(anyInt(), anyInt(), any(long[].class));
+ }).when(shuffleBlockResolver)
+ .writeIndexFileAndCommit(anyInt(), anyInt(), any(long[].class), any(File.class));
when(diskBlockManager.createTempShuffleBlock()).thenAnswer(
new Answer<Tuple2<TempShuffleBlockId, File>>() {
diff --git a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
index 3bca790f30..d87a1d2a56 100644
--- a/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
+++ b/core/src/test/java/org/apache/spark/unsafe/map/AbstractBytesToBytesMapSuite.java
@@ -117,7 +117,8 @@ public abstract class AbstractBytesToBytesMapSuite {
(Integer) args[3],
new CompressStream(),
false,
- (ShuffleWriteMetrics) args[4]
+ (ShuffleWriteMetrics) args[4],
+ (BlockId) args[0]
);
}
});
diff --git a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
index 11c3a7be38..a1c9f6fab8 100644
--- a/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
+++ b/core/src/test/java/org/apache/spark/util/collection/unsafe/sort/UnsafeExternalSorterSuite.java
@@ -130,7 +130,8 @@ public class UnsafeExternalSorterSuite {
(Integer) args[3],
new CompressStream(),
false,
- (ShuffleWriteMetrics) args[4]
+ (ShuffleWriteMetrics) args[4],
+ (BlockId) args[0]
);
}
});
diff --git a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
index 4a0877d86f..0de10ae485 100644
--- a/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
+++ b/core/src/test/scala/org/apache/spark/ShuffleSuite.scala
@@ -17,12 +17,16 @@
package org.apache.spark
+import java.util.concurrent.{Callable, Executors, ExecutorService, CyclicBarrier}
+
import org.scalatest.Matchers
import org.apache.spark.ShuffleSuite.NonJavaSerializableClass
+import org.apache.spark.memory.TaskMemoryManager
import org.apache.spark.rdd.{CoGroupedRDD, OrderedRDDFunctions, RDD, ShuffledRDD, SubtractedRDD}
-import org.apache.spark.scheduler.{SparkListener, SparkListenerTaskEnd}
+import org.apache.spark.scheduler.{MyRDD, MapStatus, SparkListener, SparkListenerTaskEnd}
import org.apache.spark.serializer.KryoSerializer
+import org.apache.spark.shuffle.ShuffleWriter
import org.apache.spark.storage.{ShuffleDataBlockId, ShuffleBlockId}
import org.apache.spark.util.MutablePair
@@ -317,6 +321,107 @@ abstract class ShuffleSuite extends SparkFunSuite with Matchers with LocalSparkC
assert(metrics.bytesWritten === metrics.byresRead)
assert(metrics.bytesWritten > 0)
}
+
+ test("multiple simultaneous attempts for one task (SPARK-8029)") {
+ sc = new SparkContext("local", "test", conf)
+ val mapTrackerMaster = sc.env.mapOutputTracker.asInstanceOf[MapOutputTrackerMaster]
+ val manager = sc.env.shuffleManager
+
+ val taskMemoryManager = new TaskMemoryManager(sc.env.memoryManager, 0L)
+ val metricsSystem = sc.env.metricsSystem
+ val shuffleMapRdd = new MyRDD(sc, 1, Nil)
+ val shuffleDep = new ShuffleDependency(shuffleMapRdd, new HashPartitioner(1))
+ val shuffleHandle = manager.registerShuffle(0, 1, shuffleDep)
+
+ // first attempt -- its successful
+ val writer1 = manager.getWriter[Int, Int](shuffleHandle, 0,
+ new TaskContextImpl(0, 0, 0L, 0, taskMemoryManager, metricsSystem,
+ InternalAccumulator.create(sc)))
+ val data1 = (1 to 10).map { x => x -> x}
+
+ // second attempt -- also successful. We'll write out different data,
+ // just to simulate the fact that the records may get written differently
+ // depending on what gets spilled, what gets combined, etc.
+ val writer2 = manager.getWriter[Int, Int](shuffleHandle, 0,
+ new TaskContextImpl(0, 0, 1L, 0, taskMemoryManager, metricsSystem,
+ InternalAccumulator.create(sc)))
+ val data2 = (11 to 20).map { x => x -> x}
+
+ // interleave writes of both attempts -- we want to test that both attempts can occur
+ // simultaneously, and everything is still OK
+
+ def writeAndClose(
+ writer: ShuffleWriter[Int, Int])(
+ iter: Iterator[(Int, Int)]): Option[MapStatus] = {
+ val files = writer.write(iter)
+ writer.stop(true)
+ }
+ val interleaver = new InterleaveIterators(
+ data1, writeAndClose(writer1), data2, writeAndClose(writer2))
+ val (mapOutput1, mapOutput2) = interleaver.run()
+
+ // check that we can read the map output and it has the right data
+ assert(mapOutput1.isDefined)
+ assert(mapOutput2.isDefined)
+ assert(mapOutput1.get.location === mapOutput2.get.location)
+ assert(mapOutput1.get.getSizeForBlock(0) === mapOutput1.get.getSizeForBlock(0))
+
+ // register one of the map outputs -- doesn't matter which one
+ mapOutput1.foreach { case mapStatus =>
+ mapTrackerMaster.registerMapOutputs(0, Array(mapStatus))
+ }
+
+ val reader = manager.getReader[Int, Int](shuffleHandle, 0, 1,
+ new TaskContextImpl(1, 0, 2L, 0, taskMemoryManager, metricsSystem,
+ InternalAccumulator.create(sc)))
+ val readData = reader.read().toIndexedSeq
+ assert(readData === data1.toIndexedSeq || readData === data2.toIndexedSeq)
+
+ manager.unregisterShuffle(0)
+ }
+}
+
+/**
+ * Utility to help tests make sure that we can process two different iterators simultaneously
+ * in different threads. This makes sure that in your test, you don't completely process data1 with
+ * f1 before processing data2 with f2 (or vice versa). It adds a barrier so that the functions only
+ * process one element, before pausing to wait for the other function to "catch up".
+ */
+class InterleaveIterators[T, R](
+ data1: Seq[T],
+ f1: Iterator[T] => R,
+ data2: Seq[T],
+ f2: Iterator[T] => R) {
+
+ require(data1.size == data2.size)
+
+ val barrier = new CyclicBarrier(2)
+ class BarrierIterator[E](id: Int, sub: Iterator[E]) extends Iterator[E] {
+ def hasNext: Boolean = sub.hasNext
+
+ def next: E = {
+ barrier.await()
+ sub.next()
+ }
+ }
+
+ val c1 = new Callable[R] {
+ override def call(): R = f1(new BarrierIterator(1, data1.iterator))
+ }
+ val c2 = new Callable[R] {
+ override def call(): R = f2(new BarrierIterator(2, data2.iterator))
+ }
+
+ val e: ExecutorService = Executors.newFixedThreadPool(2)
+
+ def run(): (R, R) = {
+ val future1 = e.submit(c1)
+ val future2 = e.submit(c2)
+ val r1 = future1.get()
+ val r2 = future2.get()
+ e.shutdown()
+ (r1, r2)
+ }
}
object ShuffleSuite {
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
index b92a302806..d3b1b2b620 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
@@ -68,6 +68,17 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
when(dependency.serializer).thenReturn(Some(new JavaSerializer(conf)))
when(taskContext.taskMetrics()).thenReturn(taskMetrics)
when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile)
+ doAnswer(new Answer[Void] {
+ def answer(invocationOnMock: InvocationOnMock): Void = {
+ val tmp: File = invocationOnMock.getArguments()(3).asInstanceOf[File]
+ if (tmp != null) {
+ outputFile.delete
+ tmp.renameTo(outputFile)
+ }
+ null
+ }
+ }).when(blockResolver)
+ .writeIndexFileAndCommit(anyInt, anyInt, any(classOf[Array[Long]]), any(classOf[File]))
when(blockManager.diskBlockManager).thenReturn(diskBlockManager)
when(blockManager.getDiskWriter(
any[BlockId],
@@ -84,7 +95,8 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
args(3).asInstanceOf[Int],
compressStream = identity,
syncWrites = false,
- args(4).asInstanceOf[ShuffleWriteMetrics]
+ args(4).asInstanceOf[ShuffleWriteMetrics],
+ blockId = args(0).asInstanceOf[BlockId]
)
}
})