aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorJosh Rosen <joshrosen@databricks.com>2016-03-16 22:52:55 -0700
committerReynold Xin <rxin@databricks.com>2016-03-16 22:52:55 -0700
commitde1a84e56e81347cb0d1ec67cc86944ea98bb9a9 (patch)
treea5a577ebb2049d55c46682161b95594ba2537201 /core
parentd1c193a2f1a5e2b98f5df1b86d7a7ec0ced13668 (diff)
downloadspark-de1a84e56e81347cb0d1ec67cc86944ea98bb9a9.tar.gz
spark-de1a84e56e81347cb0d1ec67cc86944ea98bb9a9.tar.bz2
spark-de1a84e56e81347cb0d1ec67cc86944ea98bb9a9.zip
[SPARK-13926] Automatically use Kryo serializer when shuffling RDDs with simple types
Because ClassTags are available when constructing ShuffledRDD we can use them to automatically use Kryo for shuffle serialization when the RDD's types are known to be compatible with Kryo. This patch introduces `SerializerManager`, a component which picks the "best" serializer for a shuffle given the elements' ClassTags. It will automatically pick a Kryo serializer for ShuffledRDDs whose key, value, and/or combiner types are primitives, arrays of primitives, or strings. In the future we can use this class as a narrow extension point to integrate specialized serializers for other types, such as ByteBuffers. In a planned followup patch, I will extend the BlockManager APIs so that we're able to use similar automatic serializer selection when caching RDDs (this is a little trickier because the ClassTags need to be threaded through many more places). Author: Josh Rosen <joshrosen@databricks.com> Closes #11755 from JoshRosen/automatically-pick-best-serializer.
Diffstat (limited to 'core')
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java2
-rw-r--r--core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java2
-rw-r--r--core/src/main/scala/org/apache/spark/Dependency.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/Serializer.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala71
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala5
-rw-r--r--core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java2
-rw-r--r--core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala8
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala34
18 files changed, 127 insertions, 70 deletions
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
index dc4f289ae7..052be54d8c 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriter.java
@@ -115,7 +115,7 @@ final class BypassMergeSortShuffleWriter<K, V> extends ShuffleWriter<K, V> {
this.partitioner = dep.partitioner();
this.numPartitions = partitioner.numPartitions();
this.writeMetrics = taskContext.taskMetrics().registerShuffleWriteMetrics();
- this.serializer = Serializer.getSerializer(dep.serializer());
+ this.serializer = dep.serializer();
this.shuffleBlockResolver = shuffleBlockResolver;
}
diff --git a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
index 3f4402bd3a..cd06ce9fb9 100644
--- a/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
+++ b/core/src/main/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriter.java
@@ -116,7 +116,7 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
this.mapId = mapId;
final ShuffleDependency<K, V, V> dep = handle.dependency();
this.shuffleId = dep.shuffleId();
- this.serializer = Serializer.getSerializer(dep.serializer()).newInstance();
+ this.serializer = dep.serializer().newInstance();
this.partitioner = dep.partitioner();
this.writeMetrics = taskContext.taskMetrics().registerShuffleWriteMetrics();
this.taskContext = taskContext;
diff --git a/core/src/main/scala/org/apache/spark/Dependency.scala b/core/src/main/scala/org/apache/spark/Dependency.scala
index b65cfdc4df..ca52ecafa2 100644
--- a/core/src/main/scala/org/apache/spark/Dependency.scala
+++ b/core/src/main/scala/org/apache/spark/Dependency.scala
@@ -59,9 +59,9 @@ abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] {
*
* @param _rdd the parent RDD
* @param partitioner partitioner used to partition the shuffle output
- * @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If set to None,
- * the default serializer, as specified by `spark.serializer` config option, will
- * be used.
+ * @param serializer [[org.apache.spark.serializer.Serializer Serializer]] to use. If not set
+ * explicitly then the default serializer, as specified by `spark.serializer`
+ * config option, will be used.
* @param keyOrdering key ordering for RDD's shuffles
* @param aggregator map/reduce-side aggregator for RDD's shuffle
* @param mapSideCombine whether to perform partial aggregation (also known as map-side combine)
@@ -70,7 +70,7 @@ abstract class NarrowDependency[T](_rdd: RDD[T]) extends Dependency[T] {
class ShuffleDependency[K: ClassTag, V: ClassTag, C: ClassTag](
@transient private val _rdd: RDD[_ <: Product2[K, V]],
val partitioner: Partitioner,
- val serializer: Option[Serializer] = None,
+ val serializer: Serializer = SparkEnv.get.serializer,
val keyOrdering: Option[Ordering[K]] = None,
val aggregator: Option[Aggregator[K, V, C]] = None,
val mapSideCombine: Boolean = false)
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index b3b3729625..668a913a20 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -35,7 +35,7 @@ import org.apache.spark.network.netty.NettyBlockTransferService
import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv}
import org.apache.spark.scheduler.{LiveListenerBus, OutputCommitCoordinator}
import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint
-import org.apache.spark.serializer.{JavaSerializer, Serializer}
+import org.apache.spark.serializer.{JavaSerializer, Serializer, SerializerManager}
import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.storage._
import org.apache.spark.util.{RpcUtils, Utils}
@@ -56,6 +56,7 @@ class SparkEnv (
private[spark] val rpcEnv: RpcEnv,
val serializer: Serializer,
val closureSerializer: Serializer,
+ val serializerManager: SerializerManager,
val mapOutputTracker: MapOutputTracker,
val shuffleManager: ShuffleManager,
val broadcastManager: BroadcastManager,
@@ -276,6 +277,8 @@ object SparkEnv extends Logging {
"spark.serializer", "org.apache.spark.serializer.JavaSerializer")
logDebug(s"Using serializer: ${serializer.getClass}")
+ val serializerManager = new SerializerManager(serializer, conf)
+
val closureSerializer = new JavaSerializer(conf)
def registerOrLookupEndpoint(
@@ -368,6 +371,7 @@ object SparkEnv extends Logging {
rpcEnv,
serializer,
closureSerializer,
+ serializerManager,
mapOutputTracker,
shuffleManager,
broadcastManager,
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index d9b0824b38..e5ebc63082 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -86,11 +86,11 @@ class CoGroupedRDD[K: ClassTag](
private type CoGroupValue = (Any, Int) // Int is dependency number
private type CoGroupCombiner = Array[CoGroup]
- private var serializer: Option[Serializer] = None
+ private var serializer: Serializer = SparkEnv.get.serializer
/** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
def setSerializer(serializer: Serializer): CoGroupedRDD[K] = {
- this.serializer = Option(serializer)
+ this.serializer = serializer
this
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
index 3ef506e156..800b42505d 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
@@ -44,7 +44,7 @@ class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag](
part: Partitioner)
extends RDD[(K, C)](prev.context, Nil) {
- private var serializer: Option[Serializer] = None
+ private var userSpecifiedSerializer: Option[Serializer] = None
private var keyOrdering: Option[Ordering[K]] = None
@@ -54,7 +54,7 @@ class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag](
/** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
def setSerializer(serializer: Serializer): ShuffledRDD[K, V, C] = {
- this.serializer = Option(serializer)
+ this.userSpecifiedSerializer = Option(serializer)
this
}
@@ -77,6 +77,14 @@ class ShuffledRDD[K: ClassTag, V: ClassTag, C: ClassTag](
}
override def getDependencies: Seq[Dependency[_]] = {
+ val serializer = userSpecifiedSerializer.getOrElse {
+ val serializerManager = SparkEnv.get.serializerManager
+ if (mapSideCombine) {
+ serializerManager.getSerializer(implicitly[ClassTag[K]], implicitly[ClassTag[C]])
+ } else {
+ serializerManager.getSerializer(implicitly[ClassTag[K]], implicitly[ClassTag[V]])
+ }
+ }
List(new ShuffleDependency(prev, part, serializer, keyOrdering, aggregator, mapSideCombine))
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
index 25ec685eff..a733eaa5d7 100644
--- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
@@ -30,7 +30,6 @@ import org.apache.spark.Partitioner
import org.apache.spark.ShuffleDependency
import org.apache.spark.SparkEnv
import org.apache.spark.TaskContext
-import org.apache.spark.serializer.Serializer
/**
* An optimized version of cogroup for set difference/subtraction.
@@ -54,13 +53,6 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
part: Partitioner)
extends RDD[(K, V)](rdd1.context, Nil) {
- private var serializer: Option[Serializer] = None
-
- /** Set a serializer for this RDD's shuffle, or null to use the default (spark.serializer) */
- def setSerializer(serializer: Serializer): SubtractedRDD[K, V, W] = {
- this.serializer = Option(serializer)
- this
- }
override def getDependencies: Seq[Dependency[_]] = {
def rddDependency[T1: ClassTag, T2: ClassTag](rdd: RDD[_ <: Product2[T1, T2]])
@@ -70,7 +62,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
new OneToOneDependency(rdd)
} else {
logDebug("Adding shuffle dependency with " + rdd)
- new ShuffleDependency[T1, T2, Any](rdd, part, serializer)
+ new ShuffleDependency[T1, T2, Any](rdd, part)
}
}
Seq(rddDependency[K, V](rdd1), rddDependency[K, W](rdd2))
diff --git a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
index 95bdf0ce2d..5ead40e89e 100644
--- a/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/Serializer.scala
@@ -100,18 +100,6 @@ abstract class Serializer {
}
-@DeveloperApi
-object Serializer {
- def getSerializer(serializer: Serializer): Serializer = {
- if (serializer == null) SparkEnv.get.serializer else serializer
- }
-
- def getSerializer(serializer: Option[Serializer]): Serializer = {
- serializer.getOrElse(SparkEnv.get.serializer)
- }
-}
-
-
/**
* :: DeveloperApi ::
* An instance of a serializer, for use by one thread at a time.
diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
new file mode 100644
index 0000000000..b9f115463a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala
@@ -0,0 +1,71 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.serializer
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.SparkConf
+
+/**
+ * Component that selects which [[Serializer]] to use for shuffles.
+ */
+private[spark] class SerializerManager(defaultSerializer: Serializer, conf: SparkConf) {
+
+ private[this] val kryoSerializer = new KryoSerializer(conf)
+
+ private[this] val primitiveAndPrimitiveArrayClassTags: Set[ClassTag[_]] = {
+ val primitiveClassTags = Set[ClassTag[_]](
+ ClassTag.Boolean,
+ ClassTag.Byte,
+ ClassTag.Char,
+ ClassTag.Double,
+ ClassTag.Float,
+ ClassTag.Int,
+ ClassTag.Long,
+ ClassTag.Null,
+ ClassTag.Short
+ )
+ val arrayClassTags = primitiveClassTags.map(_.wrap)
+ primitiveClassTags ++ arrayClassTags
+ }
+
+ private[this] val stringClassTag: ClassTag[String] = implicitly[ClassTag[String]]
+
+ private def canUseKryo(ct: ClassTag[_]): Boolean = {
+ primitiveAndPrimitiveArrayClassTags.contains(ct) || ct == stringClassTag
+ }
+
+ def getSerializer(ct: ClassTag[_]): Serializer = {
+ if (canUseKryo(ct)) {
+ kryoSerializer
+ } else {
+ defaultSerializer
+ }
+ }
+
+ /**
+ * Pick the best serializer for shuffling an RDD of key-value pairs.
+ */
+ def getSerializer(keyClassTag: ClassTag[_], valueClassTag: ClassTag[_]): Serializer = {
+ if (canUseKryo(keyClassTag) && canUseKryo(valueClassTag)) {
+ kryoSerializer
+ } else {
+ defaultSerializer
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
index dc182f5963..69183d9936 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/BlockStoreShuffleReader.scala
@@ -54,8 +54,7 @@ private[spark] class BlockStoreShuffleReader[K, C](
blockManager.wrapForCompression(blockId, inputStream)
}
- val ser = Serializer.getSerializer(dep.serializer)
- val serializerInstance = ser.newInstance()
+ val serializerInstance = dep.serializer.newInstance()
// Create a key/value iterator for each stream
val recordIter = wrappedStreams.flatMap { wrappedStream =>
@@ -100,7 +99,7 @@ private[spark] class BlockStoreShuffleReader[K, C](
// Create an ExternalSorter to sort the data. Note that if spark.shuffle.spill is disabled,
// the ExternalSorter won't spill to disk.
val sorter =
- new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = Some(ser))
+ new ExternalSorter[K, C, C](context, ordering = Some(keyOrd), serializer = dep.serializer)
sorter.insertAll(aggregatedIter)
context.taskMetrics().incMemoryBytesSpilled(sorter.memoryBytesSpilled)
context.taskMetrics().incDiskBytesSpilled(sorter.diskBytesSpilled)
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
index 7694e950be..22b31994e7 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
@@ -21,7 +21,6 @@ import java.io.IOException
import org.apache.spark._
import org.apache.spark.scheduler.MapStatus
-import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle._
import org.apache.spark.storage.DiskBlockObjectWriter
@@ -44,9 +43,8 @@ private[spark] class HashShuffleWriter[K, V](
private val writeMetrics = metrics.registerShuffleWriteMetrics()
private val blockManager = SparkEnv.get.blockManager
- private val ser = Serializer.getSerializer(dep.serializer.getOrElse(null))
- private val shuffle = shuffleBlockResolver.forMapTask(dep.shuffleId, mapId, numOutputSplits, ser,
- writeMetrics)
+ private val shuffle = shuffleBlockResolver.forMapTask(dep.shuffleId, mapId, numOutputSplits,
+ dep.serializer, writeMetrics)
/** Write a bunch of records to this task's output */
override def write(records: Iterator[Product2[K, V]]): Unit = {
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
index 9b1a279528..f7744d12c5 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
@@ -20,7 +20,6 @@ package org.apache.spark.shuffle.sort
import java.util.concurrent.ConcurrentHashMap
import org.apache.spark._
-import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle._
/**
@@ -186,10 +185,9 @@ private[spark] object SortShuffleManager extends Logging {
def canUseSerializedShuffle(dependency: ShuffleDependency[_, _, _]): Boolean = {
val shufId = dependency.shuffleId
val numPartitions = dependency.partitioner.numPartitions
- val serializer = Serializer.getSerializer(dependency.serializer)
- if (!serializer.supportsRelocationOfSerializedObjects) {
+ if (!dependency.serializer.supportsRelocationOfSerializedObjects) {
log.debug(s"Can't use serialized shuffle for shuffle $shufId because the serializer, " +
- s"${serializer.getClass.getName}, does not support object relocation")
+ s"${dependency.serializer.getClass.getName}, does not support object relocation")
false
} else if (dependency.aggregator.isDefined) {
log.debug(
diff --git a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
index 5afd6d6e22..4bcdcb0774 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/ExternalSorter.scala
@@ -91,7 +91,7 @@ private[spark] class ExternalSorter[K, V, C](
aggregator: Option[Aggregator[K, V, C]] = None,
partitioner: Option[Partitioner] = None,
ordering: Option[Ordering[K]] = None,
- serializer: Option[Serializer] = None)
+ serializer: Serializer = SparkEnv.get.serializer)
extends Logging
with Spillable[WritablePartitionedPairCollection[K, C]] {
@@ -107,8 +107,7 @@ private[spark] class ExternalSorter[K, V, C](
private val blockManager = SparkEnv.get.blockManager
private val diskBlockManager = blockManager.diskBlockManager
- private val ser = Serializer.getSerializer(serializer)
- private val serInstance = ser.newInstance()
+ private val serInstance = serializer.newInstance()
// Use getSizeAsKb (not bytes) to maintain backwards compatibility if no units are provided
private val fileBufferSize = conf.getSizeAsKb("spark.shuffle.file.buffer", "32k").toInt * 1024
diff --git a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
index ddea6f5a69..47c695ad4e 100644
--- a/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
+++ b/core/src/test/java/org/apache/spark/shuffle/sort/UnsafeShuffleWriterSuite.java
@@ -191,7 +191,7 @@ public class UnsafeShuffleWriterSuite {
});
when(taskContext.taskMetrics()).thenReturn(taskMetrics);
- when(shuffleDep.serializer()).thenReturn(Option.<Serializer>apply(serializer));
+ when(shuffleDep.serializer()).thenReturn(serializer);
when(shuffleDep.partitioner()).thenReturn(hashPartitioner);
}
diff --git a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala
index 26a372d6a9..08f52c92e1 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/BlockStoreShuffleReaderSuite.scala
@@ -127,7 +127,7 @@ class BlockStoreShuffleReaderSuite extends SparkFunSuite with LocalSparkContext
// Create a mocked shuffle handle to pass into HashShuffleReader.
val shuffleHandle = {
val dependency = mock(classOf[ShuffleDependency[Int, Int, Int]])
- when(dependency.serializer).thenReturn(Some(serializer))
+ when(dependency.serializer).thenReturn(serializer)
when(dependency.aggregator).thenReturn(None)
when(dependency.keyOrdering).thenReturn(None)
new BaseShuffleHandle(shuffleId, numMaps, dependency)
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
index cf9f9da1e6..16418f855b 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/BypassMergeSortShuffleWriterSuite.scala
@@ -66,7 +66,7 @@ class BypassMergeSortShuffleWriterSuite extends SparkFunSuite with BeforeAndAfte
dependency = dependency
)
when(dependency.partitioner).thenReturn(new HashPartitioner(7))
- when(dependency.serializer).thenReturn(Some(new JavaSerializer(conf)))
+ when(dependency.serializer).thenReturn(new JavaSerializer(conf))
when(taskContext.taskMetrics()).thenReturn(taskMetrics)
when(blockResolver.getDataFile(0, 0)).thenReturn(outputFile)
doAnswer(new Answer[Void] {
diff --git a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala
index 8744a072cb..55cebe7c8b 100644
--- a/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/shuffle/sort/SortShuffleManagerSuite.scala
@@ -41,7 +41,7 @@ class SortShuffleManagerSuite extends SparkFunSuite with Matchers {
private def shuffleDep(
partitioner: Partitioner,
- serializer: Option[Serializer],
+ serializer: Serializer,
keyOrdering: Option[Ordering[Any]],
aggregator: Option[Aggregator[Any, Any, Any]],
mapSideCombine: Boolean): ShuffleDependency[Any, Any, Any] = {
@@ -56,7 +56,7 @@ class SortShuffleManagerSuite extends SparkFunSuite with Matchers {
}
test("supported shuffle dependencies for serialized shuffle") {
- val kryo = Some(new KryoSerializer(new SparkConf()))
+ val kryo = new KryoSerializer(new SparkConf())
assert(canUseSerializedShuffle(shuffleDep(
partitioner = new HashPartitioner(2),
@@ -88,8 +88,8 @@ class SortShuffleManagerSuite extends SparkFunSuite with Matchers {
}
test("unsupported shuffle dependencies for serialized shuffle") {
- val kryo = Some(new KryoSerializer(new SparkConf()))
- val java = Some(new JavaSerializer(new SparkConf()))
+ val kryo = new KryoSerializer(new SparkConf())
+ val java = new JavaSerializer(new SparkConf())
// We only support serializers that support object relocation
assert(!canUseSerializedShuffle(shuffleDep(
diff --git a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
index a62adf1c2c..a1a7ac97d9 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/ExternalSorterSuite.scala
@@ -110,7 +110,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
createCombiner _, mergeValue _, mergeCombiners _)
val sorter = new ExternalSorter[String, String, ArrayBuffer[String]](
- context, Some(agg), None, None, None)
+ context, Some(agg), None, None)
val collisionPairs = Seq(
("Aa", "BB"), // 2112
@@ -161,7 +161,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
sc = new SparkContext("local-cluster[1,1,1024]", "test", conf)
val context = MemoryTestingUtils.fakeTaskContext(sc.env)
val agg = new Aggregator[FixedHashObject, Int, Int](_ => 1, _ + _, _ + _)
- val sorter = new ExternalSorter[FixedHashObject, Int, Int](context, Some(agg), None, None, None)
+ val sorter = new ExternalSorter[FixedHashObject, Int, Int](context, Some(agg), None, None)
// Insert 10 copies each of lots of objects whose hash codes are either 0 or 1. This causes
// problems if the map fails to group together the objects with the same code (SPARK-2043).
val toInsert = for (i <- 1 to 10; j <- 1 to size) yield (FixedHashObject(j, j % 2), 1)
@@ -192,7 +192,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
val agg = new Aggregator[Int, Int, ArrayBuffer[Int]](createCombiner, mergeValue, mergeCombiners)
val sorter =
- new ExternalSorter[Int, Int, ArrayBuffer[Int]](context, Some(agg), None, None, None)
+ new ExternalSorter[Int, Int, ArrayBuffer[Int]](context, Some(agg), None, None)
sorter.insertAll(
(1 to size).iterator.map(i => (i, i)) ++ Iterator((Int.MaxValue, Int.MaxValue)))
assert(sorter.numSpills > 0, "sorter did not spill")
@@ -219,7 +219,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
createCombiner, mergeValue, mergeCombiners)
val sorter = new ExternalSorter[String, String, ArrayBuffer[String]](
- context, Some(agg), None, None, None)
+ context, Some(agg), None, None)
sorter.insertAll((1 to size).iterator.map(i => (i.toString, i.toString)) ++ Iterator(
(null.asInstanceOf[String], "1"),
@@ -283,25 +283,25 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
// Both aggregator and ordering
val sorter = new ExternalSorter[Int, Int, Int](
- context, Some(agg), Some(new HashPartitioner(3)), Some(ord), None)
+ context, Some(agg), Some(new HashPartitioner(3)), Some(ord))
assert(sorter.iterator.toSeq === Seq())
sorter.stop()
// Only aggregator
val sorter2 = new ExternalSorter[Int, Int, Int](
- context, Some(agg), Some(new HashPartitioner(3)), None, None)
+ context, Some(agg), Some(new HashPartitioner(3)), None)
assert(sorter2.iterator.toSeq === Seq())
sorter2.stop()
// Only ordering
val sorter3 = new ExternalSorter[Int, Int, Int](
- context, None, Some(new HashPartitioner(3)), Some(ord), None)
+ context, None, Some(new HashPartitioner(3)), Some(ord))
assert(sorter3.iterator.toSeq === Seq())
sorter3.stop()
// Neither aggregator nor ordering
val sorter4 = new ExternalSorter[Int, Int, Int](
- context, None, Some(new HashPartitioner(3)), None, None)
+ context, None, Some(new HashPartitioner(3)), None)
assert(sorter4.iterator.toSeq === Seq())
sorter4.stop()
}
@@ -320,28 +320,28 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
// Both aggregator and ordering
val sorter = new ExternalSorter[Int, Int, Int](
- context, Some(agg), Some(new HashPartitioner(7)), Some(ord), None)
+ context, Some(agg), Some(new HashPartitioner(7)), Some(ord))
sorter.insertAll(elements.iterator)
assert(sorter.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
sorter.stop()
// Only aggregator
val sorter2 = new ExternalSorter[Int, Int, Int](
- context, Some(agg), Some(new HashPartitioner(7)), None, None)
+ context, Some(agg), Some(new HashPartitioner(7)), None)
sorter2.insertAll(elements.iterator)
assert(sorter2.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
sorter2.stop()
// Only ordering
val sorter3 = new ExternalSorter[Int, Int, Int](
- context, None, Some(new HashPartitioner(7)), Some(ord), None)
+ context, None, Some(new HashPartitioner(7)), Some(ord))
sorter3.insertAll(elements.iterator)
assert(sorter3.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
sorter3.stop()
// Neither aggregator nor ordering
val sorter4 = new ExternalSorter[Int, Int, Int](
- context, None, Some(new HashPartitioner(7)), None, None)
+ context, None, Some(new HashPartitioner(7)), None)
sorter4.insertAll(elements.iterator)
assert(sorter4.partitionedIterator.map(p => (p._1, p._2.toSet)).toSet === expected)
sorter4.stop()
@@ -358,7 +358,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
val elements = Iterator((1, 1), (5, 5)) ++ (0 until size).iterator.map(x => (2, 2))
val sorter = new ExternalSorter[Int, Int, Int](
- context, None, Some(new HashPartitioner(7)), Some(ord), None)
+ context, None, Some(new HashPartitioner(7)), Some(ord))
sorter.insertAll(elements)
assert(sorter.numSpills > 0, "sorter did not spill")
val iter = sorter.partitionedIterator.map(p => (p._1, p._2.toList))
@@ -442,7 +442,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
val expectedSize = if (withFailures) size - 1 else size
val context = MemoryTestingUtils.fakeTaskContext(sc.env)
val sorter = new ExternalSorter[Int, Int, Int](
- context, None, Some(new HashPartitioner(3)), Some(ord), None)
+ context, None, Some(new HashPartitioner(3)), Some(ord))
if (withFailures) {
intercept[SparkException] {
sorter.insertAll((0 until size).iterator.map { i =>
@@ -512,7 +512,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
val ord = if (withOrdering) Some(implicitly[Ordering[Int]]) else None
val context = MemoryTestingUtils.fakeTaskContext(sc.env)
val sorter =
- new ExternalSorter[Int, Int, Int](context, agg, Some(new HashPartitioner(3)), ord, None)
+ new ExternalSorter[Int, Int, Int](context, agg, Some(new HashPartitioner(3)), ord)
sorter.insertAll((0 until size).iterator.map { i => (i / 4, i) })
if (withSpilling) {
assert(sorter.numSpills > 0, "sorter did not spill")
@@ -551,7 +551,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
val context = MemoryTestingUtils.fakeTaskContext(sc.env)
val sorter1 = new ExternalSorter[String, String, String](
- context, None, None, Some(wrongOrdering), None)
+ context, None, None, Some(wrongOrdering))
val thrown = intercept[IllegalArgumentException] {
sorter1.insertAll(testData.iterator.map(i => (i, i)))
assert(sorter1.numSpills > 0, "sorter did not spill")
@@ -573,7 +573,7 @@ class ExternalSorterSuite extends SparkFunSuite with LocalSparkContext {
createCombiner, mergeValue, mergeCombiners)
val sorter2 = new ExternalSorter[String, String, ArrayBuffer[String]](
- context, Some(agg), None, None, None)
+ context, Some(agg), None, None)
sorter2.insertAll(testData.iterator.map(i => (i, i)))
assert(sorter2.numSpills > 0, "sorter did not spill")