aboutsummaryrefslogtreecommitdiff
path: root/core/src
diff options
context:
space:
mode:
Diffstat (limited to 'core/src')
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java4
-rw-r--r--core/src/main/scala/org/apache/spark/io/CompressionCodec.scala5
-rw-r--r--core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala6
3 files changed, 9 insertions, 6 deletions
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index e19b378642..6a0a89e81c 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -254,8 +254,8 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
final CompressionCodec compressionCodec = CompressionCodec$.MODULE$.createCodec(sparkConf);
final boolean fastMergeEnabled =
sparkConf.getBoolean("spark.shuffle.unsafe.fastMergeEnabled", true);
- final boolean fastMergeIsSupported =
- !compressionEnabled || compressionCodec instanceof LZFCompressionCodec;
+ final boolean fastMergeIsSupported = !compressionEnabled ||
+ CompressionCodec$.MODULE$.supportsConcatenationOfSerializedStreams(compressionCodec);
try {
if (spills.length == 0) {
new FileOutputStream(outputFile).close(); // Create an empty file
diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
index 9dc36704a6..ca74eedf89 100644
--- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
+++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala
@@ -47,6 +47,11 @@ trait CompressionCodec {
private[spark] object CompressionCodec {
private val configKey = "spark.io.compression.codec"
+
+ private[spark] def supportsConcatenationOfSerializedStreams(codec: CompressionCodec): Boolean = {
+ codec.isInstanceOf[SnappyCompressionCodec] || codec.isInstanceOf[LZFCompressionCodec]
+ }
+
private val shortCompressionCodecNames = Map(
"lz4" -> classOf[LZ4CompressionCodec].getName,
"lzf" -> classOf[LZFCompressionCodec].getName,
diff --git a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala
index cbdb33c89d..1553ab60bd 100644
--- a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala
+++ b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala
@@ -100,12 +100,10 @@ class CompressionCodecSuite extends SparkFunSuite {
testCodec(codec)
}
- test("snappy does not support concatenation of serialized streams") {
+ test("snappy supports concatenation of serialized streams") {
val codec = CompressionCodec.createCodec(conf, classOf[SnappyCompressionCodec].getName)
assert(codec.getClass === classOf[SnappyCompressionCodec])
- intercept[Exception] {
- testConcatenationOfSerializedStreams(codec)
- }
+ testConcatenationOfSerializedStreams(codec)
}
test("bad compression codec") {