aboutsummaryrefslogtreecommitdiff
path: root/core
diff options
context:
space:
mode:
authorTathagata Das <tathagata.das1565@gmail.com>2013-12-19 12:05:08 -0800
committerTathagata Das <tathagata.das1565@gmail.com>2013-12-19 12:05:08 -0800
commitde41c436a0088efc83bc4705dcd279e61b085759 (patch)
treed725da59145e243c1b1934567dbb61c7ddb2abcb /core
parent03ef6e889929befa968d35fa3757c687edc3a38b (diff)
parentec71b445ad0440e84c4b4909e4faf75aba0f13d7 (diff)
downloadspark-de41c436a0088efc83bc4705dcd279e61b085759.tar.gz
spark-de41c436a0088efc83bc4705dcd279e61b085759.tar.bz2
spark-de41c436a0088efc83bc4705dcd279e61b085759.zip
Merge branch 'scheduler-update' into window-improvement
Conflicts: streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala
Diffstat (limited to 'core')
-rw-r--r--core/pom.xml47
-rw-r--r--core/src/main/java/org/apache/spark/network/netty/FileClient.java2
-rw-r--r--core/src/main/java/org/apache/spark/network/netty/FileServer.java1
-rw-r--r--core/src/main/scala/org/apache/spark/FutureAction.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/MapOutputTracker.scala37
-rw-r--r--core/src/main/scala/org/apache/spark/Partitioner.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala288
-rw-r--r--core/src/main/scala/org/apache/spark/SparkEnv.scala13
-rw-r--r--core/src/main/scala/org/apache/spark/TaskState.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala49
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala77
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala32
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala60
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java1
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/Function.java8
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/Function2.java8
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/Function3.java8
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java12
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java12
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala159
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala28
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala13
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/client/Client.scala48
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/Master.scala69
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala48
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala13
-rw-r--r--core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala23
-rw-r--r--core/src/main/scala/org/apache/spark/executor/Executor.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala23
-rw-r--r--core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala82
-rw-r--r--core/src/main/scala/org/apache/spark/network/ConnectionManager.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala127
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/FilteredRDD.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/FlatMappedRDD.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/GlommedRDD.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala12
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala41
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/MappedRDD.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala33
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala8
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala14
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala120
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala57
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala6
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala404
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala5
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SchedulingMode.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala13
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala24
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala4
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala29
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManager.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala16
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/storage/StorageLevel.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala17
-rw-r--r--core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala23
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala67
-rw-r--r--core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/util/AkkaUtils.scala79
-rw-r--r--core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala93
-rw-r--r--core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/util/IndestructibleActorSystem.scala68
-rw-r--r--core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala29
-rw-r--r--core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala94
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala3
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala118
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala7
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/AccumulatorSuite.scala32
-rw-r--r--core/src/test/scala/org/apache/spark/CheckpointSuite.scala7
-rw-r--r--core/src/test/scala/org/apache/spark/DistributedSuite.scala5
-rw-r--r--core/src/test/scala/org/apache/spark/DriverSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/JavaAPISuite.java14
-rw-r--r--core/src/test/scala/org/apache/spark/JobCancellationSuite.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/LocalSparkContext.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala14
-rw-r--r--core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala45
-rw-r--r--core/src/test/scala/org/apache/spark/PartitioningSuite.scala10
-rw-r--r--core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala140
-rw-r--r--core/src/test/scala/org/apache/spark/UnpersistSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala17
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala26
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala271
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala86
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala8
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala43
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala7
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala4
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala31
-rw-r--r--core/src/test/scala/org/apache/spark/ui/UISuite.scala1
-rw-r--r--core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala72
-rw-r--r--core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala76
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala33
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala37
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala (renamed from core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashSetSuite.scala)31
149 files changed, 2721 insertions, 1263 deletions
diff --git a/core/pom.xml b/core/pom.xml
index 8621d257e5..043f6cf68d 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -26,7 +26,7 @@
</parent>
<groupId>org.apache.spark</groupId>
- <artifactId>spark-core_2.9.3</artifactId>
+ <artifactId>spark-core_2.10</artifactId>
<packaging>jar</packaging>
<name>Spark Project Core</name>
<url>http://spark.incubator.apache.org/</url>
@@ -81,12 +81,8 @@
<artifactId>asm</artifactId>
</dependency>
<dependency>
- <groupId>com.google.protobuf</groupId>
- <artifactId>protobuf-java</artifactId>
- </dependency>
- <dependency>
<groupId>com.twitter</groupId>
- <artifactId>chill_2.9.3</artifactId>
+ <artifactId>chill_${scala.binary.version}</artifactId>
<version>0.3.1</version>
</dependency>
<dependency>
@@ -95,20 +91,12 @@
<version>0.3.1</version>
</dependency>
<dependency>
- <groupId>com.typesafe.akka</groupId>
- <artifactId>akka-actor</artifactId>
- </dependency>
- <dependency>
- <groupId>com.typesafe.akka</groupId>
- <artifactId>akka-remote</artifactId>
+ <groupId>${akka.group}</groupId>
+ <artifactId>akka-remote_${scala.binary.version}</artifactId>
</dependency>
<dependency>
- <groupId>com.typesafe.akka</groupId>
- <artifactId>akka-slf4j</artifactId>
- </dependency>
- <dependency>
- <groupId>org.scala-lang</groupId>
- <artifactId>scalap</artifactId>
+ <groupId>${akka.group}</groupId>
+ <artifactId>akka-slf4j_${scala.binary.version}</artifactId>
</dependency>
<dependency>
<groupId>org.scala-lang</groupId>
@@ -116,7 +104,7 @@
</dependency>
<dependency>
<groupId>net.liftweb</groupId>
- <artifactId>lift-json_2.9.2</artifactId>
+ <artifactId>lift-json_${scala.binary.version}</artifactId>
</dependency>
<dependency>
<groupId>it.unimi.dsi</groupId>
@@ -127,10 +115,6 @@
<artifactId>colt</artifactId>
</dependency>
<dependency>
- <groupId>com.github.scala-incubator.io</groupId>
- <artifactId>scala-io-file_2.9.2</artifactId>
- </dependency>
- <dependency>
<groupId>org.apache.mesos</groupId>
<artifactId>mesos</artifactId>
</dependency>
@@ -159,18 +143,27 @@
<artifactId>metrics-ganglia</artifactId>
</dependency>
<dependency>
+ <groupId>com.codahale.metrics</groupId>
+ <artifactId>metrics-graphite</artifactId>
+ </dependency>
+ <dependency>
<groupId>org.apache.derby</groupId>
<artifactId>derby</artifactId>
<scope>test</scope>
</dependency>
<dependency>
+ <groupId>commons-io</groupId>
+ <artifactId>commons-io</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
<groupId>org.scalatest</groupId>
- <artifactId>scalatest_2.9.3</artifactId>
+ <artifactId>scalatest_${scala.binary.version}</artifactId>
<scope>test</scope>
</dependency>
<dependency>
<groupId>org.scalacheck</groupId>
- <artifactId>scalacheck_2.9.3</artifactId>
+ <artifactId>scalacheck_${scala.binary.version}</artifactId>
<scope>test</scope>
</dependency>
<dependency>
@@ -190,8 +183,8 @@
</dependency>
</dependencies>
<build>
- <outputDirectory>target/scala-${scala.version}/classes</outputDirectory>
- <testOutputDirectory>target/scala-${scala.version}/test-classes</testOutputDirectory>
+ <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory>
+ <testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory>
<plugins>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
diff --git a/core/src/main/java/org/apache/spark/network/netty/FileClient.java b/core/src/main/java/org/apache/spark/network/netty/FileClient.java
index 20a7a3aa8c..edd0fc56f8 100644
--- a/core/src/main/java/org/apache/spark/network/netty/FileClient.java
+++ b/core/src/main/java/org/apache/spark/network/netty/FileClient.java
@@ -19,8 +19,6 @@ package org.apache.spark.network.netty;
import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel;
-import io.netty.channel.ChannelFuture;
-import io.netty.channel.ChannelFutureListener;
import io.netty.channel.ChannelOption;
import io.netty.channel.oio.OioEventLoopGroup;
import io.netty.channel.socket.oio.OioSocketChannel;
diff --git a/core/src/main/java/org/apache/spark/network/netty/FileServer.java b/core/src/main/java/org/apache/spark/network/netty/FileServer.java
index 666432474d..a99af348ce 100644
--- a/core/src/main/java/org/apache/spark/network/netty/FileServer.java
+++ b/core/src/main/java/org/apache/spark/network/netty/FileServer.java
@@ -20,7 +20,6 @@ package org.apache.spark.network.netty;
import java.net.InetSocketAddress;
import io.netty.bootstrap.ServerBootstrap;
-import io.netty.channel.Channel;
import io.netty.channel.ChannelFuture;
import io.netty.channel.ChannelOption;
import io.netty.channel.oio.OioEventLoopGroup;
diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala
index 1ad9240cfa..c6b4ac5192 100644
--- a/core/src/main/scala/org/apache/spark/FutureAction.scala
+++ b/core/src/main/scala/org/apache/spark/FutureAction.scala
@@ -99,7 +99,7 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc:
override def ready(atMost: Duration)(implicit permit: CanAwait): SimpleFutureAction.this.type = {
if (!atMost.isFinite()) {
awaitResult()
- } else {
+ } else jobWaiter.synchronized {
val finishTime = System.currentTimeMillis() + atMost.toMillis
while (!isCompleted) {
val time = System.currentTimeMillis()
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 5e465fa22c..10fae5af9f 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -21,12 +21,11 @@ import java.io._
import java.util.zip.{GZIPInputStream, GZIPOutputStream}
import scala.collection.mutable.HashSet
+import scala.concurrent.Await
+import scala.concurrent.duration._
import akka.actor._
-import akka.dispatch._
import akka.pattern.ask
-import akka.util.Duration
-
import org.apache.spark.scheduler.MapStatus
import org.apache.spark.storage.BlockManagerId
@@ -55,9 +54,9 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster
private[spark] class MapOutputTracker extends Logging {
private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
-
+
// Set to the MapOutputTrackerActor living on the driver
- var trackerActor: ActorRef = _
+ var trackerActor: Either[ActorRef, ActorSelection] = _
protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]]
@@ -73,8 +72,18 @@ private[spark] class MapOutputTracker extends Logging {
// throw a SparkException if this fails.
private def askTracker(message: Any): Any = {
try {
- val future = trackerActor.ask(message)(timeout)
- return Await.result(future, timeout)
+ /*
+ The difference between ActorRef and ActorSelection is well explained here:
+ http://doc.akka.io/docs/akka/2.2.3/project/migration-guide-2.1.x-2.2.x.html#Use_actorSelection_instead_of_actorFor
+ In spark a map output tracker can be either started on Driver where it is created which
+ is an ActorRef or it can be on executor from where it is looked up which is an
+ actorSelection.
+ */
+ val future = trackerActor match {
+ case Left(a: ActorRef) => a.ask(message)(timeout)
+ case Right(b: ActorSelection) => b.ask(message)(timeout)
+ }
+ Await.result(future, timeout)
} catch {
case e: Exception =>
throw new SparkException("Error communicating with MapOutputTracker", e)
@@ -117,7 +126,7 @@ private[spark] class MapOutputTracker extends Logging {
fetching += shuffleId
}
}
-
+
if (fetchedStatuses == null) {
// We won the race to fetch the output locs; do so
logInfo("Doing the fetch; tracker actor = " + trackerActor)
@@ -144,7 +153,7 @@ private[spark] class MapOutputTracker extends Logging {
else{
throw new FetchFailedException(null, shuffleId, -1, reduceId,
new Exception("Missing all output locations for shuffle " + shuffleId))
- }
+ }
} else {
statuses.synchronized {
return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses)
@@ -244,12 +253,12 @@ private[spark] class MapOutputTrackerMaster extends MapOutputTracker {
case Some(bytes) =>
return bytes
case None =>
- statuses = mapStatuses(shuffleId)
+ statuses = mapStatuses.getOrElse(shuffleId, Array[MapStatus]())
epochGotten = epoch
}
}
// If we got here, we failed to find the serialized locations in the cache, so we pulled
- // out a snapshot of the locations as "locs"; let's serialize and return that
+ // out a snapshot of the locations as "statuses"; let's serialize and return that
val bytes = MapOutputTracker.serializeMapStatuses(statuses)
logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length))
// Add them into the table only if the epoch hasn't changed while we were working
@@ -274,6 +283,10 @@ private[spark] class MapOutputTrackerMaster extends MapOutputTracker {
override def updateEpoch(newEpoch: Long) {
// This might be called on the MapOutputTrackerMaster if we're running in local mode.
}
+
+ def has(shuffleId: Int): Boolean = {
+ cachedSerializedStatuses.get(shuffleId).isDefined || mapStatuses.contains(shuffleId)
+ }
}
private[spark] object MapOutputTracker {
@@ -308,7 +321,7 @@ private[spark] object MapOutputTracker {
statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = {
assert (statuses != null)
statuses.map {
- status =>
+ status =>
if (status == null) {
throw new FetchFailedException(null, shuffleId, -1, reduceId,
new Exception("Missing an output location for shuffle " + shuffleId))
diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala
index 0e2c987a59..bcec41c439 100644
--- a/core/src/main/scala/org/apache/spark/Partitioner.scala
+++ b/core/src/main/scala/org/apache/spark/Partitioner.scala
@@ -17,8 +17,10 @@
package org.apache.spark
-import org.apache.spark.util.Utils
+import scala.reflect.ClassTag
+
import org.apache.spark.rdd.RDD
+import org.apache.spark.util.Utils
/**
* An object that defines how the elements in a key-value pair RDD are partitioned by key.
@@ -72,7 +74,7 @@ class HashPartitioner(partitions: Int) extends Partitioner {
case null => 0
case _ => Utils.nonNegativeMod(key.hashCode, numPartitions)
}
-
+
override def equals(other: Any): Boolean = other match {
case h: HashPartitioner =>
h.numPartitions == numPartitions
@@ -85,7 +87,7 @@ class HashPartitioner(partitions: Int) extends Partitioner {
* A [[org.apache.spark.Partitioner]] that partitions sortable records by range into roughly equal ranges.
* Determines the ranges by sampling the RDD passed in.
*/
-class RangePartitioner[K <% Ordered[K]: ClassManifest, V](
+class RangePartitioner[K <% Ordered[K]: ClassTag, V](
partitions: Int,
@transient rdd: RDD[_ <: Product2[K,V]],
private val ascending: Boolean = true)
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 42b2985b50..a0f794edfd 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -26,6 +26,7 @@ import scala.collection.Map
import scala.collection.generic.Growable
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
+import scala.reflect.{ClassTag, classTag}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.Path
@@ -81,7 +82,7 @@ class SparkContext(
val sparkHome: String = null,
val jars: Seq[String] = Nil,
val environment: Map[String, String] = Map(),
- // This is used only by yarn for now, but should be relevant to other cluster types (mesos, etc)
+ // This is used only by YARN for now, but should be relevant to other cluster types (Mesos, etc)
// too. This is typically generated from InputFormatInfo.computePreferredLocations .. host, set
// of data-local splits on host
val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] =
@@ -153,98 +154,11 @@ class SparkContext(
executorEnvs("SPARK_USER") = sparkUser
// Create and start the scheduler
- private[spark] var taskScheduler: TaskScheduler = {
- // Regular expression used for local[N] master format
- val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r
- // Regular expression for local[N, maxRetries], used in tests with failing tasks
- val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r
- // Regular expression for simulating a Spark cluster of [N, cores, memory] locally
- val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r
- // Regular expression for connecting to Spark deploy clusters
- val SPARK_REGEX = """spark://(.*)""".r
- // Regular expression for connection to Mesos cluster
- val MESOS_REGEX = """mesos://(.*)""".r
- // Regular expression for connection to Simr cluster
- val SIMR_REGEX = """simr://(.*)""".r
-
- master match {
- case "local" =>
- new LocalScheduler(1, 0, this)
-
- case LOCAL_N_REGEX(threads) =>
- new LocalScheduler(threads.toInt, 0, this)
-
- case LOCAL_N_FAILURES_REGEX(threads, maxFailures) =>
- new LocalScheduler(threads.toInt, maxFailures.toInt, this)
-
- case SPARK_REGEX(sparkUrl) =>
- val scheduler = new ClusterScheduler(this)
- val masterUrls = sparkUrl.split(",").map("spark://" + _)
- val backend = new SparkDeploySchedulerBackend(scheduler, this, masterUrls, appName)
- scheduler.initialize(backend)
- scheduler
-
- case SIMR_REGEX(simrUrl) =>
- val scheduler = new ClusterScheduler(this)
- val backend = new SimrSchedulerBackend(scheduler, this, simrUrl)
- scheduler.initialize(backend)
- scheduler
-
- case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) =>
- // Check to make sure memory requested <= memoryPerSlave. Otherwise Spark will just hang.
- val memoryPerSlaveInt = memoryPerSlave.toInt
- if (SparkContext.executorMemoryRequested > memoryPerSlaveInt) {
- throw new SparkException(
- "Asked to launch cluster with %d MB RAM / worker but requested %d MB/worker".format(
- memoryPerSlaveInt, SparkContext.executorMemoryRequested))
- }
-
- val scheduler = new ClusterScheduler(this)
- val localCluster = new LocalSparkCluster(
- numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt)
- val masterUrls = localCluster.start()
- val backend = new SparkDeploySchedulerBackend(scheduler, this, masterUrls, appName)
- scheduler.initialize(backend)
- backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => {
- localCluster.stop()
- }
- scheduler
-
- case "yarn-standalone" =>
- val scheduler = try {
- val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClusterScheduler")
- val cons = clazz.getConstructor(classOf[SparkContext])
- cons.newInstance(this).asInstanceOf[ClusterScheduler]
- } catch {
- // TODO: Enumerate the exact reasons why it can fail
- // But irrespective of it, it means we cannot proceed !
- case th: Throwable => {
- throw new SparkException("YARN mode not available ?", th)
- }
- }
- val backend = new CoarseGrainedSchedulerBackend(scheduler, this.env.actorSystem)
- scheduler.initialize(backend)
- scheduler
-
- case MESOS_REGEX(mesosUrl) =>
- MesosNativeLibrary.load()
- val scheduler = new ClusterScheduler(this)
- val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean
- val backend = if (coarseGrained) {
- new CoarseMesosSchedulerBackend(scheduler, this, mesosUrl, appName)
- } else {
- new MesosSchedulerBackend(scheduler, this, mesosUrl, appName)
- }
- scheduler.initialize(backend)
- scheduler
-
- case _ =>
- throw new SparkException("Could not parse Master URL: '" + master + "'")
- }
- }
+ private[spark] var taskScheduler = SparkContext.createTaskScheduler(this, master, appName)
taskScheduler.start()
@volatile private[spark] var dagScheduler = new DAGScheduler(taskScheduler)
+ dagScheduler.start()
ui.start()
@@ -354,19 +268,19 @@ class SparkContext(
// Methods for creating RDDs
/** Distribute a local Scala collection to form an RDD. */
- def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = {
+ def parallelize[T: ClassTag](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = {
new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]())
}
/** Distribute a local Scala collection to form an RDD. */
- def makeRDD[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = {
+ def makeRDD[T: ClassTag](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = {
parallelize(seq, numSlices)
}
/** Distribute a local Scala collection to form an RDD, with one or more
* location preferences (hostnames of Spark nodes) for each object.
* Create a new partition for each collection item. */
- def makeRDD[T: ClassManifest](seq: Seq[(T, Seq[String])]): RDD[T] = {
+ def makeRDD[T: ClassTag](seq: Seq[(T, Seq[String])]): RDD[T] = {
val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap
new ParallelCollectionRDD[T](this, seq.map(_._1), seq.size, indexToPrefs)
}
@@ -419,7 +333,7 @@ class SparkContext(
}
/**
- * Smarter version of hadoopFile() that uses class manifests to figure out the classes of keys,
+ * Smarter version of hadoopFile() that uses class tags to figure out the classes of keys,
* values and the InputFormat so that users don't need to pass them directly. Instead, callers
* can just write, for example,
* {{{
@@ -427,17 +341,17 @@ class SparkContext(
* }}}
*/
def hadoopFile[K, V, F <: InputFormat[K, V]](path: String, minSplits: Int)
- (implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F])
+ (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F])
: RDD[(K, V)] = {
hadoopFile(path,
- fm.erasure.asInstanceOf[Class[F]],
- km.erasure.asInstanceOf[Class[K]],
- vm.erasure.asInstanceOf[Class[V]],
+ fm.runtimeClass.asInstanceOf[Class[F]],
+ km.runtimeClass.asInstanceOf[Class[K]],
+ vm.runtimeClass.asInstanceOf[Class[V]],
minSplits)
}
/**
- * Smarter version of hadoopFile() that uses class manifests to figure out the classes of keys,
+ * Smarter version of hadoopFile() that uses class tags to figure out the classes of keys,
* values and the InputFormat so that users don't need to pass them directly. Instead, callers
* can just write, for example,
* {{{
@@ -445,17 +359,17 @@ class SparkContext(
* }}}
*/
def hadoopFile[K, V, F <: InputFormat[K, V]](path: String)
- (implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F]): RDD[(K, V)] =
+ (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] =
hadoopFile[K, V, F](path, defaultMinSplits)
/** Get an RDD for a Hadoop file with an arbitrary new API InputFormat. */
def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]](path: String)
- (implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F]): RDD[(K, V)] = {
+ (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] = {
newAPIHadoopFile(
path,
- fm.erasure.asInstanceOf[Class[F]],
- km.erasure.asInstanceOf[Class[K]],
- vm.erasure.asInstanceOf[Class[V]])
+ fm.runtimeClass.asInstanceOf[Class[F]],
+ km.runtimeClass.asInstanceOf[Class[K]],
+ vm.runtimeClass.asInstanceOf[Class[V]])
}
/**
@@ -513,11 +427,11 @@ class SparkContext(
* IntWritable). The most natural thing would've been to have implicit objects for the
* converters, but then we couldn't have an object for every subclass of Writable (you can't
* have a parameterized singleton object). We use functions instead to create a new converter
- * for the appropriate type. In addition, we pass the converter a ClassManifest of its type to
+ * for the appropriate type. In addition, we pass the converter a ClassTag of its type to
* allow it to figure out the Writable class to use in the subclass case.
*/
def sequenceFile[K, V](path: String, minSplits: Int = defaultMinSplits)
- (implicit km: ClassManifest[K], vm: ClassManifest[V],
+ (implicit km: ClassTag[K], vm: ClassTag[V],
kcf: () => WritableConverter[K], vcf: () => WritableConverter[V])
: RDD[(K, V)] = {
val kc = kcf()
@@ -536,7 +450,7 @@ class SparkContext(
* slow if you use the default serializer (Java serialization), though the nice thing about it is
* that there's very little effort required to save arbitrary objects.
*/
- def objectFile[T: ClassManifest](
+ def objectFile[T: ClassTag](
path: String,
minSplits: Int = defaultMinSplits
): RDD[T] = {
@@ -545,17 +459,17 @@ class SparkContext(
}
- protected[spark] def checkpointFile[T: ClassManifest](
+ protected[spark] def checkpointFile[T: ClassTag](
path: String
): RDD[T] = {
new CheckpointRDD[T](this, path)
}
/** Build the union of a list of RDDs. */
- def union[T: ClassManifest](rdds: Seq[RDD[T]]): RDD[T] = new UnionRDD(this, rdds)
+ def union[T: ClassTag](rdds: Seq[RDD[T]]): RDD[T] = new UnionRDD(this, rdds)
/** Build the union of a list of RDDs passed as variable-length arguments. */
- def union[T: ClassManifest](first: RDD[T], rest: RDD[T]*): RDD[T] =
+ def union[T: ClassTag](first: RDD[T], rest: RDD[T]*): RDD[T] =
new UnionRDD(this, Seq(first) ++ rest)
// Methods for creating shared variables
@@ -798,7 +712,7 @@ class SparkContext(
* flag specifies whether the scheduler can run the computation on the driver rather than
* shipping it out to the cluster, for short actions like first().
*/
- def runJob[T, U: ClassManifest](
+ def runJob[T, U: ClassTag](
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
@@ -819,7 +733,7 @@ class SparkContext(
* allowLocal flag specifies whether the scheduler can run the computation on the driver rather
* than shipping it out to the cluster, for short actions like first().
*/
- def runJob[T, U: ClassManifest](
+ def runJob[T, U: ClassTag](
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
@@ -834,7 +748,7 @@ class SparkContext(
* Run a job on a given set of partitions of an RDD, but take a function of type
* `Iterator[T] => U` instead of `(TaskContext, Iterator[T]) => U`.
*/
- def runJob[T, U: ClassManifest](
+ def runJob[T, U: ClassTag](
rdd: RDD[T],
func: Iterator[T] => U,
partitions: Seq[Int],
@@ -846,21 +760,21 @@ class SparkContext(
/**
* Run a job on all partitions in an RDD and return the results in an array.
*/
- def runJob[T, U: ClassManifest](rdd: RDD[T], func: (TaskContext, Iterator[T]) => U): Array[U] = {
+ def runJob[T, U: ClassTag](rdd: RDD[T], func: (TaskContext, Iterator[T]) => U): Array[U] = {
runJob(rdd, func, 0 until rdd.partitions.size, false)
}
/**
* Run a job on all partitions in an RDD and return the results in an array.
*/
- def runJob[T, U: ClassManifest](rdd: RDD[T], func: Iterator[T] => U): Array[U] = {
+ def runJob[T, U: ClassTag](rdd: RDD[T], func: Iterator[T] => U): Array[U] = {
runJob(rdd, func, 0 until rdd.partitions.size, false)
}
/**
* Run a job on all partitions in an RDD and pass the results to a handler function.
*/
- def runJob[T, U: ClassManifest](
+ def runJob[T, U: ClassTag](
rdd: RDD[T],
processPartition: (TaskContext, Iterator[T]) => U,
resultHandler: (Int, U) => Unit)
@@ -871,7 +785,7 @@ class SparkContext(
/**
* Run a job on all partitions in an RDD and pass the results to a handler function.
*/
- def runJob[T, U: ClassManifest](
+ def runJob[T, U: ClassTag](
rdd: RDD[T],
processPartition: Iterator[T] => U,
resultHandler: (Int, U) => Unit)
@@ -1017,16 +931,16 @@ object SparkContext {
// TODO: Add AccumulatorParams for other types, e.g. lists and strings
- implicit def rddToPairRDDFunctions[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) =
+ implicit def rddToPairRDDFunctions[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]) =
new PairRDDFunctions(rdd)
- implicit def rddToAsyncRDDActions[T: ClassManifest](rdd: RDD[T]) = new AsyncRDDActions(rdd)
+ implicit def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]) = new AsyncRDDActions(rdd)
- implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable: ClassManifest](
+ implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag](
rdd: RDD[(K, V)]) =
new SequenceFileRDDFunctions(rdd)
- implicit def rddToOrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest](
+ implicit def rddToOrderedRDDFunctions[K <% Ordered[K]: ClassTag, V: ClassTag](
rdd: RDD[(K, V)]) =
new OrderedRDDFunctions[K, V, (K, V)](rdd)
@@ -1051,16 +965,16 @@ object SparkContext {
implicit def stringToText(s: String) = new Text(s)
- private implicit def arrayToArrayWritable[T <% Writable: ClassManifest](arr: Traversable[T]): ArrayWritable = {
+ private implicit def arrayToArrayWritable[T <% Writable: ClassTag](arr: Traversable[T]): ArrayWritable = {
def anyToWritable[U <% Writable](u: U): Writable = u
- new ArrayWritable(classManifest[T].erasure.asInstanceOf[Class[Writable]],
+ new ArrayWritable(classTag[T].runtimeClass.asInstanceOf[Class[Writable]],
arr.map(x => anyToWritable(x)).toArray)
}
// Helper objects for converting common types to Writable
- private def simpleWritableConverter[T, W <: Writable: ClassManifest](convert: W => T) = {
- val wClass = classManifest[W].erasure.asInstanceOf[Class[W]]
+ private def simpleWritableConverter[T, W <: Writable: ClassTag](convert: W => T) = {
+ val wClass = classTag[W].runtimeClass.asInstanceOf[Class[W]]
new WritableConverter[T](_ => wClass, x => convert(x.asInstanceOf[W]))
}
@@ -1079,7 +993,7 @@ object SparkContext {
implicit def stringWritableConverter() = simpleWritableConverter[String, Text](_.toString)
implicit def writableWritableConverter[T <: Writable]() =
- new WritableConverter[T](_.erasure.asInstanceOf[Class[T]], _.asInstanceOf[T])
+ new WritableConverter[T](_.runtimeClass.asInstanceOf[Class[T]], _.asInstanceOf[T])
/**
* Find the JAR from which a given class was loaded, to make it easy for users to pass
@@ -1111,17 +1025,135 @@ object SparkContext {
.map(Utils.memoryStringToMb)
.getOrElse(512)
}
+
+ // Creates a task scheduler based on a given master URL. Extracted for testing.
+ private
+ def createTaskScheduler(sc: SparkContext, master: String, appName: String): TaskScheduler = {
+ // Regular expression used for local[N] master format
+ val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r
+ // Regular expression for local[N, maxRetries], used in tests with failing tasks
+ val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r
+ // Regular expression for simulating a Spark cluster of [N, cores, memory] locally
+ val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r
+ // Regular expression for connecting to Spark deploy clusters
+ val SPARK_REGEX = """spark://(.*)""".r
+ // Regular expression for connection to Mesos cluster by mesos:// or zk:// url
+ val MESOS_REGEX = """(mesos|zk)://.*""".r
+ // Regular expression for connection to Simr cluster
+ val SIMR_REGEX = """simr://(.*)""".r
+
+ master match {
+ case "local" =>
+ new LocalScheduler(1, 0, sc)
+
+ case LOCAL_N_REGEX(threads) =>
+ new LocalScheduler(threads.toInt, 0, sc)
+
+ case LOCAL_N_FAILURES_REGEX(threads, maxFailures) =>
+ new LocalScheduler(threads.toInt, maxFailures.toInt, sc)
+
+ case SPARK_REGEX(sparkUrl) =>
+ val scheduler = new ClusterScheduler(sc)
+ val masterUrls = sparkUrl.split(",").map("spark://" + _)
+ val backend = new SparkDeploySchedulerBackend(scheduler, sc, masterUrls, appName)
+ scheduler.initialize(backend)
+ scheduler
+
+ case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) =>
+ // Check to make sure memory requested <= memoryPerSlave. Otherwise Spark will just hang.
+ val memoryPerSlaveInt = memoryPerSlave.toInt
+ if (SparkContext.executorMemoryRequested > memoryPerSlaveInt) {
+ throw new SparkException(
+ "Asked to launch cluster with %d MB RAM / worker but requested %d MB/worker".format(
+ memoryPerSlaveInt, SparkContext.executorMemoryRequested))
+ }
+
+ val scheduler = new ClusterScheduler(sc)
+ val localCluster = new LocalSparkCluster(
+ numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt)
+ val masterUrls = localCluster.start()
+ val backend = new SparkDeploySchedulerBackend(scheduler, sc, masterUrls, appName)
+ scheduler.initialize(backend)
+ backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => {
+ localCluster.stop()
+ }
+ scheduler
+
+ case "yarn-standalone" =>
+ val scheduler = try {
+ val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClusterScheduler")
+ val cons = clazz.getConstructor(classOf[SparkContext])
+ cons.newInstance(sc).asInstanceOf[ClusterScheduler]
+ } catch {
+ // TODO: Enumerate the exact reasons why it can fail
+ // But irrespective of it, it means we cannot proceed !
+ case th: Throwable => {
+ throw new SparkException("YARN mode not available ?", th)
+ }
+ }
+ val backend = new CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem)
+ scheduler.initialize(backend)
+ scheduler
+
+ case "yarn-client" =>
+ val scheduler = try {
+ val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClientClusterScheduler")
+ val cons = clazz.getConstructor(classOf[SparkContext])
+ cons.newInstance(sc).asInstanceOf[ClusterScheduler]
+
+ } catch {
+ case th: Throwable => {
+ throw new SparkException("YARN mode not available ?", th)
+ }
+ }
+
+ val backend = try {
+ val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend")
+ val cons = clazz.getConstructor(classOf[ClusterScheduler], classOf[SparkContext])
+ cons.newInstance(scheduler, sc).asInstanceOf[CoarseGrainedSchedulerBackend]
+ } catch {
+ case th: Throwable => {
+ throw new SparkException("YARN mode not available ?", th)
+ }
+ }
+
+ scheduler.initialize(backend)
+ scheduler
+
+ case mesosUrl @ MESOS_REGEX(_) =>
+ MesosNativeLibrary.load()
+ val scheduler = new ClusterScheduler(sc)
+ val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean
+ val url = mesosUrl.stripPrefix("mesos://") // strip scheme from raw Mesos URLs
+ val backend = if (coarseGrained) {
+ new CoarseMesosSchedulerBackend(scheduler, sc, url, appName)
+ } else {
+ new MesosSchedulerBackend(scheduler, sc, url, appName)
+ }
+ scheduler.initialize(backend)
+ scheduler
+
+ case SIMR_REGEX(simrUrl) =>
+ val scheduler = new ClusterScheduler(sc)
+ val backend = new SimrSchedulerBackend(scheduler, sc, simrUrl)
+ scheduler.initialize(backend)
+ scheduler
+
+ case _ =>
+ throw new SparkException("Could not parse Master URL: '" + master + "'")
+ }
+ }
}
/**
* A class encapsulating how to convert some type T to Writable. It stores both the Writable class
* corresponding to T (e.g. IntWritable for Int) and a function for doing the conversion.
- * The getter for the writable class takes a ClassManifest[T] in case this is a generic object
+ * The getter for the writable class takes a ClassTag[T] in case this is a generic object
* that doesn't know the type of T when it is created. This sounds strange but is necessary to
* support converting subclasses of Writable to themselves (writableWritableConverter).
*/
private[spark] class WritableConverter[T](
- val writableClass: ClassManifest[T] => Class[_ <: Writable],
+ val writableClass: ClassTag[T] => Class[_ <: Writable],
val convert: Writable => T)
extends Serializable
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index ff2df8fb6a..826f5c2d8c 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -20,7 +20,7 @@ package org.apache.spark
import collection.mutable
import serializer.Serializer
-import akka.actor.{Actor, ActorRef, Props, ActorSystemImpl, ActorSystem}
+import akka.actor._
import akka.remote.RemoteActorRefProvider
import org.apache.spark.broadcast.BroadcastManager
@@ -74,7 +74,8 @@ class SparkEnv (
actorSystem.shutdown()
// Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut
// down, but let's call it anyway in case it gets fixed in a later release
- actorSystem.awaitTermination()
+ // UPDATE: In Akka 2.1.x, this hangs if there are remote actors, so we can't call it.
+ //actorSystem.awaitTermination()
}
def createPythonWorker(pythonExec: String, envVars: Map[String, String]): java.net.Socket = {
@@ -151,17 +152,17 @@ object SparkEnv extends Logging {
val closureSerializer = serializerManager.get(
System.getProperty("spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer"))
- def registerOrLookup(name: String, newActor: => Actor): ActorRef = {
+ def registerOrLookup(name: String, newActor: => Actor): Either[ActorRef, ActorSelection] = {
if (isDriver) {
logInfo("Registering " + name)
- actorSystem.actorOf(Props(newActor), name = name)
+ Left(actorSystem.actorOf(Props(newActor), name = name))
} else {
val driverHost: String = System.getProperty("spark.driver.host", "localhost")
val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt
Utils.checkHost(driverHost, "Expected hostname")
- val url = "akka://spark@%s:%s/user/%s".format(driverHost, driverPort, name)
+ val url = "akka.tcp://spark@%s:%s/user/%s".format(driverHost, driverPort, name)
logInfo("Connecting to " + name + ": " + url)
- actorSystem.actorFor(url)
+ Right(actorSystem.actorSelection(url))
}
}
diff --git a/core/src/main/scala/org/apache/spark/TaskState.scala b/core/src/main/scala/org/apache/spark/TaskState.scala
index 19ce8369d9..0bf1e4a5e2 100644
--- a/core/src/main/scala/org/apache/spark/TaskState.scala
+++ b/core/src/main/scala/org/apache/spark/TaskState.scala
@@ -19,8 +19,7 @@ package org.apache.spark
import org.apache.mesos.Protos.{TaskState => MesosTaskState}
-private[spark] object TaskState
- extends Enumeration("LAUNCHING", "RUNNING", "FINISHED", "FAILED", "KILLED", "LOST") {
+private[spark] object TaskState extends Enumeration {
val LAUNCHING, RUNNING, FINISHED, FAILED, KILLED, LOST = Value
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
index 043cb183ba..da30cf619a 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
@@ -17,18 +17,23 @@
package org.apache.spark.api.java
+import scala.reflect.ClassTag
+
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext.doubleRDDToDoubleRDDFunctions
import org.apache.spark.api.java.function.{Function => JFunction}
import org.apache.spark.util.StatCounter
import org.apache.spark.partial.{BoundedDouble, PartialResult}
import org.apache.spark.storage.StorageLevel
+
import java.lang.Double
import org.apache.spark.Partitioner
+import scala.collection.JavaConverters._
+
class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, JavaDoubleRDD] {
- override val classManifest: ClassManifest[Double] = implicitly[ClassManifest[Double]]
+ override val classTag: ClassTag[Double] = implicitly[ClassTag[Double]]
override val rdd: RDD[Double] = srdd.map(x => Double.valueOf(x))
@@ -42,7 +47,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
/** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
def cache(): JavaDoubleRDD = fromRDD(srdd.cache())
- /**
+ /**
* Set this RDD's storage level to persist its values across operations after the first time
* it is computed. Can only be called once on each RDD.
*/
@@ -106,7 +111,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
/**
* Return an RDD with the elements from `this` that are not in `other`.
- *
+ *
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
* RDD will be <= us.
*/
@@ -182,6 +187,44 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
/** (Experimental) Approximate operation to return the sum within a timeout. */
def sumApprox(timeout: Long): PartialResult[BoundedDouble] = srdd.sumApprox(timeout)
+
+ /**
+ * Compute a histogram of the data using bucketCount number of buckets evenly
+ * spaced between the minimum and maximum of the RDD. For example if the min
+ * value is 0 and the max is 100 and there are two buckets the resulting
+ * buckets will be [0,50) [50,100]. bucketCount must be at least 1
+ * If the RDD contains infinity, NaN throws an exception
+ * If the elements in RDD do not vary (max == min) always returns a single bucket.
+ */
+ def histogram(bucketCount: Int): Pair[Array[scala.Double], Array[Long]] = {
+ val result = srdd.histogram(bucketCount)
+ (result._1, result._2)
+ }
+
+ /**
+ * Compute a histogram using the provided buckets. The buckets are all open
+ * to the left except for the last which is closed
+ * e.g. for the array
+ * [1,10,20,50] the buckets are [1,10) [10,20) [20,50]
+ * e.g 1<=x<10 , 10<=x<20, 20<=x<50
+ * And on the input of 1 and 50 we would have a histogram of 1,0,0
+ *
+ * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched
+ * from an O(log n) inseration to O(1) per element. (where n = # buckets) if you set evenBuckets
+ * to true.
+ * buckets must be sorted and not contain any duplicates.
+ * buckets array must be at least two elements
+ * All NaN entries are treated the same. If you have a NaN bucket it must be
+ * the maximum value of the last position and all NaN entries will be counted
+ * in that bucket.
+ */
+ def histogram(buckets: Array[scala.Double]): Array[Long] = {
+ srdd.histogram(buckets, false)
+ }
+
+ def histogram(buckets: Array[Double], evenBuckets: Boolean): Array[Long] = {
+ srdd.histogram(buckets.map(_.toDouble), evenBuckets)
+ }
}
object JavaDoubleRDD {
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
index 2142fd7327..363667fa86 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala
@@ -22,6 +22,7 @@ import java.util.Comparator
import scala.Tuple2
import scala.collection.JavaConversions._
+import scala.reflect.ClassTag
import com.google.common.base.Optional
import org.apache.hadoop.io.compress.CompressionCodec
@@ -43,13 +44,13 @@ import org.apache.spark.rdd.OrderedRDDFunctions
import org.apache.spark.storage.StorageLevel
-class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManifest[K],
- implicit val vManifest: ClassManifest[V]) extends JavaRDDLike[(K, V), JavaPairRDD[K, V]] {
+class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kClassTag: ClassTag[K],
+ implicit val vClassTag: ClassTag[V]) extends JavaRDDLike[(K, V), JavaPairRDD[K, V]] {
override def wrapRDD(rdd: RDD[(K, V)]): JavaPairRDD[K, V] = JavaPairRDD.fromRDD(rdd)
- override val classManifest: ClassManifest[(K, V)] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]]
+ override val classTag: ClassTag[(K, V)] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[Tuple2[K, V]]]
import JavaPairRDD._
@@ -58,7 +59,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
/** Persist this RDD with the default storage level (`MEMORY_ONLY`). */
def cache(): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.cache())
- /**
+ /**
* Set this RDD's storage level to persist its values across operations after the first time
* it is computed. Can only be called once on each RDD.
*/
@@ -138,14 +139,14 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
override def first(): (K, V) = rdd.first()
// Pair RDD functions
-
+
/**
- * Generic function to combine the elements for each key using a custom set of aggregation
- * functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a
- * "combined type" C * Note that V and C can be different -- for example, one might group an
- * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three
+ * Generic function to combine the elements for each key using a custom set of aggregation
+ * functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a
+ * "combined type" C * Note that V and C can be different -- for example, one might group an
+ * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three
* functions:
- *
+ *
* - `createCombiner`, which turns a V into a C (e.g., creates a one-element list)
* - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list)
* - `mergeCombiners`, to combine two C's into a single one.
@@ -157,8 +158,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
mergeValue: JFunction2[C, V, C],
mergeCombiners: JFunction2[C, C, C],
partitioner: Partitioner): JavaPairRDD[K, C] = {
- implicit val cm: ClassManifest[C] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[C]]
+ implicit val cm: ClassTag[C] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[C]]
fromRDD(rdd.combineByKey(
createCombiner,
mergeValue,
@@ -195,14 +195,14 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
/** Count the number of elements for each key, and return the result to the master as a Map. */
def countByKey(): java.util.Map[K, Long] = mapAsJavaMap(rdd.countByKey())
- /**
+ /**
* (Experimental) Approximate version of countByKey that can return a partial result if it does
* not finish within a timeout.
*/
def countByKeyApprox(timeout: Long): PartialResult[java.util.Map[K, BoundedDouble]] =
rdd.countByKeyApprox(timeout).map(mapAsJavaMap)
- /**
+ /**
* (Experimental) Approximate version of countByKey that can return a partial result if it does
* not finish within a timeout.
*/
@@ -258,7 +258,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
/**
* Return an RDD with the elements from `this` that are not in `other`.
- *
+ *
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
* RDD will be <= us.
*/
@@ -315,15 +315,14 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
fromRDD(joinResult.mapValues{case (v, w) => (JavaUtils.optionToOptional(v), w)})
}
- /**
+ /**
* Simplified version of combineByKey that hash-partitions the resulting RDD using the existing
* partitioner/parallelism level.
*/
def combineByKey[C](createCombiner: JFunction[V, C],
mergeValue: JFunction2[C, V, C],
mergeCombiners: JFunction2[C, C, C]): JavaPairRDD[K, C] = {
- implicit val cm: ClassManifest[C] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[C]]
+ implicit val cm: ClassTag[C] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[C]]
fromRDD(combineByKey(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(rdd)))
}
@@ -414,8 +413,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
* this also retains the original RDD's partitioning.
*/
def mapValues[U](f: JFunction[V, U]): JavaPairRDD[K, U] = {
- implicit val cm: ClassManifest[U] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]]
+ implicit val cm: ClassTag[U] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[U]]
fromRDD(rdd.mapValues(f))
}
@@ -426,8 +424,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
def flatMapValues[U](f: JFunction[V, java.lang.Iterable[U]]): JavaPairRDD[K, U] = {
import scala.collection.JavaConverters._
def fn = (x: V) => f.apply(x).asScala
- implicit val cm: ClassManifest[U] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]]
+ implicit val cm: ClassTag[U] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[U]]
fromRDD(rdd.flatMapValues(fn))
}
@@ -592,6 +589,20 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
}
/**
+ * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling
+ * `collect` or `save` on the resulting RDD will return or output an ordered list of records
+ * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in
+ * order of the keys).
+ */
+ def sortByKey(comp: Comparator[K], ascending: Boolean, numPartitions: Int): JavaPairRDD[K, V] = {
+ class KeyOrdering(val a: K) extends Ordered[K] {
+ override def compare(b: K) = comp.compare(a, b)
+ }
+ implicit def toOrdered(x: K): Ordered[K] = new KeyOrdering(x)
+ fromRDD(new OrderedRDDFunctions[K, V, (K, V)](rdd).sortByKey(ascending, numPartitions))
+ }
+
+ /**
* Return an RDD with the keys of each tuple.
*/
def keys(): JavaRDD[K] = JavaRDD.fromRDD[K](rdd.map(_._1))
@@ -603,22 +614,22 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif
}
object JavaPairRDD {
- def groupByResultToJava[K, T](rdd: RDD[(K, Seq[T])])(implicit kcm: ClassManifest[K],
- vcm: ClassManifest[T]): RDD[(K, JList[T])] =
+ def groupByResultToJava[K, T](rdd: RDD[(K, Seq[T])])(implicit kcm: ClassTag[K],
+ vcm: ClassTag[T]): RDD[(K, JList[T])] =
rddToPairRDDFunctions(rdd).mapValues(seqAsJavaList _)
- def cogroupResultToJava[W, K, V](rdd: RDD[(K, (Seq[V], Seq[W]))])(implicit kcm: ClassManifest[K],
- vcm: ClassManifest[V]): RDD[(K, (JList[V], JList[W]))] = rddToPairRDDFunctions(rdd).mapValues((x: (Seq[V],
- Seq[W])) => (seqAsJavaList(x._1), seqAsJavaList(x._2)))
+ def cogroupResultToJava[W, K, V](rdd: RDD[(K, (Seq[V], Seq[W]))])(implicit kcm: ClassTag[K],
+ vcm: ClassTag[V]): RDD[(K, (JList[V], JList[W]))] = rddToPairRDDFunctions(rdd)
+ .mapValues((x: (Seq[V], Seq[W])) => (seqAsJavaList(x._1), seqAsJavaList(x._2)))
def cogroupResult2ToJava[W1, W2, K, V](rdd: RDD[(K, (Seq[V], Seq[W1],
- Seq[W2]))])(implicit kcm: ClassManifest[K]) : RDD[(K, (JList[V], JList[W1],
+ Seq[W2]))])(implicit kcm: ClassTag[K]) : RDD[(K, (JList[V], JList[W1],
JList[W2]))] = rddToPairRDDFunctions(rdd).mapValues(
(x: (Seq[V], Seq[W1], Seq[W2])) => (seqAsJavaList(x._1),
seqAsJavaList(x._2),
seqAsJavaList(x._3)))
- def fromRDD[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]): JavaPairRDD[K, V] =
+ def fromRDD[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]): JavaPairRDD[K, V] =
new JavaPairRDD[K, V](rdd)
implicit def toRDD[K, V](rdd: JavaPairRDD[K, V]): RDD[(K, V)] = rdd.rdd
@@ -626,10 +637,8 @@ object JavaPairRDD {
/** Convert a JavaRDD of key-value pairs to JavaPairRDD. */
def fromJavaRDD[K, V](rdd: JavaRDD[(K, V)]): JavaPairRDD[K, V] = {
- implicit val cmk: ClassManifest[K] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
- implicit val cmv: ClassManifest[V] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]]
+ implicit val cmk: ClassTag[K] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]]
+ implicit val cmv: ClassTag[V] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]]
new JavaPairRDD[K, V](rdd.rdd)
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
index 3b359a8fd6..c47657f512 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala
@@ -17,12 +17,14 @@
package org.apache.spark.api.java
+import scala.reflect.ClassTag
+
import org.apache.spark._
import org.apache.spark.rdd.RDD
import org.apache.spark.api.java.function.{Function => JFunction}
import org.apache.spark.storage.StorageLevel
-class JavaRDD[T](val rdd: RDD[T])(implicit val classManifest: ClassManifest[T]) extends
+class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) extends
JavaRDDLike[T, JavaRDD[T]] {
override def wrapRDD(rdd: RDD[T]): JavaRDD[T] = JavaRDD.fromRDD(rdd)
@@ -127,8 +129,7 @@ JavaRDDLike[T, JavaRDD[T]] {
object JavaRDD {
- implicit def fromRDD[T: ClassManifest](rdd: RDD[T]): JavaRDD[T] = new JavaRDD[T](rdd)
+ implicit def fromRDD[T: ClassTag](rdd: RDD[T]): JavaRDD[T] = new JavaRDD[T](rdd)
implicit def toRDD[T](rdd: JavaRDD[T]): RDD[T] = rdd.rdd
}
-
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
index 7a3568c5ef..9e912d3adb 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
@@ -20,6 +20,7 @@ package org.apache.spark.api.java
import java.util.{List => JList, Comparator}
import scala.Tuple2
import scala.collection.JavaConversions._
+import scala.reflect.ClassTag
import com.google.common.base.Optional
import org.apache.hadoop.io.compress.CompressionCodec
@@ -35,7 +36,7 @@ import org.apache.spark.storage.StorageLevel
trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
def wrapRDD(rdd: RDD[T]): This
- implicit val classManifest: ClassManifest[T]
+ implicit val classTag: ClassTag[T]
def rdd: RDD[T]
@@ -71,7 +72,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index
* of the original partition.
*/
- def mapPartitionsWithIndex[R: ClassManifest](
+ def mapPartitionsWithIndex[R: ClassTag](
f: JFunction2[Int, java.util.Iterator[T], java.util.Iterator[R]],
preservesPartitioning: Boolean = false): JavaRDD[R] =
new JavaRDD(rdd.mapPartitionsWithIndex(((a,b) => f(a,asJavaIterator(b))),
@@ -87,7 +88,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* Return a new RDD by applying a function to all elements of this RDD.
*/
def map[K2, V2](f: PairFunction[T, K2, V2]): JavaPairRDD[K2, V2] = {
- def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K2, V2]]]
+ def cm = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[Tuple2[K2, V2]]]
new JavaPairRDD(rdd.map(f)(cm))(f.keyType(), f.valueType())
}
@@ -118,7 +119,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
def flatMap[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairRDD[K2, V2] = {
import scala.collection.JavaConverters._
def fn = (x: T) => f.apply(x).asScala
- def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K2, V2]]]
+ def cm = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[Tuple2[K2, V2]]]
JavaPairRDD.fromRDD(rdd.flatMap(fn)(cm))(f.keyType(), f.valueType())
}
@@ -158,18 +159,16 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* elements (a, b) where a is in `this` and b is in `other`.
*/
def cartesian[U](other: JavaRDDLike[U, _]): JavaPairRDD[T, U] =
- JavaPairRDD.fromRDD(rdd.cartesian(other.rdd)(other.classManifest))(classManifest,
- other.classManifest)
+ JavaPairRDD.fromRDD(rdd.cartesian(other.rdd)(other.classTag))(classTag, other.classTag)
/**
* Return an RDD of grouped elements. Each group consists of a key and a sequence of elements
* mapping to that key.
*/
def groupBy[K](f: JFunction[T, K]): JavaPairRDD[K, JList[T]] = {
- implicit val kcm: ClassManifest[K] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
- implicit val vcm: ClassManifest[JList[T]] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[JList[T]]]
+ implicit val kcm: ClassTag[K] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]]
+ implicit val vcm: ClassTag[JList[T]] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[JList[T]]]
JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f)(f.returnType)))(kcm, vcm)
}
@@ -178,10 +177,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* mapping to that key.
*/
def groupBy[K](f: JFunction[T, K], numPartitions: Int): JavaPairRDD[K, JList[T]] = {
- implicit val kcm: ClassManifest[K] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
- implicit val vcm: ClassManifest[JList[T]] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[JList[T]]]
+ implicit val kcm: ClassTag[K] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]]
+ implicit val vcm: ClassTag[JList[T]] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[JList[T]]]
JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f, numPartitions)(f.returnType)))(kcm, vcm)
}
@@ -209,7 +207,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* a map on the other).
*/
def zip[U](other: JavaRDDLike[U, _]): JavaPairRDD[T, U] = {
- JavaPairRDD.fromRDD(rdd.zip(other.rdd)(other.classManifest))(classManifest, other.classManifest)
+ JavaPairRDD.fromRDD(rdd.zip(other.rdd)(other.classTag))(classTag, other.classTag)
}
/**
@@ -224,7 +222,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
def fn = (x: Iterator[T], y: Iterator[U]) => asScalaIterator(
f.apply(asJavaIterator(x), asJavaIterator(y)).iterator())
JavaRDD.fromRDD(
- rdd.zipPartitions(other.rdd)(fn)(other.classManifest, f.elementType()))(f.elementType())
+ rdd.zipPartitions(other.rdd)(fn)(other.classTag, f.elementType()))(f.elementType())
}
// Actions (launch a job to return a value to the user program)
@@ -356,7 +354,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* Creates tuples of the elements in this RDD by applying `f`.
*/
def keyBy[K](f: JFunction[T, K]): JavaPairRDD[K, T] = {
- implicit val kcm: ClassManifest[K] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
+ implicit val kcm: ClassTag[K] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]]
JavaPairRDD.fromRDD(rdd.keyBy(f))
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
index 8869e072bf..acf328aa6a 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
@@ -21,6 +21,7 @@ import java.util.{Map => JMap}
import scala.collection.JavaConversions
import scala.collection.JavaConversions._
+import scala.reflect.ClassTag
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.mapred.InputFormat
@@ -82,8 +83,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
/** Distribute a local Scala collection to form an RDD. */
def parallelize[T](list: java.util.List[T], numSlices: Int): JavaRDD[T] = {
- implicit val cm: ClassManifest[T] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]]
sc.parallelize(JavaConversions.asScalaBuffer(list), numSlices)
}
@@ -94,10 +94,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
/** Distribute a local Scala collection to form an RDD. */
def parallelizePairs[K, V](list: java.util.List[Tuple2[K, V]], numSlices: Int)
: JavaPairRDD[K, V] = {
- implicit val kcm: ClassManifest[K] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]]
- implicit val vcm: ClassManifest[V] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]]
+ implicit val kcm: ClassTag[K] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]]
+ implicit val vcm: ClassTag[V] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]]
JavaPairRDD.fromRDD(sc.parallelize(JavaConversions.asScalaBuffer(list), numSlices))
}
@@ -132,16 +130,16 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
valueClass: Class[V],
minSplits: Int
): JavaPairRDD[K, V] = {
- implicit val kcm = ClassManifest.fromClass(keyClass)
- implicit val vcm = ClassManifest.fromClass(valueClass)
+ implicit val kcm: ClassTag[K] = ClassTag(keyClass)
+ implicit val vcm: ClassTag[V] = ClassTag(valueClass)
new JavaPairRDD(sc.sequenceFile(path, keyClass, valueClass, minSplits))
}
/**Get an RDD for a Hadoop SequenceFile. */
def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]):
JavaPairRDD[K, V] = {
- implicit val kcm = ClassManifest.fromClass(keyClass)
- implicit val vcm = ClassManifest.fromClass(valueClass)
+ implicit val kcm: ClassTag[K] = ClassTag(keyClass)
+ implicit val vcm: ClassTag[V] = ClassTag(valueClass)
new JavaPairRDD(sc.sequenceFile(path, keyClass, valueClass))
}
@@ -153,8 +151,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
* that there's very little effort required to save arbitrary objects.
*/
def objectFile[T](path: String, minSplits: Int): JavaRDD[T] = {
- implicit val cm: ClassManifest[T] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]]
sc.objectFile(path, minSplits)(cm)
}
@@ -166,8 +163,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
* that there's very little effort required to save arbitrary objects.
*/
def objectFile[T](path: String): JavaRDD[T] = {
- implicit val cm: ClassManifest[T] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]]
sc.objectFile(path)(cm)
}
@@ -183,8 +179,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
valueClass: Class[V],
minSplits: Int
): JavaPairRDD[K, V] = {
- implicit val kcm = ClassManifest.fromClass(keyClass)
- implicit val vcm = ClassManifest.fromClass(valueClass)
+ implicit val kcm: ClassTag[K] = ClassTag(keyClass)
+ implicit val vcm: ClassTag[V] = ClassTag(valueClass)
new JavaPairRDD(sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass, minSplits))
}
@@ -199,8 +195,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
keyClass: Class[K],
valueClass: Class[V]
): JavaPairRDD[K, V] = {
- implicit val kcm = ClassManifest.fromClass(keyClass)
- implicit val vcm = ClassManifest.fromClass(valueClass)
+ implicit val kcm: ClassTag[K] = ClassTag(keyClass)
+ implicit val vcm: ClassTag[V] = ClassTag(valueClass)
new JavaPairRDD(sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass))
}
@@ -212,8 +208,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
valueClass: Class[V],
minSplits: Int
): JavaPairRDD[K, V] = {
- implicit val kcm = ClassManifest.fromClass(keyClass)
- implicit val vcm = ClassManifest.fromClass(valueClass)
+ implicit val kcm: ClassTag[K] = ClassTag(keyClass)
+ implicit val vcm: ClassTag[V] = ClassTag(valueClass)
new JavaPairRDD(sc.hadoopFile(path, inputFormatClass, keyClass, valueClass, minSplits))
}
@@ -224,8 +220,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
keyClass: Class[K],
valueClass: Class[V]
): JavaPairRDD[K, V] = {
- implicit val kcm = ClassManifest.fromClass(keyClass)
- implicit val vcm = ClassManifest.fromClass(valueClass)
+ implicit val kcm: ClassTag[K] = ClassTag(keyClass)
+ implicit val vcm: ClassTag[V] = ClassTag(valueClass)
new JavaPairRDD(sc.hadoopFile(path,
inputFormatClass, keyClass, valueClass))
}
@@ -240,8 +236,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
kClass: Class[K],
vClass: Class[V],
conf: Configuration): JavaPairRDD[K, V] = {
- implicit val kcm = ClassManifest.fromClass(kClass)
- implicit val vcm = ClassManifest.fromClass(vClass)
+ implicit val kcm: ClassTag[K] = ClassTag(kClass)
+ implicit val vcm: ClassTag[V] = ClassTag(vClass)
new JavaPairRDD(sc.newAPIHadoopFile(path, fClass, kClass, vClass, conf))
}
@@ -254,15 +250,15 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
fClass: Class[F],
kClass: Class[K],
vClass: Class[V]): JavaPairRDD[K, V] = {
- implicit val kcm = ClassManifest.fromClass(kClass)
- implicit val vcm = ClassManifest.fromClass(vClass)
+ implicit val kcm: ClassTag[K] = ClassTag(kClass)
+ implicit val vcm: ClassTag[V] = ClassTag(vClass)
new JavaPairRDD(sc.newAPIHadoopRDD(conf, fClass, kClass, vClass))
}
/** Build the union of two or more RDDs. */
override def union[T](first: JavaRDD[T], rest: java.util.List[JavaRDD[T]]): JavaRDD[T] = {
val rdds: Seq[RDD[T]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.rdd)
- implicit val cm: ClassManifest[T] = first.classManifest
+ implicit val cm: ClassTag[T] = first.classTag
sc.union(rdds)(cm)
}
@@ -270,9 +266,9 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
override def union[K, V](first: JavaPairRDD[K, V], rest: java.util.List[JavaPairRDD[K, V]])
: JavaPairRDD[K, V] = {
val rdds: Seq[RDD[(K, V)]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.rdd)
- implicit val cm: ClassManifest[(K, V)] = first.classManifest
- implicit val kcm: ClassManifest[K] = first.kManifest
- implicit val vcm: ClassManifest[V] = first.vManifest
+ implicit val cm: ClassTag[(K, V)] = first.classTag
+ implicit val kcm: ClassTag[K] = first.kClassTag
+ implicit val vcm: ClassTag[V] = first.vClassTag
new JavaPairRDD(sc.union(rdds)(cm))(kcm, vcm)
}
@@ -405,8 +401,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork
}
protected def checkpointFile[T](path: String): JavaRDD[T] = {
- implicit val cm: ClassManifest[T] =
- implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]]
+ implicit val cm: ClassTag[T] =
+ implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]]
new JavaRDD(sc.checkpointFile(path))
}
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java
index c9cbce5624..2090efd3b9 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java
@@ -17,7 +17,6 @@
package org.apache.spark.api.java;
-import java.util.Arrays;
import java.util.ArrayList;
import java.util.List;
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala
index 2dfda8b09a..bdb01f7670 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala
@@ -17,9 +17,11 @@
package org.apache.spark.api.java.function
+import scala.reflect.ClassTag
+
/**
* A function that returns zero or more output records from each input record.
*/
abstract class FlatMapFunction[T, R] extends Function[T, java.lang.Iterable[R]] {
- def elementType() : ClassManifest[R] = ClassManifest.Any.asInstanceOf[ClassManifest[R]]
+ def elementType(): ClassTag[R] = ClassTag.Any.asInstanceOf[ClassTag[R]]
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala
index 528e1c0a7c..aae1349c5e 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala
@@ -17,9 +17,11 @@
package org.apache.spark.api.java.function
+import scala.reflect.ClassTag
+
/**
* A function that takes two inputs and returns zero or more output records.
*/
abstract class FlatMapFunction2[A, B, C] extends Function2[A, B, java.lang.Iterable[C]] {
- def elementType() : ClassManifest[C] = ClassManifest.Any.asInstanceOf[ClassManifest[C]]
+ def elementType() : ClassTag[C] = ClassTag.Any.asInstanceOf[ClassTag[C]]
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/Function.java b/core/src/main/scala/org/apache/spark/api/java/function/Function.java
index ce368ee01b..537439ef53 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/Function.java
+++ b/core/src/main/scala/org/apache/spark/api/java/function/Function.java
@@ -17,8 +17,8 @@
package org.apache.spark.api.java.function;
-import scala.reflect.ClassManifest;
-import scala.reflect.ClassManifest$;
+import scala.reflect.ClassTag;
+import scala.reflect.ClassTag$;
import java.io.Serializable;
@@ -29,8 +29,8 @@ import java.io.Serializable;
* when mapping RDDs of other types.
*/
public abstract class Function<T, R> extends WrappedFunction1<T, R> implements Serializable {
- public ClassManifest<R> returnType() {
- return (ClassManifest<R>) ClassManifest$.MODULE$.fromClass(Object.class);
+ public ClassTag<R> returnType() {
+ return ClassTag$.MODULE$.apply(Object.class);
}
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/Function2.java b/core/src/main/scala/org/apache/spark/api/java/function/Function2.java
index 44ad559d48..a2d1214fb4 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/Function2.java
+++ b/core/src/main/scala/org/apache/spark/api/java/function/Function2.java
@@ -17,8 +17,8 @@
package org.apache.spark.api.java.function;
-import scala.reflect.ClassManifest;
-import scala.reflect.ClassManifest$;
+import scala.reflect.ClassTag;
+import scala.reflect.ClassTag$;
import java.io.Serializable;
@@ -28,8 +28,8 @@ import java.io.Serializable;
public abstract class Function2<T1, T2, R> extends WrappedFunction2<T1, T2, R>
implements Serializable {
- public ClassManifest<R> returnType() {
- return (ClassManifest<R>) ClassManifest$.MODULE$.fromClass(Object.class);
+ public ClassTag<R> returnType() {
+ return (ClassTag<R>) ClassTag$.MODULE$.apply(Object.class);
}
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/Function3.java b/core/src/main/scala/org/apache/spark/api/java/function/Function3.java
index ac6178924a..fb1deceab5 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/Function3.java
+++ b/core/src/main/scala/org/apache/spark/api/java/function/Function3.java
@@ -17,8 +17,8 @@
package org.apache.spark.api.java.function;
-import scala.reflect.ClassManifest;
-import scala.reflect.ClassManifest$;
+import scala.reflect.ClassTag;
+import scala.reflect.ClassTag$;
import scala.runtime.AbstractFunction2;
import java.io.Serializable;
@@ -29,8 +29,8 @@ import java.io.Serializable;
public abstract class Function3<T1, T2, T3, R> extends WrappedFunction3<T1, T2, T3, R>
implements Serializable {
- public ClassManifest<R> returnType() {
- return (ClassManifest<R>) ClassManifest$.MODULE$.fromClass(Object.class);
+ public ClassTag<R> returnType() {
+ return (ClassTag<R>) ClassTag$.MODULE$.apply(Object.class);
}
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java
index 6d76a8f970..ca485b3cc2 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java
+++ b/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java
@@ -18,8 +18,8 @@
package org.apache.spark.api.java.function;
import scala.Tuple2;
-import scala.reflect.ClassManifest;
-import scala.reflect.ClassManifest$;
+import scala.reflect.ClassTag;
+import scala.reflect.ClassTag$;
import java.io.Serializable;
@@ -33,11 +33,11 @@ public abstract class PairFlatMapFunction<T, K, V>
extends WrappedFunction1<T, Iterable<Tuple2<K, V>>>
implements Serializable {
- public ClassManifest<K> keyType() {
- return (ClassManifest<K>) ClassManifest$.MODULE$.fromClass(Object.class);
+ public ClassTag<K> keyType() {
+ return (ClassTag<K>) ClassTag$.MODULE$.apply(Object.class);
}
- public ClassManifest<V> valueType() {
- return (ClassManifest<V>) ClassManifest$.MODULE$.fromClass(Object.class);
+ public ClassTag<V> valueType() {
+ return (ClassTag<V>) ClassTag$.MODULE$.apply(Object.class);
}
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java
index ede7ceefb5..cbe2306026 100644
--- a/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java
+++ b/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java
@@ -18,8 +18,8 @@
package org.apache.spark.api.java.function;
import scala.Tuple2;
-import scala.reflect.ClassManifest;
-import scala.reflect.ClassManifest$;
+import scala.reflect.ClassTag;
+import scala.reflect.ClassTag$;
import java.io.Serializable;
@@ -31,11 +31,11 @@ import java.io.Serializable;
public abstract class PairFunction<T, K, V> extends WrappedFunction1<T, Tuple2<K, V>>
implements Serializable {
- public ClassManifest<K> keyType() {
- return (ClassManifest<K>) ClassManifest$.MODULE$.fromClass(Object.class);
+ public ClassTag<K> keyType() {
+ return (ClassTag<K>) ClassTag$.MODULE$.apply(Object.class);
}
- public ClassManifest<V> valueType() {
- return (ClassManifest<V>) ClassManifest$.MODULE$.fromClass(Object.class);
+ public ClassTag<V> valueType() {
+ return (ClassTag<V>) ClassTag$.MODULE$.apply(Object.class);
}
}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 12b4d94a56..a659cc06c2 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -22,18 +22,17 @@ import java.net._
import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}
import scala.collection.JavaConversions._
+import scala.reflect.ClassTag
import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark._
import org.apache.spark.rdd.RDD
-import org.apache.spark.rdd.PipedRDD
import org.apache.spark.util.Utils
-
-private[spark] class PythonRDD[T: ClassManifest](
+private[spark] class PythonRDD[T: ClassTag](
parent: RDD[T],
- command: Seq[String],
+ command: Array[Byte],
envVars: JMap[String, String],
pythonIncludes: JList[String],
preservePartitoning: Boolean,
@@ -44,21 +43,10 @@ private[spark] class PythonRDD[T: ClassManifest](
val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
- // Similar to Runtime.exec(), if we are given a single string, split it into words
- // using a standard StringTokenizer (i.e. by spaces)
- def this(parent: RDD[T], command: String, envVars: JMap[String, String],
- pythonIncludes: JList[String],
- preservePartitoning: Boolean, pythonExec: String,
- broadcastVars: JList[Broadcast[Array[Byte]]],
- accumulator: Accumulator[JList[Array[Byte]]]) =
- this(parent, PipedRDD.tokenize(command), envVars, pythonIncludes, preservePartitoning, pythonExec,
- broadcastVars, accumulator)
-
override def getPartitions = parent.partitions
override val partitioner = if (preservePartitoning) parent.partitioner else None
-
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
val startTime = System.currentTimeMillis
val env = SparkEnv.get
@@ -71,11 +59,10 @@ private[spark] class PythonRDD[T: ClassManifest](
SparkEnv.set(env)
val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
val dataOut = new DataOutputStream(stream)
- val printOut = new PrintWriter(stream)
// Partition index
dataOut.writeInt(split.index)
// sparkFilesDir
- PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dataOut)
+ dataOut.writeUTF(SparkFiles.getRootDirectory)
// Broadcast variables
dataOut.writeInt(broadcastVars.length)
for (broadcast <- broadcastVars) {
@@ -85,21 +72,16 @@ private[spark] class PythonRDD[T: ClassManifest](
}
// Python includes (*.zip and *.egg files)
dataOut.writeInt(pythonIncludes.length)
- for (f <- pythonIncludes) {
- PythonRDD.writeAsPickle(f, dataOut)
- }
+ pythonIncludes.foreach(dataOut.writeUTF)
dataOut.flush()
- // Serialized user code
- for (elem <- command) {
- printOut.println(elem)
- }
- printOut.flush()
+ // Serialized command:
+ dataOut.writeInt(command.length)
+ dataOut.write(command)
// Data values
for (elem <- parent.iterator(split, context)) {
- PythonRDD.writeAsPickle(elem, dataOut)
+ PythonRDD.writeToStream(elem, dataOut)
}
dataOut.flush()
- printOut.flush()
worker.shutdownOutput()
} catch {
case e: IOException =>
@@ -132,7 +114,7 @@ private[spark] class PythonRDD[T: ClassManifest](
val obj = new Array[Byte](length)
stream.readFully(obj)
obj
- case -3 =>
+ case SpecialLengths.TIMING_DATA =>
// Timing data from worker
val bootTime = stream.readLong()
val initTime = stream.readLong()
@@ -143,30 +125,30 @@ private[spark] class PythonRDD[T: ClassManifest](
val total = finishTime - startTime
logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, init, finish))
read
- case -2 =>
+ case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
// Signals that an exception has been thrown in python
val exLength = stream.readInt()
val obj = new Array[Byte](exLength)
stream.readFully(obj)
throw new PythonException(new String(obj))
- case -1 =>
+ case SpecialLengths.END_OF_DATA_SECTION =>
// We've finished the data section of the output, but we can still
- // read some accumulator updates; let's do that, breaking when we
- // get a negative length record.
- var len2 = stream.readInt()
- while (len2 >= 0) {
- val update = new Array[Byte](len2)
+ // read some accumulator updates:
+ val numAccumulatorUpdates = stream.readInt()
+ (1 to numAccumulatorUpdates).foreach { _ =>
+ val updateLen = stream.readInt()
+ val update = new Array[Byte](updateLen)
stream.readFully(update)
accumulator += Collections.singletonList(update)
- len2 = stream.readInt()
+
}
- new Array[Byte](0)
+ Array.empty[Byte]
}
} catch {
case eof: EOFException => {
throw new SparkException("Python worker exited unexpectedly (crashed)", eof)
}
- case e => throw e
+ case e: Throwable => throw e
}
}
@@ -197,62 +179,15 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
val asJavaPairRDD : JavaPairRDD[Long, Array[Byte]] = JavaPairRDD.fromRDD(this)
}
-private[spark] object PythonRDD {
-
- /** Strips the pickle PROTO and STOP opcodes from the start and end of a pickle */
- def stripPickle(arr: Array[Byte]) : Array[Byte] = {
- arr.slice(2, arr.length - 1)
- }
+private object SpecialLengths {
+ val END_OF_DATA_SECTION = -1
+ val PYTHON_EXCEPTION_THROWN = -2
+ val TIMING_DATA = -3
+}
- /**
- * Write strings, pickled Python objects, or pairs of pickled objects to a data output stream.
- * The data format is a 32-bit integer representing the pickled object's length (in bytes),
- * followed by the pickled data.
- *
- * Pickle module:
- *
- * http://docs.python.org/2/library/pickle.html
- *
- * The pickle protocol is documented in the source of the `pickle` and `pickletools` modules:
- *
- * http://hg.python.org/cpython/file/2.6/Lib/pickle.py
- * http://hg.python.org/cpython/file/2.6/Lib/pickletools.py
- *
- * @param elem the object to write
- * @param dOut a data output stream
- */
- def writeAsPickle(elem: Any, dOut: DataOutputStream) {
- if (elem.isInstanceOf[Array[Byte]]) {
- val arr = elem.asInstanceOf[Array[Byte]]
- dOut.writeInt(arr.length)
- dOut.write(arr)
- } else if (elem.isInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]) {
- val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]
- val length = t._1.length + t._2.length - 3 - 3 + 4 // stripPickle() removes 3 bytes
- dOut.writeInt(length)
- dOut.writeByte(Pickle.PROTO)
- dOut.writeByte(Pickle.TWO)
- dOut.write(PythonRDD.stripPickle(t._1))
- dOut.write(PythonRDD.stripPickle(t._2))
- dOut.writeByte(Pickle.TUPLE2)
- dOut.writeByte(Pickle.STOP)
- } else if (elem.isInstanceOf[String]) {
- // For uniformity, strings are wrapped into Pickles.
- val s = elem.asInstanceOf[String].getBytes("UTF-8")
- val length = 2 + 1 + 4 + s.length + 1
- dOut.writeInt(length)
- dOut.writeByte(Pickle.PROTO)
- dOut.writeByte(Pickle.TWO)
- dOut.write(Pickle.BINUNICODE)
- dOut.writeInt(Integer.reverseBytes(s.length))
- dOut.write(s)
- dOut.writeByte(Pickle.STOP)
- } else {
- throw new SparkException("Unexpected RDD type")
- }
- }
+private[spark] object PythonRDD {
- def readRDDFromPickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) :
+ def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
JavaRDD[Array[Byte]] = {
val file = new DataInputStream(new FileInputStream(filename))
val objs = new collection.mutable.ArrayBuffer[Array[Byte]]
@@ -265,41 +200,47 @@ private[spark] object PythonRDD {
}
} catch {
case eof: EOFException => {}
- case e => throw e
+ case e: Throwable => throw e
}
JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
}
- def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) {
+ def writeToStream(elem: Any, dataOut: DataOutputStream) {
+ elem match {
+ case bytes: Array[Byte] =>
+ dataOut.writeInt(bytes.length)
+ dataOut.write(bytes)
+ case pair: (Array[Byte], Array[Byte]) =>
+ dataOut.writeInt(pair._1.length)
+ dataOut.write(pair._1)
+ dataOut.writeInt(pair._2.length)
+ dataOut.write(pair._2)
+ case str: String =>
+ dataOut.writeUTF(str)
+ case other =>
+ throw new SparkException("Unexpected element type " + other.getClass)
+ }
+ }
+
+ def writeToFile[T](items: java.util.Iterator[T], filename: String) {
import scala.collection.JavaConverters._
- writeIteratorToPickleFile(items.asScala, filename)
+ writeToFile(items.asScala, filename)
}
- def writeIteratorToPickleFile[T](items: Iterator[T], filename: String) {
+ def writeToFile[T](items: Iterator[T], filename: String) {
val file = new DataOutputStream(new FileOutputStream(filename))
for (item <- items) {
- writeAsPickle(item, file)
+ writeToStream(item, file)
}
file.close()
}
def takePartition[T](rdd: RDD[T], partition: Int): Iterator[T] = {
- implicit val cm : ClassManifest[T] = rdd.elementClassManifest
+ implicit val cm : ClassTag[T] = rdd.elementClassTag
rdd.context.runJob(rdd, ((x: Iterator[T]) => x.toArray), Seq(partition), true).head.iterator
}
}
-private object Pickle {
- val PROTO: Byte = 0x80.toByte
- val TWO: Byte = 0x02.toByte
- val BINUNICODE: Byte = 'X'
- val STOP: Byte = '.'
- val TUPLE2: Byte = 0x86.toByte
- val EMPTY_LIST: Byte = ']'
- val MARK: Byte = '('
- val APPENDS: Byte = 'e'
-}
-
private class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] {
override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8")
}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
index 67d45723ba..f291266fcf 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala
@@ -64,7 +64,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
startDaemon()
new Socket(daemonHost, daemonPort)
}
- case e => throw e
+ case e: Throwable => throw e
}
}
}
@@ -198,7 +198,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String
}
}.start()
} catch {
- case e => {
+ case e: Throwable => {
stopDaemon()
throw e
}
diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
index 609464e38d..47db720416 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
@@ -19,6 +19,7 @@ package org.apache.spark.broadcast
import java.io.{File, FileOutputStream, ObjectInputStream, OutputStream}
import java.net.URL
+import java.util.concurrent.TimeUnit
import it.unimi.dsi.fastutil.io.FastBufferedInputStream
import it.unimi.dsi.fastutil.io.FastBufferedOutputStream
@@ -83,6 +84,8 @@ private object HttpBroadcast extends Logging {
private val files = new TimeStampedHashSet[String]
private val cleaner = new MetadataCleaner(MetadataCleanerType.HTTP_BROADCAST, cleanup)
+ private val httpReadTimeout = TimeUnit.MILLISECONDS.convert(5,TimeUnit.MINUTES).toInt
+
private lazy val compressionCodec = CompressionCodec.createCodec()
def initialize(isDriver: Boolean) {
@@ -138,10 +141,13 @@ private object HttpBroadcast extends Logging {
def read[T](id: Long): T = {
val url = serverUri + "/" + BroadcastBlockId(id).name
val in = {
+ val httpConnection = new URL(url).openConnection()
+ httpConnection.setReadTimeout(httpReadTimeout)
+ val inputStream = httpConnection.getInputStream()
if (compress) {
- compressionCodec.compressedInputStream(new URL(url).openStream())
+ compressionCodec.compressedInputStream(inputStream)
} else {
- new FastBufferedInputStream(new URL(url).openStream(), bufferSize)
+ new FastBufferedInputStream(inputStream, bufferSize)
}
}
val ser = SparkEnv.get.serializer.newInstance()
diff --git a/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala b/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala
index fcfea96ad6..37dfa7fec0 100644
--- a/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala
@@ -17,8 +17,7 @@
package org.apache.spark.deploy
-private[spark] object ExecutorState
- extends Enumeration("LAUNCHING", "LOADING", "RUNNING", "KILLED", "FAILED", "LOST") {
+private[spark] object ExecutorState extends Enumeration {
val LAUNCHING, LOADING, RUNNING, KILLED, FAILED, LOST = Value
diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala
index 668032a3a2..0aa8852649 100644
--- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala
@@ -1,19 +1,19 @@
/*
*
- * * Licensed to the Apache Software Foundation (ASF) under one or more
- * * contributor license agreements. See the NOTICE file distributed with
- * * this work for additional information regarding copyright ownership.
- * * The ASF licenses this file to You under the Apache License, Version 2.0
- * * (the "License"); you may not use this file except in compliance with
- * * the License. You may obtain a copy of the License at
- * *
- * * http://www.apache.org/licenses/LICENSE-2.0
- * *
- * * Unless required by applicable law or agreed to in writing, software
- * * distributed under the License is distributed on an "AS IS" BASIS,
- * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * * See the License for the specific language governing permissions and
- * * limitations under the License.
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
*
*/
diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
index a724900943..59d12a3e6f 100644
--- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala
@@ -34,11 +34,11 @@ import scala.collection.mutable.ArrayBuffer
*/
private[spark]
class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: Int) extends Logging {
-
+
private val localHostname = Utils.localHostName()
private val masterActorSystems = ArrayBuffer[ActorSystem]()
private val workerActorSystems = ArrayBuffer[ActorSystem]()
-
+
def start(): Array[String] = {
logInfo("Starting a local Spark cluster with " + numWorkers + " workers.")
@@ -61,10 +61,13 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I
def stop() {
logInfo("Shutting down local Spark cluster.")
// Stop the workers before the master so they don't get upset that it disconnected
+ // TODO: In Akka 2.1.x, ActorSystem.awaitTermination hangs when you have remote actors!
+ // This is unfortunate, but for now we just comment it out.
workerActorSystems.foreach(_.shutdown())
- workerActorSystems.foreach(_.awaitTermination())
-
+ //workerActorSystems.foreach(_.awaitTermination())
masterActorSystems.foreach(_.shutdown())
- masterActorSystems.foreach(_.awaitTermination())
+ //masterActorSystems.foreach(_.awaitTermination())
+ masterActorSystems.clear()
+ workerActorSystems.clear()
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/client/Client.scala b/core/src/main/scala/org/apache/spark/deploy/client/Client.scala
index 77422f61ec..4d95efa73a 100644
--- a/core/src/main/scala/org/apache/spark/deploy/client/Client.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/client/Client.scala
@@ -19,17 +19,15 @@ package org.apache.spark.deploy.client
import java.util.concurrent.TimeoutException
+import scala.concurrent.duration._
+import scala.concurrent.Await
+
import akka.actor._
-import akka.actor.Terminated
+import akka.pattern.AskTimeoutException
import akka.pattern.ask
-import akka.util.Duration
-import akka.util.duration._
-import akka.remote.RemoteClientDisconnected
-import akka.remote.RemoteClientLifeCycleEvent
-import akka.remote.RemoteClientShutdown
-import akka.dispatch.Await
-
-import org.apache.spark.Logging
+import akka.remote.{RemotingLifecycleEvent, DisassociatedEvent, AssociationErrorEvent}
+
+import org.apache.spark.{SparkException, Logging}
import org.apache.spark.deploy.{ApplicationDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.master.Master
@@ -51,18 +49,19 @@ private[spark] class Client(
val REGISTRATION_TIMEOUT = 20.seconds
val REGISTRATION_RETRIES = 3
+ var masterAddress: Address = null
var actor: ActorRef = null
var appId: String = null
var registered = false
var activeMasterUrl: String = null
class ClientActor extends Actor with Logging {
- var master: ActorRef = null
- var masterAddress: Address = null
+ var master: ActorSelection = null
var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times
var alreadyDead = false // To avoid calling listener.dead() multiple times
override def preStart() {
+ context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
try {
registerWithMaster()
} catch {
@@ -76,7 +75,7 @@ private[spark] class Client(
def tryRegisterAllMasters() {
for (masterUrl <- masterUrls) {
logInfo("Connecting to master " + masterUrl + "...")
- val actor = context.actorFor(Master.toAkkaUrl(masterUrl))
+ val actor = context.actorSelection(Master.toAkkaUrl(masterUrl))
actor ! RegisterApplication(appDescription)
}
}
@@ -84,6 +83,7 @@ private[spark] class Client(
def registerWithMaster() {
tryRegisterAllMasters()
+ import context.dispatcher
var retries = 0
lazy val retryTimer: Cancellable =
context.system.scheduler.schedule(REGISTRATION_TIMEOUT, REGISTRATION_TIMEOUT) {
@@ -102,10 +102,13 @@ private[spark] class Client(
def changeMaster(url: String) {
activeMasterUrl = url
- master = context.actorFor(Master.toAkkaUrl(url))
- masterAddress = master.path.address
- context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
- context.watch(master) // Doesn't work with remote actors, but useful for testing
+ master = context.actorSelection(Master.toAkkaUrl(activeMasterUrl))
+ masterAddress = activeMasterUrl match {
+ case Master.sparkUrlRegex(host, port) =>
+ Address("akka.tcp", Master.systemName, host, port.toInt)
+ case x =>
+ throw new SparkException("Invalid spark URL: " + x)
+ }
}
override def receive = {
@@ -135,21 +138,12 @@ private[spark] class Client(
case MasterChanged(masterUrl, masterWebUiUrl) =>
logInfo("Master has changed, new master is at " + masterUrl)
- context.unwatch(master)
changeMaster(masterUrl)
alreadyDisconnected = false
sender ! MasterChangeAcknowledged(appId)
- case Terminated(actor_) if actor_ == master =>
- logWarning("Connection to master failed; waiting for master to reconnect...")
- markDisconnected()
-
- case RemoteClientDisconnected(transport, address) if address == masterAddress =>
- logWarning("Connection to master failed; waiting for master to reconnect...")
- markDisconnected()
-
- case RemoteClientShutdown(transport, address) if address == masterAddress =>
- logWarning("Connection to master failed; waiting for master to reconnect...")
+ case DisassociatedEvent(_, address, _) if address == masterAddress =>
+ logWarning(s"Connection to $address failed; waiting for master to reconnect...")
markDisconnected()
case StopClient =>
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala
index fedf879eff..67e6c5d66a 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala
@@ -17,8 +17,7 @@
package org.apache.spark.deploy.master
-private[spark] object ApplicationState
- extends Enumeration("WAITING", "RUNNING", "FINISHED", "FAILED", "UNKNOWN") {
+private[spark] object ApplicationState extends Enumeration {
type ApplicationState = Value
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
index c0849ef324..043945a211 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala
@@ -65,7 +65,7 @@ private[spark] class FileSystemPersistenceEngine(
(apps, workers)
}
- private def serializeIntoFile(file: File, value: Serializable) {
+ private def serializeIntoFile(file: File, value: AnyRef) {
val created = file.createNewFile()
if (!created) { throw new IllegalStateException("Could not create file: " + file) }
@@ -77,13 +77,13 @@ private[spark] class FileSystemPersistenceEngine(
out.close()
}
- def deserializeFromFile[T <: Serializable](file: File)(implicit m: Manifest[T]): T = {
+ def deserializeFromFile[T](file: File)(implicit m: Manifest[T]): T = {
val fileData = new Array[Byte](file.length().asInstanceOf[Int])
val dis = new DataInputStream(new FileInputStream(file))
dis.readFully(fileData)
dis.close()
- val clazz = m.erasure.asInstanceOf[Class[T]]
+ val clazz = m.runtimeClass.asInstanceOf[Class[T]]
val serializer = serialization.serializerFor(clazz)
serializer.fromBinary(fileData).asInstanceOf[T]
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
index cd916672ac..c627dd3806 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -17,19 +17,20 @@
package org.apache.spark.deploy.master
-import java.util.Date
import java.text.SimpleDateFormat
+import java.util.concurrent.TimeUnit
+import java.util.Date
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
+import scala.concurrent.Await
+import scala.concurrent.duration._
+import scala.concurrent.duration.{Duration, FiniteDuration}
import akka.actor._
-import akka.actor.Terminated
-import akka.dispatch.Await
import akka.pattern.ask
-import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientDisconnected, RemoteClientShutdown}
+import akka.remote._
import akka.serialization.SerializationExtension
-import akka.util.duration._
-import akka.util.{Duration, Timeout}
+import akka.util.Timeout
import org.apache.spark.{Logging, SparkException}
import org.apache.spark.deploy.{ApplicationDescription, ExecutorState}
@@ -37,9 +38,11 @@ import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.master.MasterMessages._
import org.apache.spark.deploy.master.ui.MasterWebUI
import org.apache.spark.metrics.MetricsSystem
-import org.apache.spark.util.{AkkaUtils, Utils}
+import org.apache.spark.util.{Utils, AkkaUtils}
private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Actor with Logging {
+ import context.dispatcher
+
val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs
val WORKER_TIMEOUT = System.getProperty("spark.worker.timeout", "60").toLong * 1000
val RETAINED_APPLICATIONS = System.getProperty("spark.deploy.retainedApplications", "200").toInt
@@ -93,7 +96,7 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
override def preStart() {
logInfo("Starting Spark master at " + masterUrl)
// Listen for remote client disconnection events, since they don't go through Akka's watch()
- context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
+ context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
webUi.start()
masterWebUiUrl = "http://" + masterPublicAddress + ":" + webUi.boundPort.get
context.system.scheduler.schedule(0 millis, WORKER_TIMEOUT millis, self, CheckForWorkerTimeOut)
@@ -113,13 +116,12 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
new BlackHolePersistenceEngine()
}
- leaderElectionAgent = context.actorOf(Props(
- RECOVERY_MODE match {
+ leaderElectionAgent = RECOVERY_MODE match {
case "ZOOKEEPER" =>
- new ZooKeeperLeaderElectionAgent(self, masterUrl)
+ context.actorOf(Props(classOf[ZooKeeperLeaderElectionAgent], self, masterUrl))
case _ =>
- new MonarchyLeaderAgent(self)
- }))
+ context.actorOf(Props(classOf[MonarchyLeaderAgent], self))
+ }
}
override def preRestart(reason: Throwable, message: Option[Any]) {
@@ -142,9 +144,7 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
RecoveryState.ALIVE
else
RecoveryState.RECOVERING
-
logInfo("I have been elected leader! New state: " + state)
-
if (state == RecoveryState.RECOVERING) {
beginRecovery(storedApps, storedWorkers)
context.system.scheduler.scheduleOnce(WORKER_TIMEOUT millis) { completeRecovery() }
@@ -156,7 +156,7 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
System.exit(0)
}
- case RegisterWorker(id, host, workerPort, cores, memory, webUiPort, publicAddress) => {
+ case RegisterWorker(id, workerHost, workerPort, cores, memory, workerWebUiPort, publicAddress) => {
logInfo("Registering worker %s:%d with %d cores, %s RAM".format(
host, workerPort, cores, Utils.megabytesToString(memory)))
if (state == RecoveryState.STANDBY) {
@@ -164,9 +164,9 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
} else if (idToWorker.contains(id)) {
sender ! RegisterWorkerFailed("Duplicate worker ID")
} else {
- val worker = new WorkerInfo(id, host, port, cores, memory, sender, webUiPort, publicAddress)
+ val worker = new WorkerInfo(id, workerHost, workerPort, cores, memory,
+ sender, workerWebUiPort, publicAddress)
registerWorker(worker)
- context.watch(sender) // This doesn't work with remote actors but helps for testing
persistenceEngine.addWorker(worker)
sender ! RegisteredWorker(masterUrl, masterWebUiUrl)
schedule()
@@ -181,7 +181,6 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
val app = createApplication(description, sender)
registerApplication(app)
logInfo("Registered app " + description.name + " with ID " + app.id)
- context.watch(sender) // This doesn't work with remote actors but helps for testing
persistenceEngine.addApplication(app)
sender ! RegisteredApplication(app.id, masterUrl)
schedule()
@@ -257,23 +256,9 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
if (canCompleteRecovery) { completeRecovery() }
}
- case Terminated(actor) => {
- // The disconnected actor could've been either a worker or an app; remove whichever of
- // those we have an entry for in the corresponding actor hashmap
- actorToWorker.get(actor).foreach(removeWorker)
- actorToApp.get(actor).foreach(finishApplication)
- if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() }
- }
-
- case RemoteClientDisconnected(transport, address) => {
- // The disconnected client could've been either a worker or an app; remove whichever it was
- addressToWorker.get(address).foreach(removeWorker)
- addressToApp.get(address).foreach(finishApplication)
- if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() }
- }
-
- case RemoteClientShutdown(transport, address) => {
+ case DisassociatedEvent(_, address, _) => {
// The disconnected client could've been either a worker or an app; remove whichever it was
+ logInfo(s"$address got disassociated, removing it.")
addressToWorker.get(address).foreach(removeWorker)
addressToApp.get(address).foreach(finishApplication)
if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() }
@@ -530,9 +515,9 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act
}
private[spark] object Master {
- private val systemName = "sparkMaster"
+ val systemName = "sparkMaster"
private val actorName = "Master"
- private val sparkUrlRegex = "spark://([^:]+):([0-9]+)".r
+ val sparkUrlRegex = "spark://([^:]+):([0-9]+)".r
def main(argStrings: Array[String]) {
val args = new MasterArguments(argStrings)
@@ -540,11 +525,11 @@ private[spark] object Master {
actorSystem.awaitTermination()
}
- /** Returns an `akka://...` URL for the Master actor given a sparkUrl `spark://host:ip`. */
+ /** Returns an `akka.tcp://...` URL for the Master actor given a sparkUrl `spark://host:ip`. */
def toAkkaUrl(sparkUrl: String): String = {
sparkUrl match {
case sparkUrlRegex(host, port) =>
- "akka://%s@%s:%s/user/%s".format(systemName, host, port, actorName)
+ "akka.tcp://%s@%s:%s/user/%s".format(systemName, host, port, actorName)
case _ =>
throw new SparkException("Invalid master URL: " + sparkUrl)
}
@@ -552,9 +537,9 @@ private[spark] object Master {
def startSystemAndActor(host: String, port: Int, webUiPort: Int): (ActorSystem, Int, Int) = {
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port)
- val actor = actorSystem.actorOf(Props(new Master(host, boundPort, webUiPort)), name = actorName)
- val timeoutDuration = Duration.create(
- System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds")
+ val actor = actorSystem.actorOf(Props(classOf[Master], host, boundPort, webUiPort), name = actorName)
+ val timeoutDuration: FiniteDuration = Duration.create(
+ System.getProperty("spark.akka.askTimeout", "10").toLong, TimeUnit.SECONDS)
implicit val timeout = Timeout(timeoutDuration)
val respFuture = actor ? RequestWebUIPort // ask pattern
val resp = Await.result(respFuture, timeoutDuration).asInstanceOf[WebUIPortResponse]
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala
index b91be821f0..256a5a7c28 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala
@@ -17,9 +17,7 @@
package org.apache.spark.deploy.master
-private[spark] object RecoveryState
- extends Enumeration("STANDBY", "ALIVE", "RECOVERING", "COMPLETING_RECOVERY") {
-
+private[spark] object RecoveryState extends Enumeration {
type MasterState = Value
val STANDBY, ALIVE, RECOVERING, COMPLETING_RECOVERY = Value
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala
index c8d34f25e2..0b36ef6005 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala
@@ -17,9 +17,7 @@
package org.apache.spark.deploy.master
-private[spark] object WorkerState
- extends Enumeration("ALIVE", "DEAD", "DECOMMISSIONED", "UNKNOWN") {
-
+private[spark] object WorkerState extends Enumeration {
type WorkerState = Value
val ALIVE, DEAD, DECOMMISSIONED, UNKNOWN = Value
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
index a0233a7271..825344b3bb 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala
@@ -70,15 +70,15 @@ class ZooKeeperPersistenceEngine(serialization: Serialization)
(apps, workers)
}
- private def serializeIntoFile(path: String, value: Serializable) {
+ private def serializeIntoFile(path: String, value: AnyRef) {
val serializer = serialization.findSerializerFor(value)
val serialized = serializer.toBinary(value)
zk.create(path, serialized, CreateMode.PERSISTENT)
}
- def deserializeFromFile[T <: Serializable](filename: String)(implicit m: Manifest[T]): T = {
+ def deserializeFromFile[T](filename: String)(implicit m: Manifest[T]): T = {
val fileData = zk.getData("/spark/master_status/" + filename)
- val clazz = m.erasure.asInstanceOf[Class[T]]
+ val clazz = m.runtimeClass.asInstanceOf[Class[T]]
val serializer = serialization.serializerFor(clazz)
serializer.fromBinary(fileData).asInstanceOf[T]
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
index f4e574d15d..3b983c19eb 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala
@@ -19,9 +19,10 @@ package org.apache.spark.deploy.master.ui
import scala.xml.Node
-import akka.dispatch.Await
import akka.pattern.ask
-import akka.util.duration._
+
+import scala.concurrent.Await
+import scala.concurrent.duration._
import javax.servlet.http.HttpServletRequest
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala
index d7a57229b0..65e7a14e7a 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala
@@ -21,9 +21,9 @@ import javax.servlet.http.HttpServletRequest
import scala.xml.Node
-import akka.dispatch.Await
+import scala.concurrent.Await
import akka.pattern.ask
-import akka.util.duration._
+import scala.concurrent.duration._
import net.liftweb.json.JsonAST.JValue
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
index f4df729e87..a211ce2b42 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
@@ -17,7 +17,7 @@
package org.apache.spark.deploy.master.ui
-import akka.util.Duration
+import scala.concurrent.duration._
import javax.servlet.http.HttpServletRequest
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
index 216d9d44ac..87531b6719 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
@@ -17,23 +17,31 @@
package org.apache.spark.deploy.worker
+import java.io.File
import java.text.SimpleDateFormat
import java.util.Date
-import java.io.File
import scala.collection.mutable.HashMap
+import scala.concurrent.duration._
import akka.actor._
-import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientShutdown, RemoteClientDisconnected}
-import akka.util.duration._
+import akka.remote.{ DisassociatedEvent, RemotingLifecycleEvent}
-import org.apache.spark.Logging
+import org.apache.spark.{SparkException, Logging}
import org.apache.spark.deploy.{ExecutorDescription, ExecutorState}
import org.apache.spark.deploy.DeployMessages._
import org.apache.spark.deploy.master.Master
import org.apache.spark.deploy.worker.ui.WorkerWebUI
import org.apache.spark.metrics.MetricsSystem
import org.apache.spark.util.{Utils, AkkaUtils}
+import org.apache.spark.deploy.DeployMessages.WorkerStateResponse
+import org.apache.spark.deploy.DeployMessages.RegisterWorkerFailed
+import org.apache.spark.deploy.DeployMessages.KillExecutor
+import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged
+import org.apache.spark.deploy.DeployMessages.Heartbeat
+import org.apache.spark.deploy.DeployMessages.RegisteredWorker
+import org.apache.spark.deploy.DeployMessages.LaunchExecutor
+import org.apache.spark.deploy.DeployMessages.RegisterWorker
/**
* @param masterUrls Each url should look like spark://host:port.
@@ -47,6 +55,7 @@ private[spark] class Worker(
masterUrls: Array[String],
workDirPath: String = null)
extends Actor with Logging {
+ import context.dispatcher
Utils.checkHost(host, "Expected hostname")
assert (port > 0)
@@ -63,7 +72,8 @@ private[spark] class Worker(
var masterIndex = 0
val masterLock: Object = new Object()
- var master: ActorRef = null
+ var master: ActorSelection = null
+ var masterAddress: Address = null
var activeMasterUrl: String = ""
var activeMasterWebUiUrl : String = ""
@volatile var registered = false
@@ -114,7 +124,7 @@ private[spark] class Worker(
logInfo("Spark home: " + sparkHome)
createWorkDir()
webUi = new WorkerWebUI(this, workDir, Some(webUiPort))
-
+ context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
webUi.start()
registerWithMaster()
@@ -126,9 +136,13 @@ private[spark] class Worker(
masterLock.synchronized {
activeMasterUrl = url
activeMasterWebUiUrl = uiUrl
- master = context.actorFor(Master.toAkkaUrl(activeMasterUrl))
- context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
- context.watch(master) // Doesn't work with remote actors, but useful for testing
+ master = context.actorSelection(Master.toAkkaUrl(activeMasterUrl))
+ masterAddress = activeMasterUrl match {
+ case Master.sparkUrlRegex(_host, _port) =>
+ Address("akka.tcp", Master.systemName, _host, _port.toInt)
+ case x =>
+ throw new SparkException("Invalid spark URL: " + x)
+ }
connected = true
}
}
@@ -136,7 +150,7 @@ private[spark] class Worker(
def tryRegisterAllMasters() {
for (masterUrl <- masterUrls) {
logInfo("Connecting to master " + masterUrl + "...")
- val actor = context.actorFor(Master.toAkkaUrl(masterUrl))
+ val actor = context.actorSelection(Master.toAkkaUrl(masterUrl))
actor ! RegisterWorker(workerId, host, port, cores, memory, webUi.boundPort.get,
publicAddress)
}
@@ -175,7 +189,6 @@ private[spark] class Worker(
case MasterChanged(masterUrl, masterWebUiUrl) =>
logInfo("Master has changed, new master is at " + masterUrl)
- context.unwatch(master)
changeMaster(masterUrl, masterWebUiUrl)
val execs = executors.values.
@@ -234,13 +247,8 @@ private[spark] class Worker(
}
}
- case Terminated(actor_) if actor_ == master =>
- masterDisconnected()
-
- case RemoteClientDisconnected(transport, address) if address == master.path.address =>
- masterDisconnected()
-
- case RemoteClientShutdown(transport, address) if address == master.path.address =>
+ case x: DisassociatedEvent if x.remoteAddress == masterAddress =>
+ logInfo(s"$x Disassociated !")
masterDisconnected()
case RequestWorkerState => {
@@ -280,8 +288,8 @@ private[spark] object Worker {
// The LocalSparkCluster runs multiple local sparkWorkerX actor systems
val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("")
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port)
- val actor = actorSystem.actorOf(Props(new Worker(host, boundPort, webUiPort, cores, memory,
- masterUrls, workDir)), name = "Worker")
+ actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory,
+ masterUrls, workDir), name = "Worker")
(actorSystem, boundPort)
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala
index d2d3617498..1a768d501f 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala
@@ -21,9 +21,10 @@ import javax.servlet.http.HttpServletRequest
import scala.xml.Node
-import akka.dispatch.Await
+import scala.concurrent.duration._
+import scala.concurrent.Await
+
import akka.pattern.ask
-import akka.util.duration._
import net.liftweb.json.JsonAST.JValue
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
index 800f1cafcc..6c18a3c245 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
@@ -17,20 +17,19 @@
package org.apache.spark.deploy.worker.ui
-import akka.util.{Duration, Timeout}
+import java.io.File
-import java.io.{FileInputStream, File}
+import scala.concurrent.duration._
+import akka.util.Timeout
import javax.servlet.http.HttpServletRequest
-import org.eclipse.jetty.server.{Handler, Server}
-
+import org.apache.spark.Logging
import org.apache.spark.deploy.worker.Worker
-import org.apache.spark.{Logging}
-import org.apache.spark.ui.JettyUtils
+import org.apache.spark.ui.{JettyUtils, UIUtils}
import org.apache.spark.ui.JettyUtils._
-import org.apache.spark.ui.UIUtils
import org.apache.spark.util.Utils
+import org.eclipse.jetty.server.{Handler, Server}
/**
* Web UI server for the standalone worker.
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 8332631838..debbdd4c44 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -19,15 +19,14 @@ package org.apache.spark.executor
import java.nio.ByteBuffer
-import akka.actor.{ActorRef, Actor, Props, Terminated}
-import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientShutdown, RemoteClientDisconnected}
+import akka.actor._
+import akka.remote._
import org.apache.spark.Logging
import org.apache.spark.TaskState.TaskState
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
import org.apache.spark.util.{Utils, AkkaUtils}
-
private[spark] class CoarseGrainedExecutorBackend(
driverUrl: String,
executorId: String,
@@ -40,14 +39,13 @@ private[spark] class CoarseGrainedExecutorBackend(
Utils.checkHostPort(hostPort, "Expected hostport")
var executor: Executor = null
- var driver: ActorRef = null
+ var driver: ActorSelection = null
override def preStart() {
logInfo("Connecting to driver: " + driverUrl)
- driver = context.actorFor(driverUrl)
+ driver = context.actorSelection(driverUrl)
driver ! RegisterExecutor(executorId, hostPort, cores)
- context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
- context.watch(driver) // Doesn't work with remote actors, but useful for testing
+ context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
}
override def receive = {
@@ -77,8 +75,8 @@ private[spark] class CoarseGrainedExecutorBackend(
executor.killTask(taskId)
}
- case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) =>
- logError("Driver terminated or disconnected! Shutting down.")
+ case x: DisassociatedEvent =>
+ logError(s"Driver $x disassociated! Shutting down.")
System.exit(1)
case StopExecutor =>
@@ -99,12 +97,13 @@ private[spark] object CoarseGrainedExecutorBackend {
// Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor
// before getting started with all our system properties, etc
- val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0)
+ val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0,
+ indestructible = true)
// set it
val sparkHostPort = hostname + ":" + boundPort
System.setProperty("spark.hostPort", sparkHostPort)
- val actor = actorSystem.actorOf(
- Props(new CoarseGrainedExecutorBackend(driverUrl, executorId, sparkHostPort, cores)),
+ actorSystem.actorOf(
+ Props(classOf[CoarseGrainedExecutorBackend], driverUrl, executorId, sparkHostPort, cores),
name = "Executor")
actorSystem.awaitTermination()
}
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index 5c9bb9db1c..0b0a60ee60 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -121,7 +121,7 @@ private[spark] class Executor(
// Akka's message frame size. If task result is bigger than this, we use the block manager
// to send the result back.
private val akkaFrameSize = {
- env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size")
+ env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size")
}
// Start worker thread pool
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index 0b4892f98f..c0ce46e379 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -61,50 +61,53 @@ object TaskMetrics {
class ShuffleReadMetrics extends Serializable {
/**
- * Time when shuffle finishs
+ * Absolute time when this task finished reading shuffle data
*/
var shuffleFinishTime: Long = _
/**
- * Total number of blocks fetched in a shuffle (remote or local)
+ * Number of blocks fetched in this shuffle by this task (remote or local)
*/
var totalBlocksFetched: Int = _
/**
- * Number of remote blocks fetched in a shuffle
+ * Number of remote blocks fetched in this shuffle by this task
*/
var remoteBlocksFetched: Int = _
/**
- * Local blocks fetched in a shuffle
+ * Number of local blocks fetched in this shuffle by this task
*/
var localBlocksFetched: Int = _
/**
- * Total time that is spent blocked waiting for shuffle to fetch data
+ * Time the task spent waiting for remote shuffle blocks. This only includes the time
+ * blocking on shuffle input data. For instance if block B is being fetched while the task is
+ * still not finished processing block A, it is not considered to be blocking on block B.
*/
var fetchWaitTime: Long = _
/**
- * The total amount of time for all the shuffle fetches. This adds up time from overlapping
- * shuffles, so can be longer than task time
+ * Total time spent fetching remote shuffle blocks. This aggregates the time spent fetching all
+ * input blocks. Since block fetches are both pipelined and parallelized, this can
+ * exceed fetchWaitTime and executorRunTime.
*/
var remoteFetchTime: Long = _
/**
- * Total number of remote bytes read from a shuffle
+ * Total number of remote bytes read from the shuffle by this task
*/
var remoteBytesRead: Long = _
}
class ShuffleWriteMetrics extends Serializable {
/**
- * Number of bytes written for a shuffle
+ * Number of bytes written for the shuffle by this task
*/
var shuffleBytesWritten: Long = _
/**
- * Time spent blocking on writes to disk or buffer cache, in nanoseconds.
+ * Time the task spent blocking on writes to disk or buffer cache, in nanoseconds
*/
var shuffleWriteTime: Long = _
}
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala
new file mode 100644
index 0000000000..cdcfec8ca7
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.metrics.sink
+
+import java.util.Properties
+import java.util.concurrent.TimeUnit
+import java.net.InetSocketAddress
+
+import com.codahale.metrics.MetricRegistry
+import com.codahale.metrics.graphite.{GraphiteReporter, Graphite}
+
+import org.apache.spark.metrics.MetricsSystem
+
+class GraphiteSink(val property: Properties, val registry: MetricRegistry) extends Sink {
+ val GRAPHITE_DEFAULT_PERIOD = 10
+ val GRAPHITE_DEFAULT_UNIT = "SECONDS"
+ val GRAPHITE_DEFAULT_PREFIX = ""
+
+ val GRAPHITE_KEY_HOST = "host"
+ val GRAPHITE_KEY_PORT = "port"
+ val GRAPHITE_KEY_PERIOD = "period"
+ val GRAPHITE_KEY_UNIT = "unit"
+ val GRAPHITE_KEY_PREFIX = "prefix"
+
+ def propertyToOption(prop: String) = Option(property.getProperty(prop))
+
+ if (!propertyToOption(GRAPHITE_KEY_HOST).isDefined) {
+ throw new Exception("Graphite sink requires 'host' property.")
+ }
+
+ if (!propertyToOption(GRAPHITE_KEY_PORT).isDefined) {
+ throw new Exception("Graphite sink requires 'port' property.")
+ }
+
+ val host = propertyToOption(GRAPHITE_KEY_HOST).get
+ val port = propertyToOption(GRAPHITE_KEY_PORT).get.toInt
+
+ val pollPeriod = propertyToOption(GRAPHITE_KEY_PERIOD) match {
+ case Some(s) => s.toInt
+ case None => GRAPHITE_DEFAULT_PERIOD
+ }
+
+ val pollUnit = propertyToOption(GRAPHITE_KEY_UNIT) match {
+ case Some(s) => TimeUnit.valueOf(s.toUpperCase())
+ case None => TimeUnit.valueOf(GRAPHITE_DEFAULT_UNIT)
+ }
+
+ val prefix = propertyToOption(GRAPHITE_KEY_PREFIX).getOrElse(GRAPHITE_DEFAULT_PREFIX)
+
+ MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod)
+
+ val graphite: Graphite = new Graphite(new InetSocketAddress(host, port))
+
+ val reporter: GraphiteReporter = GraphiteReporter.forRegistry(registry)
+ .convertDurationsTo(TimeUnit.MILLISECONDS)
+ .convertRatesTo(TimeUnit.SECONDS)
+ .prefixedWith(prefix)
+ .build(graphite)
+
+ override def start() {
+ reporter.start(pollPeriod, pollUnit)
+ }
+
+ override def stop() {
+ reporter.stop()
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
index 9c2fee4023..703bc6a9ca 100644
--- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala
@@ -31,11 +31,11 @@ import scala.collection.mutable.SynchronizedMap
import scala.collection.mutable.SynchronizedQueue
import scala.collection.mutable.ArrayBuffer
-import akka.dispatch.{Await, Promise, ExecutionContext, Future}
-import akka.util.Duration
-import akka.util.duration._
-import org.apache.spark.util.Utils
+import scala.concurrent.{Await, Promise, ExecutionContext, Future}
+import scala.concurrent.duration.Duration
+import scala.concurrent.duration._
+import org.apache.spark.util.Utils
private[spark] class ConnectionManager(port: Int) extends Logging {
diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala
index 8d9ad9604d..4f5742d29b 100644
--- a/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala
+++ b/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala
@@ -25,8 +25,8 @@ import scala.io.Source
import java.nio.ByteBuffer
import java.net.InetAddress
-import akka.dispatch.Await
-import akka.util.duration._
+import scala.concurrent.Await
+import scala.concurrent.duration._
private[spark] object ConnectionManagerTest extends Logging{
def main(args: Array[String]) {
diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala
index 481ff8c3e0..b1e1576dad 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala
@@ -76,7 +76,7 @@ private[spark] object ShuffleCopier extends Logging {
extends FileClientHandler with Logging {
override def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) {
- logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)");
+ logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)")
resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen))
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
index faaf837be0..d1c74a5063 100644
--- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala
@@ -21,6 +21,7 @@ import java.util.concurrent.atomic.AtomicLong
import scala.collection.mutable.ArrayBuffer
import scala.concurrent.ExecutionContext.Implicits.global
+import scala.reflect.ClassTag
import org.apache.spark.{ComplexFutureAction, FutureAction, Logging}
@@ -28,7 +29,7 @@ import org.apache.spark.{ComplexFutureAction, FutureAction, Logging}
* A set of asynchronous RDD actions available through an implicit conversion.
* Import `org.apache.spark.SparkContext._` at the top of your program to use these functions.
*/
-class AsyncRDDActions[T: ClassManifest](self: RDD[T]) extends Serializable with Logging {
+class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Logging {
/**
* Returns a future for counting the number of elements in the RDD.
diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
index 44ea573a7c..424354ae16 100644
--- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala
@@ -17,6 +17,8 @@
package org.apache.spark.rdd
+import scala.reflect.ClassTag
+
import org.apache.spark.{SparkContext, SparkEnv, Partition, TaskContext}
import org.apache.spark.storage.{BlockId, BlockManager}
@@ -25,7 +27,7 @@ private[spark] class BlockRDDPartition(val blockId: BlockId, idx: Int) extends P
}
private[spark]
-class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[BlockId])
+class BlockRDD[T: ClassTag](sc: SparkContext, @transient blockIds: Array[BlockId])
extends RDD[T](sc, Nil) {
@transient lazy val locations_ = BlockManager.blockIdsToHosts(blockIds, SparkEnv.get)
diff --git a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala
index 9b0c882481..87b950ba43 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala
@@ -18,6 +18,7 @@
package org.apache.spark.rdd
import java.io.{ObjectOutputStream, IOException}
+import scala.reflect.ClassTag
import org.apache.spark._
@@ -43,7 +44,7 @@ class CartesianPartition(
}
private[spark]
-class CartesianRDD[T: ClassManifest, U:ClassManifest](
+class CartesianRDD[T: ClassTag, U: ClassTag](
sc: SparkContext,
var rdd1 : RDD[T],
var rdd2 : RDD[U])
@@ -70,7 +71,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest](
override def compute(split: Partition, context: TaskContext) = {
val currSplit = split.asInstanceOf[CartesianPartition]
for (x <- rdd1.iterator(currSplit.s1, context);
- y <- rdd2.iterator(currSplit.s2, context)) yield (x, y)
+ y <- rdd2.iterator(currSplit.s2, context)) yield (x, y)
}
override def getDependencies: Seq[Dependency[_]] = List(
diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
index d3033ea4a6..a712ef1c27 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala
@@ -17,15 +17,13 @@
package org.apache.spark.rdd
+import java.io.IOException
+
+import scala.reflect.ClassTag
+
+import org.apache.hadoop.fs.Path
import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
-import org.apache.hadoop.mapred.{FileInputFormat, SequenceFileInputFormat, JobConf, Reporter}
-import org.apache.hadoop.conf.Configuration
-import org.apache.hadoop.io.{NullWritable, BytesWritable}
-import org.apache.hadoop.util.ReflectionUtils
-import org.apache.hadoop.fs.Path
-import java.io.{File, IOException, EOFException}
-import java.text.NumberFormat
private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {}
@@ -33,7 +31,7 @@ private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {}
* This RDD represents a RDD checkpoint file (similar to HadoopRDD).
*/
private[spark]
-class CheckpointRDD[T: ClassManifest](sc: SparkContext, val checkpointPath: String)
+class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String)
extends RDD[T](sc, Nil) {
@transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration)
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
index c5de6362a9..98da35763b 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
@@ -22,6 +22,7 @@ import java.io.{ObjectOutputStream, IOException}
import scala.collection.mutable
import scala.Some
import scala.collection.mutable.ArrayBuffer
+import scala.reflect.ClassTag
/**
* Class that captures a coalesced RDD by essentially keeping track of parent partitions
@@ -68,7 +69,7 @@ case class CoalescedRDDPartition(
* @param maxPartitions number of desired partitions in the coalesced RDD
* @param balanceSlack used to trade-off balance and locality. 1.0 is all locality, 0 is all balance
*/
-class CoalescedRDD[T: ClassManifest](
+class CoalescedRDD[T: ClassTag](
@transient var prev: RDD[T],
maxPartitions: Int,
balanceSlack: Double = 0.10)
diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
index a4bec41752..688c310ee9 100644
--- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
@@ -24,6 +24,8 @@ import org.apache.spark.partial.SumEvaluator
import org.apache.spark.util.StatCounter
import org.apache.spark.{TaskContext, Logging}
+import scala.collection.immutable.NumericRange
+
/**
* Extra functions available on RDDs of Doubles through an implicit conversion.
* Import `org.apache.spark.SparkContext._` at the top of your program to use these functions.
@@ -76,4 +78,129 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
val evaluator = new SumEvaluator(self.partitions.size, confidence)
self.context.runApproximateJob(self, processPartition, evaluator, timeout)
}
+
+ /**
+ * Compute a histogram of the data using bucketCount number of buckets evenly
+ * spaced between the minimum and maximum of the RDD. For example if the min
+ * value is 0 and the max is 100 and there are two buckets the resulting
+ * buckets will be [0, 50) [50, 100]. bucketCount must be at least 1
+ * If the RDD contains infinity, NaN throws an exception
+ * If the elements in RDD do not vary (max == min) always returns a single bucket.
+ */
+ def histogram(bucketCount: Int): Pair[Array[Double], Array[Long]] = {
+ // Compute the minimum and the maxium
+ val (max: Double, min: Double) = self.mapPartitions { items =>
+ Iterator(items.foldRight(Double.NegativeInfinity,
+ Double.PositiveInfinity)((e: Double, x: Pair[Double, Double]) =>
+ (x._1.max(e), x._2.min(e))))
+ }.reduce { (maxmin1, maxmin2) =>
+ (maxmin1._1.max(maxmin2._1), maxmin1._2.min(maxmin2._2))
+ }
+ if (min.isNaN || max.isNaN || max.isInfinity || min.isInfinity ) {
+ throw new UnsupportedOperationException(
+ "Histogram on either an empty RDD or RDD containing +/-infinity or NaN")
+ }
+ val increment = (max-min)/bucketCount.toDouble
+ val range = if (increment != 0) {
+ Range.Double.inclusive(min, max, increment)
+ } else {
+ List(min, min)
+ }
+ val buckets = range.toArray
+ (buckets, histogram(buckets, true))
+ }
+
+ /**
+ * Compute a histogram using the provided buckets. The buckets are all open
+ * to the left except for the last which is closed
+ * e.g. for the array
+ * [1, 10, 20, 50] the buckets are [1, 10) [10, 20) [20, 50]
+ * e.g 1<=x<10 , 10<=x<20, 20<=x<50
+ * And on the input of 1 and 50 we would have a histogram of 1, 0, 0
+ *
+ * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched
+ * from an O(log n) inseration to O(1) per element. (where n = # buckets) if you set evenBuckets
+ * to true.
+ * buckets must be sorted and not contain any duplicates.
+ * buckets array must be at least two elements
+ * All NaN entries are treated the same. If you have a NaN bucket it must be
+ * the maximum value of the last position and all NaN entries will be counted
+ * in that bucket.
+ */
+ def histogram(buckets: Array[Double], evenBuckets: Boolean = false): Array[Long] = {
+ if (buckets.length < 2) {
+ throw new IllegalArgumentException("buckets array must have at least two elements")
+ }
+ // The histogramPartition function computes the partail histogram for a given
+ // partition. The provided bucketFunction determines which bucket in the array
+ // to increment or returns None if there is no bucket. This is done so we can
+ // specialize for uniformly distributed buckets and save the O(log n) binary
+ // search cost.
+ def histogramPartition(bucketFunction: (Double) => Option[Int])(iter: Iterator[Double]):
+ Iterator[Array[Long]] = {
+ val counters = new Array[Long](buckets.length - 1)
+ while (iter.hasNext) {
+ bucketFunction(iter.next()) match {
+ case Some(x: Int) => {counters(x) += 1}
+ case _ => {}
+ }
+ }
+ Iterator(counters)
+ }
+ // Merge the counters.
+ def mergeCounters(a1: Array[Long], a2: Array[Long]): Array[Long] = {
+ a1.indices.foreach(i => a1(i) += a2(i))
+ a1
+ }
+ // Basic bucket function. This works using Java's built in Array
+ // binary search. Takes log(size(buckets))
+ def basicBucketFunction(e: Double): Option[Int] = {
+ val location = java.util.Arrays.binarySearch(buckets, e)
+ if (location < 0) {
+ // If the location is less than 0 then the insertion point in the array
+ // to keep it sorted is -location-1
+ val insertionPoint = -location-1
+ // If we have to insert before the first element or after the last one
+ // its out of bounds.
+ // We do this rather than buckets.lengthCompare(insertionPoint)
+ // because Array[Double] fails to override it (for now).
+ if (insertionPoint > 0 && insertionPoint < buckets.length) {
+ Some(insertionPoint-1)
+ } else {
+ None
+ }
+ } else if (location < buckets.length - 1) {
+ // Exact match, just insert here
+ Some(location)
+ } else {
+ // Exact match to the last element
+ Some(location - 1)
+ }
+ }
+ // Determine the bucket function in constant time. Requires that buckets are evenly spaced
+ def fastBucketFunction(min: Double, increment: Double, count: Int)(e: Double): Option[Int] = {
+ // If our input is not a number unless the increment is also NaN then we fail fast
+ if (e.isNaN()) {
+ return None
+ }
+ val bucketNumber = (e - min)/(increment)
+ // We do this rather than buckets.lengthCompare(bucketNumber)
+ // because Array[Double] fails to override it (for now).
+ if (bucketNumber > count || bucketNumber < 0) {
+ None
+ } else {
+ Some(bucketNumber.toInt.min(count - 1))
+ }
+ }
+ // Decide which bucket function to pass to histogramPartition. We decide here
+ // rather than having a general function so that the decission need only be made
+ // once rather than once per shard
+ val bucketFunction = if (evenBuckets) {
+ fastBucketFunction(buckets(0), buckets(1)-buckets(0), buckets.length-1) _
+ } else {
+ basicBucketFunction _
+ }
+ self.mapPartitions(histogramPartition(bucketFunction)).reduce(mergeCounters)
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala b/core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala
index c8900d1a93..a84e5f9fd8 100644
--- a/core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala
@@ -17,13 +17,14 @@
package org.apache.spark.rdd
-import org.apache.spark.{SparkContext, SparkEnv, Partition, TaskContext}
+import scala.reflect.ClassTag
+import org.apache.spark.{Partition, SparkContext, TaskContext}
/**
* An RDD that is empty, i.e. has no element in it.
*/
-class EmptyRDD[T: ClassManifest](sc: SparkContext) extends RDD[T](sc, Nil) {
+class EmptyRDD[T: ClassTag](sc: SparkContext) extends RDD[T](sc, Nil) {
override def getPartitions: Array[Partition] = Array.empty
diff --git a/core/src/main/scala/org/apache/spark/rdd/FilteredRDD.scala b/core/src/main/scala/org/apache/spark/rdd/FilteredRDD.scala
index 5312dc0b59..e74c83b90b 100644
--- a/core/src/main/scala/org/apache/spark/rdd/FilteredRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/FilteredRDD.scala
@@ -18,8 +18,9 @@
package org.apache.spark.rdd
import org.apache.spark.{OneToOneDependency, Partition, TaskContext}
+import scala.reflect.ClassTag
-private[spark] class FilteredRDD[T: ClassManifest](
+private[spark] class FilteredRDD[T: ClassTag](
prev: RDD[T],
f: T => Boolean)
extends RDD[T](prev) {
diff --git a/core/src/main/scala/org/apache/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/FlatMappedRDD.scala
index cbdf6d84c0..4d1878fc14 100644
--- a/core/src/main/scala/org/apache/spark/rdd/FlatMappedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/FlatMappedRDD.scala
@@ -18,10 +18,11 @@
package org.apache.spark.rdd
import org.apache.spark.{Partition, TaskContext}
+import scala.reflect.ClassTag
private[spark]
-class FlatMappedRDD[U: ClassManifest, T: ClassManifest](
+class FlatMappedRDD[U: ClassTag, T: ClassTag](
prev: RDD[T],
f: T => TraversableOnce[U])
extends RDD[U](prev) {
diff --git a/core/src/main/scala/org/apache/spark/rdd/GlommedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/GlommedRDD.scala
index 829545d7b0..1a694475f6 100644
--- a/core/src/main/scala/org/apache/spark/rdd/GlommedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/GlommedRDD.scala
@@ -18,8 +18,9 @@
package org.apache.spark.rdd
import org.apache.spark.{Partition, TaskContext}
+import scala.reflect.ClassTag
-private[spark] class GlommedRDD[T: ClassManifest](prev: RDD[T])
+private[spark] class GlommedRDD[T: ClassTag](prev: RDD[T])
extends RDD[Array[T]](prev) {
override def getPartitions: Array[Partition] = firstParent[T].partitions
diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
index aca0146884..8df8718f3b 100644
--- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
@@ -19,6 +19,8 @@ package org.apache.spark.rdd
import java.sql.{Connection, ResultSet}
+import scala.reflect.ClassTag
+
import org.apache.spark.{Logging, Partition, SparkContext, TaskContext}
import org.apache.spark.util.NextIterator
@@ -45,7 +47,7 @@ private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) e
* This should only call getInt, getString, etc; the RDD takes care of calling next.
* The default maps a ResultSet to an array of Object.
*/
-class JdbcRDD[T: ClassManifest](
+class JdbcRDD[T: ClassTag](
sc: SparkContext,
getConnection: () => Connection,
sql: String,
diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
index 203179c4ea..db15baf503 100644
--- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
@@ -18,20 +18,18 @@
package org.apache.spark.rdd
import org.apache.spark.{Partition, TaskContext}
+import scala.reflect.ClassTag
-
-private[spark]
-class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
+private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](
prev: RDD[T],
- f: Iterator[T] => Iterator[U],
+ f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator)
preservesPartitioning: Boolean = false)
extends RDD[U](prev) {
- override val partitioner =
- if (preservesPartitioning) firstParent[T].partitioner else None
+ override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None
override def getPartitions: Array[Partition] = firstParent[T].partitions
override def compute(split: Partition, context: TaskContext) =
- f(firstParent[T].iterator(split, context))
+ f(context, split.index, firstParent[T].iterator(split, context))
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala
deleted file mode 100644
index aea08ff81b..0000000000
--- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala
+++ /dev/null
@@ -1,41 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.rdd
-
-import org.apache.spark.{Partition, TaskContext}
-
-
-/**
- * A variant of the MapPartitionsRDD that passes the TaskContext into the closure. From the
- * TaskContext, the closure can either get access to the interruptible flag or get the index
- * of the partition in the RDD.
- */
-private[spark]
-class MapPartitionsWithContextRDD[U: ClassManifest, T: ClassManifest](
- prev: RDD[T],
- f: (TaskContext, Iterator[T]) => Iterator[U],
- preservesPartitioning: Boolean
- ) extends RDD[U](prev) {
-
- override def getPartitions: Array[Partition] = firstParent[T].partitions
-
- override val partitioner = if (preservesPartitioning) prev.partitioner else None
-
- override def compute(split: Partition, context: TaskContext) =
- f(context, firstParent[T].iterator(split, context))
-}
diff --git a/core/src/main/scala/org/apache/spark/rdd/MappedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MappedRDD.scala
index e8be1c4816..8d7c288593 100644
--- a/core/src/main/scala/org/apache/spark/rdd/MappedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/MappedRDD.scala
@@ -17,10 +17,12 @@
package org.apache.spark.rdd
+import scala.reflect.ClassTag
+
import org.apache.spark.{Partition, TaskContext}
private[spark]
-class MappedRDD[U: ClassManifest, T: ClassManifest](prev: RDD[T], f: T => U)
+class MappedRDD[U: ClassTag, T: ClassTag](prev: RDD[T], f: T => U)
extends RDD[U](prev) {
override def getPartitions: Array[Partition] = firstParent[T].partitions
diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
index 697be8b997..d5691f2267 100644
--- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala
@@ -17,7 +17,9 @@
package org.apache.spark.rdd
-import org.apache.spark.{RangePartitioner, Logging}
+import scala.reflect.ClassTag
+
+import org.apache.spark.{Logging, RangePartitioner}
/**
* Extra functions available on RDDs of (key, value) pairs where the key is sortable through
@@ -25,9 +27,9 @@ import org.apache.spark.{RangePartitioner, Logging}
* use these functions. They will work with any key type that has a `scala.math.Ordered`
* implementation.
*/
-class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest,
- V: ClassManifest,
- P <: Product2[K, V] : ClassManifest](
+class OrderedRDDFunctions[K <% Ordered[K]: ClassTag,
+ V: ClassTag,
+ P <: Product2[K, V] : ClassTag](
self: RDD[P])
extends Logging with Serializable {
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index 93b78e1232..48168e152e 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -25,6 +25,7 @@ import java.util.{HashMap => JHashMap}
import scala.collection.{mutable, Map}
import scala.collection.mutable.ArrayBuffer
import scala.collection.JavaConversions._
+import scala.reflect.{ClassTag, classTag}
import org.apache.hadoop.mapred._
import org.apache.hadoop.io.compress.CompressionCodec
@@ -50,7 +51,7 @@ import org.apache.spark.Partitioner.defaultPartitioner
* Extra functions available on RDDs of (key, value) pairs through an implicit conversion.
* Import `org.apache.spark.SparkContext._` at the top of your program to use these functions.
*/
-class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)])
+class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)])
extends Logging
with SparkHadoopMapReduceUtil
with Serializable {
@@ -415,7 +416,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)])
throw new SparkException("Default partitioner cannot partition array keys.")
}
val cg = new CoGroupedRDD[K](Seq(self, other), partitioner)
- val prfs = new PairRDDFunctions[K, Seq[Seq[_]]](cg)(classManifest[K], Manifests.seqSeqManifest)
+ val prfs = new PairRDDFunctions[K, Seq[Seq[_]]](cg)(classTag[K], ClassTags.seqSeqClassTag)
prfs.mapValues { case Seq(vs, ws) =>
(vs.asInstanceOf[Seq[V]], ws.asInstanceOf[Seq[W]])
}
@@ -431,7 +432,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)])
throw new SparkException("Default partitioner cannot partition array keys.")
}
val cg = new CoGroupedRDD[K](Seq(self, other1, other2), partitioner)
- val prfs = new PairRDDFunctions[K, Seq[Seq[_]]](cg)(classManifest[K], Manifests.seqSeqManifest)
+ val prfs = new PairRDDFunctions[K, Seq[Seq[_]]](cg)(classTag[K], ClassTags.seqSeqClassTag)
prfs.mapValues { case Seq(vs, w1s, w2s) =>
(vs.asInstanceOf[Seq[V]], w1s.asInstanceOf[Seq[W1]], w2s.asInstanceOf[Seq[W2]])
}
@@ -488,15 +489,15 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)])
* Uses `this` partitioner/partition size, because even if `other` is huge, the resulting
* RDD will be <= us.
*/
- def subtractByKey[W: ClassManifest](other: RDD[(K, W)]): RDD[(K, V)] =
+ def subtractByKey[W: ClassTag](other: RDD[(K, W)]): RDD[(K, V)] =
subtractByKey(other, self.partitioner.getOrElse(new HashPartitioner(self.partitions.size)))
/** Return an RDD with the pairs from `this` whose keys are not in `other`. */
- def subtractByKey[W: ClassManifest](other: RDD[(K, W)], numPartitions: Int): RDD[(K, V)] =
+ def subtractByKey[W: ClassTag](other: RDD[(K, W)], numPartitions: Int): RDD[(K, V)] =
subtractByKey(other, new HashPartitioner(numPartitions))
/** Return an RDD with the pairs from `this` whose keys are not in `other`. */
- def subtractByKey[W: ClassManifest](other: RDD[(K, W)], p: Partitioner): RDD[(K, V)] =
+ def subtractByKey[W: ClassTag](other: RDD[(K, W)], p: Partitioner): RDD[(K, V)] =
new SubtractedRDD[K, V, W](self, other, p)
/**
@@ -525,8 +526,8 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)])
* Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class
* supporting the key and value types K and V in this RDD.
*/
- def saveAsHadoopFile[F <: OutputFormat[K, V]](path: String)(implicit fm: ClassManifest[F]) {
- saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]])
+ def saveAsHadoopFile[F <: OutputFormat[K, V]](path: String)(implicit fm: ClassTag[F]) {
+ saveAsHadoopFile(path, getKeyClass, getValueClass, fm.runtimeClass.asInstanceOf[Class[F]])
}
/**
@@ -535,16 +536,16 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)])
* supplied codec.
*/
def saveAsHadoopFile[F <: OutputFormat[K, V]](
- path: String, codec: Class[_ <: CompressionCodec]) (implicit fm: ClassManifest[F]) {
- saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]], codec)
+ path: String, codec: Class[_ <: CompressionCodec]) (implicit fm: ClassTag[F]) {
+ saveAsHadoopFile(path, getKeyClass, getValueClass, fm.runtimeClass.asInstanceOf[Class[F]], codec)
}
/**
* Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat`
* (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD.
*/
- def saveAsNewAPIHadoopFile[F <: NewOutputFormat[K, V]](path: String)(implicit fm: ClassManifest[F]) {
- saveAsNewAPIHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]])
+ def saveAsNewAPIHadoopFile[F <: NewOutputFormat[K, V]](path: String)(implicit fm: ClassTag[F]) {
+ saveAsNewAPIHadoopFile(path, getKeyClass, getValueClass, fm.runtimeClass.asInstanceOf[Class[F]])
}
/**
@@ -698,11 +699,11 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)])
*/
def values: RDD[V] = self.map(_._2)
- private[spark] def getKeyClass() = implicitly[ClassManifest[K]].erasure
+ private[spark] def getKeyClass() = implicitly[ClassTag[K]].runtimeClass
- private[spark] def getValueClass() = implicitly[ClassManifest[V]].erasure
+ private[spark] def getValueClass() = implicitly[ClassTag[V]].runtimeClass
}
-private[spark] object Manifests {
- val seqSeqManifest = classManifest[Seq[Seq[_]]]
+private[spark] object ClassTags {
+ val seqSeqClassTag = classTag[Seq[Seq[_]]]
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
index cd96250389..09d0a8189d 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
@@ -20,13 +20,15 @@ package org.apache.spark.rdd
import scala.collection.immutable.NumericRange
import scala.collection.mutable.ArrayBuffer
import scala.collection.Map
+import scala.reflect.ClassTag
+
import org.apache.spark._
import java.io._
import scala.Serializable
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.util.Utils
-private[spark] class ParallelCollectionPartition[T: ClassManifest](
+private[spark] class ParallelCollectionPartition[T: ClassTag](
var rddId: Long,
var slice: Int,
var values: Seq[T])
@@ -78,7 +80,7 @@ private[spark] class ParallelCollectionPartition[T: ClassManifest](
}
}
-private[spark] class ParallelCollectionRDD[T: ClassManifest](
+private[spark] class ParallelCollectionRDD[T: ClassTag](
@transient sc: SparkContext,
@transient data: Seq[T],
numSlices: Int,
@@ -109,7 +111,7 @@ private object ParallelCollectionRDD {
* collections specially, encoding the slices as other Ranges to minimize memory cost. This makes
* it efficient to run Spark over RDDs representing large sets of numbers.
*/
- def slice[T: ClassManifest](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = {
+ def slice[T: ClassTag](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = {
if (numSlices < 1) {
throw new IllegalArgumentException("Positive number of slices required")
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala
index 165cd412fc..ea8885b36e 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala
@@ -17,6 +17,8 @@
package org.apache.spark.rdd
+import scala.reflect.ClassTag
+
import org.apache.spark.{NarrowDependency, SparkEnv, Partition, TaskContext}
@@ -33,11 +35,13 @@ class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boo
extends NarrowDependency[T](rdd) {
@transient
- val partitions: Array[Partition] = rdd.partitions.zipWithIndex
- .filter(s => partitionFilterFunc(s._2))
+ val partitions: Array[Partition] = rdd.partitions
+ .filter(s => partitionFilterFunc(s.index)).zipWithIndex
.map { case(split, idx) => new PartitionPruningRDDPartition(idx, split) : Partition }
- override def getParents(partitionId: Int) = List(partitions(partitionId).index)
+ override def getParents(partitionId: Int) = {
+ List(partitions(partitionId).asInstanceOf[PartitionPruningRDDPartition].parentSplit.index)
+ }
}
@@ -47,7 +51,7 @@ class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boo
* and the execution DAG has a filter on the key, we can avoid launching tasks
* on partitions that don't have the range covering the key.
*/
-class PartitionPruningRDD[T: ClassManifest](
+class PartitionPruningRDD[T: ClassTag](
@transient prev: RDD[T],
@transient partitionFilterFunc: Int => Boolean)
extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) {
@@ -67,6 +71,6 @@ object PartitionPruningRDD {
* when its type T is not known at compile time.
*/
def create[T](rdd: RDD[T], partitionFilterFunc: Int => Boolean) = {
- new PartitionPruningRDD[T](rdd, partitionFilterFunc)(rdd.elementClassManifest)
+ new PartitionPruningRDD[T](rdd, partitionFilterFunc)(rdd.elementClassTag)
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
index d5304ab0ae..1dbbe39898 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
@@ -24,6 +24,7 @@ import scala.collection.Map
import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
import scala.io.Source
+import scala.reflect.ClassTag
import org.apache.spark.{SparkEnv, Partition, TaskContext}
import org.apache.spark.broadcast.Broadcast
@@ -33,7 +34,7 @@ import org.apache.spark.broadcast.Broadcast
* An RDD that pipes the contents of each parent partition through an external command
* (printing them one per line) and returns the output as a collection of strings.
*/
-class PipedRDD[T: ClassManifest](
+class PipedRDD[T: ClassTag](
prev: RDD[T],
command: Seq[String],
envVars: Map[String, String],
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 6e88be6f6a..ea45566ad1 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -23,6 +23,9 @@ import scala.collection.Map
import scala.collection.JavaConversions.mapAsScalaMap
import scala.collection.mutable.ArrayBuffer
+import scala.collection.mutable.HashMap
+import scala.reflect.{classTag, ClassTag}
+
import org.apache.hadoop.io.BytesWritable
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.io.NullWritable
@@ -69,7 +72,7 @@ import org.apache.spark._
* [[http://www.cs.berkeley.edu/~matei/papers/2012/nsdi_spark.pdf Spark paper]] for more details
* on RDD internals.
*/
-abstract class RDD[T: ClassManifest](
+abstract class RDD[T: ClassTag](
@transient private var sc: SparkContext,
@transient private var deps: Seq[Dependency[_]]
) extends Serializable with Logging {
@@ -101,7 +104,7 @@ abstract class RDD[T: ClassManifest](
protected def getPreferredLocations(split: Partition): Seq[String] = Nil
/** Optionally overridden by subclasses to specify how they are partitioned. */
- val partitioner: Option[Partitioner] = None
+ @transient val partitioner: Option[Partitioner] = None
// =======================================================================
// Methods and fields available on all RDDs
@@ -114,7 +117,7 @@ abstract class RDD[T: ClassManifest](
val id: Int = sc.newRddId()
/** A friendly name for this RDD */
- var name: String = null
+ @transient var name: String = null
/** Assign a name to this RDD */
def setName(_name: String) = {
@@ -123,7 +126,7 @@ abstract class RDD[T: ClassManifest](
}
/** User-defined generator of this RDD*/
- var generator = Utils.getCallSiteInfo.firstUserClass
+ @transient var generator = Utils.getCallSiteInfo.firstUserClass
/** Reset generator*/
def setGenerator(_generator: String) = {
@@ -243,13 +246,13 @@ abstract class RDD[T: ClassManifest](
/**
* Return a new RDD by applying a function to all elements of this RDD.
*/
- def map[U: ClassManifest](f: T => U): RDD[U] = new MappedRDD(this, sc.clean(f))
+ def map[U: ClassTag](f: T => U): RDD[U] = new MappedRDD(this, sc.clean(f))
/**
* Return a new RDD by first applying a function to all elements of this
* RDD, and then flattening the results.
*/
- def flatMap[U: ClassManifest](f: T => TraversableOnce[U]): RDD[U] =
+ def flatMap[U: ClassTag](f: T => TraversableOnce[U]): RDD[U] =
new FlatMappedRDD(this, sc.clean(f))
/**
@@ -374,25 +377,25 @@ abstract class RDD[T: ClassManifest](
* Return the Cartesian product of this RDD and another one, that is, the RDD of all pairs of
* elements (a, b) where a is in `this` and b is in `other`.
*/
- def cartesian[U: ClassManifest](other: RDD[U]): RDD[(T, U)] = new CartesianRDD(sc, this, other)
+ def cartesian[U: ClassTag](other: RDD[U]): RDD[(T, U)] = new CartesianRDD(sc, this, other)
/**
* Return an RDD of grouped items.
*/
- def groupBy[K: ClassManifest](f: T => K): RDD[(K, Seq[T])] =
+ def groupBy[K: ClassTag](f: T => K): RDD[(K, Seq[T])] =
groupBy[K](f, defaultPartitioner(this))
/**
* Return an RDD of grouped elements. Each group consists of a key and a sequence of elements
* mapping to that key.
*/
- def groupBy[K: ClassManifest](f: T => K, numPartitions: Int): RDD[(K, Seq[T])] =
+ def groupBy[K: ClassTag](f: T => K, numPartitions: Int): RDD[(K, Seq[T])] =
groupBy(f, new HashPartitioner(numPartitions))
/**
* Return an RDD of grouped items.
*/
- def groupBy[K: ClassManifest](f: T => K, p: Partitioner): RDD[(K, Seq[T])] = {
+ def groupBy[K: ClassTag](f: T => K, p: Partitioner): RDD[(K, Seq[T])] = {
val cleanF = sc.clean(f)
this.map(t => (cleanF(t), t)).groupByKey(p)
}
@@ -408,7 +411,6 @@ abstract class RDD[T: ClassManifest](
def pipe(command: String, env: Map[String, String]): RDD[String] =
new PipedRDD(this, command, env)
-
/**
* Return an RDD created by piping elements to a forked external process.
* The print behavior can be customized by providing two functions.
@@ -440,29 +442,31 @@ abstract class RDD[T: ClassManifest](
/**
* Return a new RDD by applying a function to each partition of this RDD.
*/
- def mapPartitions[U: ClassManifest](
+ def mapPartitions[U: ClassTag](
f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
- new MapPartitionsRDD(this, sc.clean(f), preservesPartitioning)
+ val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(iter)
+ new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning)
}
/**
* Return a new RDD by applying a function to each partition of this RDD, while tracking the index
* of the original partition.
*/
- def mapPartitionsWithIndex[U: ClassManifest](
+ def mapPartitionsWithIndex[U: ClassTag](
f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
- val func = (context: TaskContext, iter: Iterator[T]) => f(context.partitionId, iter)
- new MapPartitionsWithContextRDD(this, sc.clean(func), preservesPartitioning)
+ val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter)
+ new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning)
}
/**
* Return a new RDD by applying a function to each partition of this RDD. This is a variant of
* mapPartitions that also passes the TaskContext into the closure.
*/
- def mapPartitionsWithContext[U: ClassManifest](
+ def mapPartitionsWithContext[U: ClassTag](
f: (TaskContext, Iterator[T]) => Iterator[U],
preservesPartitioning: Boolean = false): RDD[U] = {
- new MapPartitionsWithContextRDD(this, sc.clean(f), preservesPartitioning)
+ val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(context, iter)
+ new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning)
}
/**
@@ -470,7 +474,7 @@ abstract class RDD[T: ClassManifest](
* of the original partition.
*/
@deprecated("use mapPartitionsWithIndex", "0.7.0")
- def mapPartitionsWithSplit[U: ClassManifest](
+ def mapPartitionsWithSplit[U: ClassTag](
f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
mapPartitionsWithIndex(f, preservesPartitioning)
}
@@ -480,14 +484,13 @@ abstract class RDD[T: ClassManifest](
* additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
- def mapWith[A: ClassManifest, U: ClassManifest]
+ def mapWith[A: ClassTag, U: ClassTag]
(constructA: Int => A, preservesPartitioning: Boolean = false)
(f: (T, A) => U): RDD[U] = {
- def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = {
- val a = constructA(context.partitionId)
+ mapPartitionsWithIndex((index, iter) => {
+ val a = constructA(index)
iter.map(t => f(t, a))
- }
- new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning)
+ }, preservesPartitioning)
}
/**
@@ -495,14 +498,13 @@ abstract class RDD[T: ClassManifest](
* additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
- def flatMapWith[A: ClassManifest, U: ClassManifest]
+ def flatMapWith[A: ClassTag, U: ClassTag]
(constructA: Int => A, preservesPartitioning: Boolean = false)
(f: (T, A) => Seq[U]): RDD[U] = {
- def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = {
- val a = constructA(context.partitionId)
+ mapPartitionsWithIndex((index, iter) => {
+ val a = constructA(index)
iter.flatMap(t => f(t, a))
- }
- new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning)
+ }, preservesPartitioning)
}
/**
@@ -510,12 +512,11 @@ abstract class RDD[T: ClassManifest](
* This additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
- def foreachWith[A: ClassManifest](constructA: Int => A)(f: (T, A) => Unit) {
- def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = {
- val a = constructA(context.partitionId)
+ def foreachWith[A: ClassTag](constructA: Int => A)(f: (T, A) => Unit) {
+ mapPartitionsWithIndex { (index, iter) =>
+ val a = constructA(index)
iter.map(t => {f(t, a); t})
- }
- new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true).foreach(_ => {})
+ }.foreach(_ => {})
}
/**
@@ -523,12 +524,11 @@ abstract class RDD[T: ClassManifest](
* additional parameter is produced by constructA, which is called in each
* partition with the index of that partition.
*/
- def filterWith[A: ClassManifest](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = {
- def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = {
- val a = constructA(context.partitionId)
+ def filterWith[A: ClassTag](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = {
+ mapPartitionsWithIndex((index, iter) => {
+ val a = constructA(index)
iter.filter(t => p(t, a))
- }
- new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true)
+ }, preservesPartitioning = true)
}
/**
@@ -537,7 +537,7 @@ abstract class RDD[T: ClassManifest](
* partitions* and the *same number of elements in each partition* (e.g. one was made through
* a map on the other).
*/
- def zip[U: ClassManifest](other: RDD[U]): RDD[(T, U)] = new ZippedRDD(sc, this, other)
+ def zip[U: ClassTag](other: RDD[U]): RDD[(T, U)] = new ZippedRDD(sc, this, other)
/**
* Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by
@@ -545,20 +545,30 @@ abstract class RDD[T: ClassManifest](
* *same number of partitions*, but does *not* require them to have the same number
* of elements in each partition.
*/
- def zipPartitions[B: ClassManifest, V: ClassManifest]
+ def zipPartitions[B: ClassTag, V: ClassTag]
(rdd2: RDD[B])
(f: (Iterator[T], Iterator[B]) => Iterator[V]): RDD[V] =
- new ZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2)
+ new ZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2, false)
+
+ def zipPartitions[B: ClassTag, C: ClassTag, V: ClassTag]
+ (rdd2: RDD[B], rdd3: RDD[C], preservesPartitioning: Boolean)
+ (f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V]): RDD[V] =
+ new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3, preservesPartitioning)
- def zipPartitions[B: ClassManifest, C: ClassManifest, V: ClassManifest]
+ def zipPartitions[B: ClassTag, C: ClassTag, V: ClassTag]
(rdd2: RDD[B], rdd3: RDD[C])
(f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V]): RDD[V] =
- new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3)
+ new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3, false)
+
+ def zipPartitions[B: ClassTag, C: ClassTag, D: ClassTag, V: ClassTag]
+ (rdd2: RDD[B], rdd3: RDD[C], rdd4: RDD[D], preservesPartitioning: Boolean)
+ (f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V]): RDD[V] =
+ new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4, preservesPartitioning)
- def zipPartitions[B: ClassManifest, C: ClassManifest, D: ClassManifest, V: ClassManifest]
+ def zipPartitions[B: ClassTag, C: ClassTag, D: ClassTag, V: ClassTag]
(rdd2: RDD[B], rdd3: RDD[C], rdd4: RDD[D])
(f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V]): RDD[V] =
- new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4)
+ new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4, false)
// Actions (launch a job to return a value to the user program)
@@ -593,7 +603,7 @@ abstract class RDD[T: ClassManifest](
/**
* Return an RDD that contains all matching values by applying `f`.
*/
- def collect[U: ClassManifest](f: PartialFunction[T, U]): RDD[U] = {
+ def collect[U: ClassTag](f: PartialFunction[T, U]): RDD[U] = {
filter(f.isDefinedAt).map(f)
}
@@ -683,7 +693,7 @@ abstract class RDD[T: ClassManifest](
* allowed to modify and return their first argument instead of creating a new U to avoid memory
* allocation.
*/
- def aggregate[U: ClassManifest](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = {
+ def aggregate[U: ClassTag](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = {
// Clone the zero value since we will also be serializing it as part of tasks
var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance())
val cleanSeqOp = sc.clean(seqOp)
@@ -732,7 +742,7 @@ abstract class RDD[T: ClassManifest](
* combine step happens locally on the master, equivalent to running a single reduce task.
*/
def countByValue(): Map[T, Long] = {
- if (elementClassManifest.erasure.isArray) {
+ if (elementClassTag.runtimeClass.isArray) {
throw new SparkException("countByValue() does not support arrays")
}
// TODO: This should perhaps be distributed by default.
@@ -763,7 +773,7 @@ abstract class RDD[T: ClassManifest](
timeout: Long,
confidence: Double = 0.95
): PartialResult[Map[T, BoundedDouble]] = {
- if (elementClassManifest.erasure.isArray) {
+ if (elementClassTag.runtimeClass.isArray) {
throw new SparkException("countByValueApprox() does not support arrays")
}
val countPartition: (TaskContext, Iterator[T]) => OLMap[T] = { (ctx, iter) =>
@@ -928,14 +938,14 @@ abstract class RDD[T: ClassManifest](
private var storageLevel: StorageLevel = StorageLevel.NONE
/** Record user function generating this RDD. */
- private[spark] val origin = Utils.formatSparkCallSite
+ @transient private[spark] val origin = Utils.formatSparkCallSite
- private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T]
+ private[spark] def elementClassTag: ClassTag[T] = classTag[T]
private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None
/** Returns the first parent RDD */
- protected[spark] def firstParent[U: ClassManifest] = {
+ protected[spark] def firstParent[U: ClassTag] = {
dependencies.head.rdd.asInstanceOf[RDD[U]]
}
@@ -943,7 +953,7 @@ abstract class RDD[T: ClassManifest](
def context = sc
// Avoid handling doCheckpoint multiple times to prevent excessive recursion
- private var doCheckpointCalled = false
+ @transient private var doCheckpointCalled = false
/**
* Performs the checkpointing of this RDD by saving this. It is called by the DAGScheduler
@@ -997,7 +1007,7 @@ abstract class RDD[T: ClassManifest](
origin)
def toJavaRDD() : JavaRDD[T] = {
- new JavaRDD(this)(elementClassManifest)
+ new JavaRDD(this)(elementClassTag)
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
index 6009a41570..3b56e45aa9 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala
@@ -17,6 +17,8 @@
package org.apache.spark.rdd
+import scala.reflect.ClassTag
+
import org.apache.hadoop.fs.Path
import org.apache.hadoop.conf.Configuration
@@ -38,7 +40,7 @@ private[spark] object CheckpointState extends Enumeration {
* manages the post-checkpoint state by providing the updated partitions, iterator and preferred locations
* of the checkpointed RDD.
*/
-private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T])
+private[spark] class RDDCheckpointData[T: ClassTag](rdd: RDD[T])
extends Logging with Serializable {
import CheckpointState._
diff --git a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala
index 2c5253ae30..d433670cc2 100644
--- a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala
@@ -17,6 +17,7 @@
package org.apache.spark.rdd
+import scala.reflect.ClassTag
import java.util.Random
import cern.jet.random.Poisson
@@ -29,9 +30,9 @@ class SampledRDDPartition(val prev: Partition, val seed: Int) extends Partition
override val index: Int = prev.index
}
-class SampledRDD[T: ClassManifest](
+class SampledRDD[T: ClassTag](
prev: RDD[T],
- withReplacement: Boolean,
+ withReplacement: Boolean,
frac: Double,
seed: Int)
extends RDD[T](prev) {
diff --git a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala
index 5fe4676029..2d1bd5b481 100644
--- a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala
@@ -14,9 +14,10 @@
* See the License for the specific language governing permissions and
* limitations under the License.
*/
-
package org.apache.spark.rdd
+import scala.reflect.{ ClassTag, classTag}
+
import org.apache.hadoop.mapred.JobConf
import org.apache.hadoop.mapred.SequenceFileOutputFormat
import org.apache.hadoop.io.compress.CompressionCodec
@@ -32,15 +33,15 @@ import org.apache.spark.Logging
*
* Import `org.apache.spark.SparkContext._` at the top of their program to use these functions.
*/
-class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : ClassManifest](
+class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag](
self: RDD[(K, V)])
extends Logging
with Serializable {
- private def getWritableClass[T <% Writable: ClassManifest](): Class[_ <: Writable] = {
+ private def getWritableClass[T <% Writable: ClassTag](): Class[_ <: Writable] = {
val c = {
- if (classOf[Writable].isAssignableFrom(classManifest[T].erasure)) {
- classManifest[T].erasure
+ if (classOf[Writable].isAssignableFrom(classTag[T].runtimeClass)) {
+ classTag[T].runtimeClass
} else {
// We get the type of the Writable class by looking at the apply method which converts
// from T to Writable. Since we have two apply methods we filter out the one which
diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
index a5d751a7bd..3682c84598 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala
@@ -17,8 +17,10 @@
package org.apache.spark.rdd
-import org.apache.spark.{Dependency, Partitioner, SparkEnv, ShuffleDependency, Partition, TaskContext}
+import scala.reflect.ClassTag
+import org.apache.spark.{Dependency, Partition, Partitioner, ShuffleDependency,
+ SparkEnv, TaskContext}
private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
override val index = idx
@@ -32,7 +34,7 @@ private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition {
* @tparam K the key class.
* @tparam V the value class.
*/
-class ShuffledRDD[K, V, P <: Product2[K, V] : ClassManifest](
+class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag](
@transient var prev: RDD[P],
part: Partitioner)
extends RDD[P](prev.context, Nil) {
diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
index 7af4d803e7..aab30b1bb4 100644
--- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala
@@ -18,8 +18,11 @@
package org.apache.spark.rdd
import java.util.{HashMap => JHashMap}
+
import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
+import scala.reflect.ClassTag
+
import org.apache.spark.Partitioner
import org.apache.spark.Dependency
import org.apache.spark.TaskContext
@@ -45,7 +48,7 @@ import org.apache.spark.OneToOneDependency
* you can use `rdd1`'s partitioner/partition size and not worry about running
* out of memory because of the size of `rdd2`.
*/
-private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassManifest](
+private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag](
@transient var rdd1: RDD[_ <: Product2[K, V]],
@transient var rdd2: RDD[_ <: Product2[K, W]],
part: Partitioner)
diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
index ae8a9f36a6..08a41ac558 100644
--- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala
@@ -18,10 +18,13 @@
package org.apache.spark.rdd
import scala.collection.mutable.ArrayBuffer
+import scala.reflect.ClassTag
+
import org.apache.spark.{Dependency, RangeDependency, SparkContext, Partition, TaskContext}
+
import java.io.{ObjectOutputStream, IOException}
-private[spark] class UnionPartition[T: ClassManifest](idx: Int, rdd: RDD[T], splitIndex: Int)
+private[spark] class UnionPartition[T: ClassTag](idx: Int, rdd: RDD[T], splitIndex: Int)
extends Partition {
var split: Partition = rdd.partitions(splitIndex)
@@ -40,7 +43,7 @@ private[spark] class UnionPartition[T: ClassManifest](idx: Int, rdd: RDD[T], spl
}
}
-class UnionRDD[T: ClassManifest](
+class UnionRDD[T: ClassTag](
sc: SparkContext,
@transient var rdds: Seq[RDD[T]])
extends RDD[T](sc, Nil) { // Nil since we implement getDependencies
diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
index 31e6fd519d..83be3c6eb4 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
@@ -19,10 +19,12 @@ package org.apache.spark.rdd
import org.apache.spark.{OneToOneDependency, SparkContext, Partition, TaskContext}
import java.io.{ObjectOutputStream, IOException}
+import scala.reflect.ClassTag
private[spark] class ZippedPartitionsPartition(
idx: Int,
- @transient rdds: Seq[RDD[_]])
+ @transient rdds: Seq[RDD[_]],
+ @transient val preferredLocations: Seq[String])
extends Partition {
override val index: Int = idx
@@ -37,33 +39,31 @@ private[spark] class ZippedPartitionsPartition(
}
}
-abstract class ZippedPartitionsBaseRDD[V: ClassManifest](
+abstract class ZippedPartitionsBaseRDD[V: ClassTag](
sc: SparkContext,
- var rdds: Seq[RDD[_]])
+ var rdds: Seq[RDD[_]],
+ preservesPartitioning: Boolean = false)
extends RDD[V](sc, rdds.map(x => new OneToOneDependency(x))) {
+ override val partitioner =
+ if (preservesPartitioning) firstParent[Any].partitioner else None
+
override def getPartitions: Array[Partition] = {
- val sizes = rdds.map(x => x.partitions.size)
- if (!sizes.forall(x => x == sizes(0))) {
+ val numParts = rdds.head.partitions.size
+ if (!rdds.forall(rdd => rdd.partitions.size == numParts)) {
throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions")
}
- val array = new Array[Partition](sizes(0))
- for (i <- 0 until sizes(0)) {
- array(i) = new ZippedPartitionsPartition(i, rdds)
+ Array.tabulate[Partition](numParts) { i =>
+ val prefs = rdds.map(rdd => rdd.preferredLocations(rdd.partitions(i)))
+ // Check whether there are any hosts that match all RDDs; otherwise return the union
+ val exactMatchLocations = prefs.reduce((x, y) => x.intersect(y))
+ val locs = if (!exactMatchLocations.isEmpty) exactMatchLocations else prefs.flatten.distinct
+ new ZippedPartitionsPartition(i, rdds, locs)
}
- array
}
override def getPreferredLocations(s: Partition): Seq[String] = {
- val parts = s.asInstanceOf[ZippedPartitionsPartition].partitions
- val prefs = rdds.zip(parts).map { case (rdd, p) => rdd.preferredLocations(p) }
- // Check whether there are any hosts that match all RDDs; otherwise return the union
- val exactMatchLocations = prefs.reduce((x, y) => x.intersect(y))
- if (!exactMatchLocations.isEmpty) {
- exactMatchLocations
- } else {
- prefs.flatten.distinct
- }
+ s.asInstanceOf[ZippedPartitionsPartition].preferredLocations
}
override def clearDependencies() {
@@ -72,12 +72,13 @@ abstract class ZippedPartitionsBaseRDD[V: ClassManifest](
}
}
-class ZippedPartitionsRDD2[A: ClassManifest, B: ClassManifest, V: ClassManifest](
+class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag](
sc: SparkContext,
f: (Iterator[A], Iterator[B]) => Iterator[V],
var rdd1: RDD[A],
- var rdd2: RDD[B])
- extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2)) {
+ var rdd2: RDD[B],
+ preservesPartitioning: Boolean = false)
+ extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2), preservesPartitioning) {
override def compute(s: Partition, context: TaskContext): Iterator[V] = {
val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
@@ -92,13 +93,14 @@ class ZippedPartitionsRDD2[A: ClassManifest, B: ClassManifest, V: ClassManifest]
}
class ZippedPartitionsRDD3
- [A: ClassManifest, B: ClassManifest, C: ClassManifest, V: ClassManifest](
+ [A: ClassTag, B: ClassTag, C: ClassTag, V: ClassTag](
sc: SparkContext,
f: (Iterator[A], Iterator[B], Iterator[C]) => Iterator[V],
var rdd1: RDD[A],
var rdd2: RDD[B],
- var rdd3: RDD[C])
- extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3)) {
+ var rdd3: RDD[C],
+ preservesPartitioning: Boolean = false)
+ extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3), preservesPartitioning) {
override def compute(s: Partition, context: TaskContext): Iterator[V] = {
val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
@@ -116,14 +118,15 @@ class ZippedPartitionsRDD3
}
class ZippedPartitionsRDD4
- [A: ClassManifest, B: ClassManifest, C: ClassManifest, D:ClassManifest, V: ClassManifest](
+ [A: ClassTag, B: ClassTag, C: ClassTag, D:ClassTag, V: ClassTag](
sc: SparkContext,
f: (Iterator[A], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V],
var rdd1: RDD[A],
var rdd2: RDD[B],
var rdd3: RDD[C],
- var rdd4: RDD[D])
- extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4)) {
+ var rdd4: RDD[D],
+ preservesPartitioning: Boolean = false)
+ extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4), preservesPartitioning) {
override def compute(s: Partition, context: TaskContext): Iterator[V] = {
val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala
index 567b67dfee..fb5b070c18 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala
@@ -18,10 +18,12 @@
package org.apache.spark.rdd
import org.apache.spark.{OneToOneDependency, SparkContext, Partition, TaskContext}
+
import java.io.{ObjectOutputStream, IOException}
+import scala.reflect.ClassTag
-private[spark] class ZippedPartition[T: ClassManifest, U: ClassManifest](
+private[spark] class ZippedPartition[T: ClassTag, U: ClassTag](
idx: Int,
@transient rdd1: RDD[T],
@transient rdd2: RDD[U]
@@ -42,7 +44,7 @@ private[spark] class ZippedPartition[T: ClassManifest, U: ClassManifest](
}
}
-class ZippedRDD[T: ClassManifest, U: ClassManifest](
+class ZippedRDD[T: ClassTag, U: ClassTag](
sc: SparkContext,
var rdd1: RDD[T],
var rdd2: RDD[U])
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 42bb3884c8..963d15b76d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -21,9 +21,11 @@ import java.io.NotSerializableException
import java.util.Properties
import java.util.concurrent.atomic.AtomicInteger
-import akka.actor._
-import akka.util.duration._
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map}
+import scala.concurrent.duration._
+import scala.reflect.ClassTag
+
+import akka.actor._
import org.apache.spark._
import org.apache.spark.rdd.RDD
@@ -104,36 +106,16 @@ class DAGScheduler(
// The time, in millis, to wait for fetch failure events to stop coming in after one is detected;
// this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one
// as more failure events come in
- val RESUBMIT_TIMEOUT = 50L
+ val RESUBMIT_TIMEOUT = 50.milliseconds
// The time, in millis, to wake up between polls of the completion queue in order to potentially
// resubmit failed stages
val POLL_TIMEOUT = 10L
- private val eventProcessActor: ActorRef = env.actorSystem.actorOf(Props(new Actor {
- override def preStart() {
- context.system.scheduler.schedule(RESUBMIT_TIMEOUT milliseconds, RESUBMIT_TIMEOUT milliseconds) {
- if (failed.size > 0) {
- resubmitFailedStages()
- }
- }
- }
-
- /**
- * The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure
- * events and responds by launching tasks. This runs in a dedicated thread and receives events
- * via the eventQueue.
- */
- def receive = {
- case event: DAGSchedulerEvent =>
- logDebug("Got event of type " + event.getClass.getName)
+ // Warns the user if a stage contains a task with size greater than this value (in KB)
+ val TASK_SIZE_TO_WARN = 100
- if (!processEvent(event))
- submitWaitingStages()
- else
- context.stop(self)
- }
- }))
+ private var eventProcessActor: ActorRef = _
private[scheduler] val nextJobId = new AtomicInteger(0)
@@ -141,9 +123,13 @@ class DAGScheduler(
private val nextStageId = new AtomicInteger(0)
- private val stageIdToStage = new TimeStampedHashMap[Int, Stage]
+ private[scheduler] val jobIdToStageIds = new TimeStampedHashMap[Int, HashSet[Int]]
- private val shuffleToMapStage = new TimeStampedHashMap[Int, Stage]
+ private[scheduler] val stageIdToJobIds = new TimeStampedHashMap[Int, HashSet[Int]]
+
+ private[scheduler] val stageIdToStage = new TimeStampedHashMap[Int, Stage]
+
+ private[scheduler] val shuffleToMapStage = new TimeStampedHashMap[Int, Stage]
private[spark] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo]
@@ -174,6 +160,57 @@ class DAGScheduler(
val metadataCleaner = new MetadataCleaner(MetadataCleanerType.DAG_SCHEDULER, this.cleanup)
+ /**
+ * Starts the event processing actor. The actor has two responsibilities:
+ *
+ * 1. Waits for events like job submission, task finished, task failure etc., and calls
+ * [[org.apache.spark.scheduler.DAGScheduler.processEvent()]] to process them.
+ * 2. Schedules a periodical task to resubmit failed stages.
+ *
+ * NOTE: the actor cannot be started in the constructor, because the periodical task references
+ * some internal states of the enclosing [[org.apache.spark.scheduler.DAGScheduler]] object, thus
+ * cannot be scheduled until the [[org.apache.spark.scheduler.DAGScheduler]] is fully constructed.
+ */
+ def start() {
+ eventProcessActor = env.actorSystem.actorOf(Props(new Actor {
+ /**
+ * A handle to the periodical task, used to cancel the task when the actor is stopped.
+ */
+ var resubmissionTask: Cancellable = _
+
+ override def preStart() {
+ import context.dispatcher
+ /**
+ * A message is sent to the actor itself periodically to remind the actor to resubmit failed
+ * stages. In this way, stage resubmission can be done within the same thread context of
+ * other event processing logic to avoid unnecessary synchronization overhead.
+ */
+ resubmissionTask = context.system.scheduler.schedule(
+ RESUBMIT_TIMEOUT, RESUBMIT_TIMEOUT, self, ResubmitFailedStages)
+ }
+
+ /**
+ * The main event loop of the DAG scheduler.
+ */
+ def receive = {
+ case event: DAGSchedulerEvent =>
+ logDebug("Got event of type " + event.getClass.getName)
+
+ /**
+ * All events are forwarded to `processEvent()`, so that the event processing logic can
+ * easily tested without starting a dedicated actor. Please refer to `DAGSchedulerSuite`
+ * for details.
+ */
+ if (!processEvent(event)) {
+ submitWaitingStages()
+ } else {
+ resubmissionTask.cancel()
+ context.stop(self)
+ }
+ }
+ }))
+ }
+
def addSparkListener(listener: SparkListener) {
listenerBus.addListener(listener)
}
@@ -202,16 +239,16 @@ class DAGScheduler(
shuffleToMapStage.get(shuffleDep.shuffleId) match {
case Some(stage) => stage
case None =>
- val stage = newStage(shuffleDep.rdd, shuffleDep.rdd.partitions.size, Some(shuffleDep), jobId)
+ val stage = newOrUsedStage(shuffleDep.rdd, shuffleDep.rdd.partitions.size, shuffleDep, jobId)
shuffleToMapStage(shuffleDep.shuffleId) = stage
stage
}
}
/**
- * Create a Stage for the given RDD, either as a shuffle map stage (for a ShuffleDependency) or
- * as a result stage for the final RDD used directly in an action. The stage will also be
- * associated with the provided jobId.
+ * Create a Stage -- either directly for use as a result stage, or as part of the (re)-creation
+ * of a shuffle map stage in newOrUsedStage. The stage will be associated with the provided
+ * jobId. Production of shuffle map stages should always use newOrUsedStage, not newStage directly.
*/
private def newStage(
rdd: RDD[_],
@@ -221,21 +258,45 @@ class DAGScheduler(
callSite: Option[String] = None)
: Stage =
{
- if (shuffleDep != None) {
- // Kind of ugly: need to register RDDs with the cache and map output tracker here
- // since we can't do it in the RDD constructor because # of partitions is unknown
- logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")")
- mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.partitions.size)
- }
val id = nextStageId.getAndIncrement()
val stage =
new Stage(id, rdd, numTasks, shuffleDep, getParentStages(rdd, jobId), jobId, callSite)
stageIdToStage(id) = stage
+ updateJobIdStageIdMaps(jobId, stage)
stageToInfos(stage) = new StageInfo(stage)
stage
}
/**
+ * Create a shuffle map Stage for the given RDD. The stage will also be associated with the
+ * provided jobId. If a stage for the shuffleId existed previously so that the shuffleId is
+ * present in the MapOutputTracker, then the number and location of available outputs are
+ * recovered from the MapOutputTracker
+ */
+ private def newOrUsedStage(
+ rdd: RDD[_],
+ numTasks: Int,
+ shuffleDep: ShuffleDependency[_,_],
+ jobId: Int,
+ callSite: Option[String] = None)
+ : Stage =
+ {
+ val stage = newStage(rdd, numTasks, Some(shuffleDep), jobId, callSite)
+ if (mapOutputTracker.has(shuffleDep.shuffleId)) {
+ val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId)
+ val locs = MapOutputTracker.deserializeMapStatuses(serLocs)
+ for (i <- 0 until locs.size) stage.outputLocs(i) = List(locs(i))
+ stage.numAvailableOutputs = locs.size
+ } else {
+ // Kind of ugly: need to register RDDs with the cache and map output tracker here
+ // since we can't do it in the RDD constructor because # of partitions is unknown
+ logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")")
+ mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.size)
+ }
+ stage
+ }
+
+ /**
* Get or create the list of parent stages for a given RDD. The stages will be assigned the
* provided jobId if they haven't already been created with a lower jobId.
*/
@@ -287,6 +348,89 @@ class DAGScheduler(
}
/**
+ * Registers the given jobId among the jobs that need the given stage and
+ * all of that stage's ancestors.
+ */
+ private def updateJobIdStageIdMaps(jobId: Int, stage: Stage) {
+ def updateJobIdStageIdMapsList(stages: List[Stage]) {
+ if (!stages.isEmpty) {
+ val s = stages.head
+ stageIdToJobIds.getOrElseUpdate(s.id, new HashSet[Int]()) += jobId
+ jobIdToStageIds.getOrElseUpdate(jobId, new HashSet[Int]()) += s.id
+ val parents = getParentStages(s.rdd, jobId)
+ val parentsWithoutThisJobId = parents.filter(p => !stageIdToJobIds.get(p.id).exists(_.contains(jobId)))
+ updateJobIdStageIdMapsList(parentsWithoutThisJobId ++ stages.tail)
+ }
+ }
+ updateJobIdStageIdMapsList(List(stage))
+ }
+
+ /**
+ * Removes job and any stages that are not needed by any other job. Returns the set of ids for stages that
+ * were removed. The associated tasks for those stages need to be cancelled if we got here via job cancellation.
+ */
+ private def removeJobAndIndependentStages(jobId: Int): Set[Int] = {
+ val registeredStages = jobIdToStageIds(jobId)
+ val independentStages = new HashSet[Int]()
+ if (registeredStages.isEmpty) {
+ logError("No stages registered for job " + jobId)
+ } else {
+ stageIdToJobIds.filterKeys(stageId => registeredStages.contains(stageId)).foreach {
+ case (stageId, jobSet) =>
+ if (!jobSet.contains(jobId)) {
+ logError("Job %d not registered for stage %d even though that stage was registered for the job"
+ .format(jobId, stageId))
+ } else {
+ def removeStage(stageId: Int) {
+ // data structures based on Stage
+ stageIdToStage.get(stageId).foreach { s =>
+ if (running.contains(s)) {
+ logDebug("Removing running stage %d".format(stageId))
+ running -= s
+ }
+ stageToInfos -= s
+ shuffleToMapStage.keys.filter(shuffleToMapStage(_) == s).foreach(shuffleToMapStage.remove)
+ if (pendingTasks.contains(s) && !pendingTasks(s).isEmpty) {
+ logDebug("Removing pending status for stage %d".format(stageId))
+ }
+ pendingTasks -= s
+ if (waiting.contains(s)) {
+ logDebug("Removing stage %d from waiting set.".format(stageId))
+ waiting -= s
+ }
+ if (failed.contains(s)) {
+ logDebug("Removing stage %d from failed set.".format(stageId))
+ failed -= s
+ }
+ }
+ // data structures based on StageId
+ stageIdToStage -= stageId
+ stageIdToJobIds -= stageId
+
+ logDebug("After removal of stage %d, remaining stages = %d".format(stageId, stageIdToStage.size))
+ }
+
+ jobSet -= jobId
+ if (jobSet.isEmpty) { // no other job needs this stage
+ independentStages += stageId
+ removeStage(stageId)
+ }
+ }
+ }
+ }
+ independentStages.toSet
+ }
+
+ private def jobIdToStageIdsRemove(jobId: Int) {
+ if (!jobIdToStageIds.contains(jobId)) {
+ logDebug("Trying to remove unregistered job " + jobId)
+ } else {
+ removeJobAndIndependentStages(jobId)
+ jobIdToStageIds -= jobId
+ }
+ }
+
+ /**
* Submit a job to the job scheduler and get a JobWaiter object back. The JobWaiter object
* can be used to block until the the job finishes executing or can be used to cancel the job.
*/
@@ -319,7 +463,7 @@ class DAGScheduler(
waiter
}
- def runJob[T, U: ClassManifest](
+ def runJob[T, U: ClassTag](
rdd: RDD[T],
func: (TaskContext, Iterator[T]) => U,
partitions: Seq[Int],
@@ -375,13 +519,25 @@ class DAGScheduler(
}
/**
- * Process one event retrieved from the event queue.
- * Returns true if we should stop the event loop.
+ * Process one event retrieved from the event processing actor.
+ *
+ * @param event The event to be processed.
+ * @return `true` if we should stop the event loop.
*/
private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = {
event match {
case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) =>
- val finalStage = newStage(rdd, partitions.size, None, jobId, Some(callSite))
+ var finalStage: Stage = null
+ try {
+ // New stage creation at times and if its not protected, the scheduler thread is killed.
+ // e.g. it can fail when jobs are run on HadoopRDD whose underlying hdfs files have been deleted
+ finalStage = newStage(rdd, partitions.size, None, jobId, Some(callSite))
+ } catch {
+ case e: Exception =>
+ logWarning("Creating new stage failed due to exception - job: " + jobId, e)
+ listener.jobFailed(e)
+ return false
+ }
val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties)
clearCacheLocs()
logInfo("Got job " + job.jobId + " (" + callSite + ") with " + partitions.length +
@@ -391,37 +547,31 @@ class DAGScheduler(
logInfo("Missing parents: " + getMissingParentStages(finalStage))
if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) {
// Compute very short actions like first() or take() with no parent stages locally.
+ listenerBus.post(SparkListenerJobStart(job, Array(), properties))
runLocally(job)
} else {
- listenerBus.post(SparkListenerJobStart(job, properties))
idToActiveJob(jobId) = job
activeJobs += job
resultStageToJob(finalStage) = job
+ listenerBus.post(SparkListenerJobStart(job, jobIdToStageIds(jobId).toArray, properties))
submitStage(finalStage)
}
case JobCancelled(jobId) =>
- // Cancel a job: find all the running stages that are linked to this job, and cancel them.
- running.filter(_.jobId == jobId).foreach { stage =>
- taskSched.cancelTasks(stage.id)
- }
+ handleJobCancellation(jobId)
case JobGroupCancelled(groupId) =>
// Cancel all jobs belonging to this job group.
// First finds all active jobs with this group id, and then kill stages for them.
- val jobIds = activeJobs.filter(groupId == _.properties.get(SparkContext.SPARK_JOB_GROUP_ID))
- .map(_.jobId)
- if (!jobIds.isEmpty) {
- running.filter(stage => jobIds.contains(stage.jobId)).foreach { stage =>
- taskSched.cancelTasks(stage.id)
- }
- }
+ val activeInGroup = activeJobs.filter(groupId == _.properties.get(SparkContext.SPARK_JOB_GROUP_ID))
+ val jobIds = activeInGroup.map(_.jobId)
+ jobIds.foreach { handleJobCancellation }
case AllJobsCancelled =>
// Cancel all running jobs.
- running.foreach { stage =>
- taskSched.cancelTasks(stage.id)
- }
+ running.map(_.jobId).foreach { handleJobCancellation }
+ activeJobs.clear() // These should already be empty by this point,
+ idToActiveJob.clear() // but just in case we lost track of some jobs...
case ExecutorGained(execId, host) =>
handleExecutorGained(execId, host)
@@ -430,6 +580,18 @@ class DAGScheduler(
handleExecutorLost(execId)
case BeginEvent(task, taskInfo) =>
+ for (
+ job <- idToActiveJob.get(task.stageId);
+ stage <- stageIdToStage.get(task.stageId);
+ stageInfo <- stageToInfos.get(stage)
+ ) {
+ if (taskInfo.serializedSize > TASK_SIZE_TO_WARN * 1024 && !stageInfo.emittedTaskSizeWarning) {
+ stageInfo.emittedTaskSizeWarning = true
+ logWarning(("Stage %d (%s) contains a task of very large " +
+ "size (%d KB). The maximum recommended task size is %d KB.").format(
+ task.stageId, stageInfo.name, taskInfo.serializedSize / 1024, TASK_SIZE_TO_WARN))
+ }
+ }
listenerBus.post(SparkListenerTaskStart(task, taskInfo))
case GettingResultEvent(task, taskInfo) =>
@@ -440,7 +602,12 @@ class DAGScheduler(
handleTaskCompletion(completion)
case TaskSetFailed(taskSet, reason) =>
- abortStage(stageIdToStage(taskSet.stageId), reason)
+ stageIdToStage.get(taskSet.stageId).foreach { abortStage(_, reason) }
+
+ case ResubmitFailedStages =>
+ if (failed.size > 0) {
+ resubmitFailedStages()
+ }
case StopDAGScheduler =>
// Cancel any active jobs
@@ -502,6 +669,7 @@ class DAGScheduler(
// Broken out for easier testing in DAGSchedulerSuite.
protected def runLocallyWithinThread(job: ActiveJob) {
+ var jobResult: JobResult = JobSucceeded
try {
SparkEnv.set(env)
val rdd = job.finalStage.rdd
@@ -516,31 +684,59 @@ class DAGScheduler(
}
} catch {
case e: Exception =>
+ jobResult = JobFailed(e, Some(job.finalStage))
job.listener.jobFailed(e)
+ } finally {
+ val s = job.finalStage
+ stageIdToJobIds -= s.id // clean up data structures that were populated for a local job,
+ stageIdToStage -= s.id // but that won't get cleaned up via the normal paths through
+ stageToInfos -= s // completion events or stage abort
+ jobIdToStageIds -= job.jobId
+ listenerBus.post(SparkListenerJobEnd(job, jobResult))
+ }
+ }
+
+ /** Finds the earliest-created active job that needs the stage */
+ // TODO: Probably should actually find among the active jobs that need this
+ // stage the one with the highest priority (highest-priority pool, earliest created).
+ // That should take care of at least part of the priority inversion problem with
+ // cross-job dependencies.
+ private def activeJobForStage(stage: Stage): Option[Int] = {
+ if (stageIdToJobIds.contains(stage.id)) {
+ val jobsThatUseStage: Array[Int] = stageIdToJobIds(stage.id).toArray.sorted
+ jobsThatUseStage.find(idToActiveJob.contains(_))
+ } else {
+ None
}
}
/** Submits stage, but first recursively submits any missing parents. */
private def submitStage(stage: Stage) {
- logDebug("submitStage(" + stage + ")")
- if (!waiting(stage) && !running(stage) && !failed(stage)) {
- val missing = getMissingParentStages(stage).sortBy(_.id)
- logDebug("missing: " + missing)
- if (missing == Nil) {
- logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents")
- submitMissingTasks(stage)
- running += stage
- } else {
- for (parent <- missing) {
- submitStage(parent)
+ val jobId = activeJobForStage(stage)
+ if (jobId.isDefined) {
+ logDebug("submitStage(" + stage + ")")
+ if (!waiting(stage) && !running(stage) && !failed(stage)) {
+ val missing = getMissingParentStages(stage).sortBy(_.id)
+ logDebug("missing: " + missing)
+ if (missing == Nil) {
+ logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents")
+ submitMissingTasks(stage, jobId.get)
+ running += stage
+ } else {
+ for (parent <- missing) {
+ submitStage(parent)
+ }
+ waiting += stage
}
- waiting += stage
}
+ } else {
+ abortStage(stage, "No active job for stage " + stage.id)
}
}
+
/** Called when stage's parents are available and we can now do its task. */
- private def submitMissingTasks(stage: Stage) {
+ private def submitMissingTasks(stage: Stage, jobId: Int) {
logDebug("submitMissingTasks(" + stage + ")")
// Get our pending tasks and remember them in our pendingTasks entry
val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet)
@@ -561,7 +757,7 @@ class DAGScheduler(
}
}
- val properties = if (idToActiveJob.contains(stage.jobId)) {
+ val properties = if (idToActiveJob.contains(jobId)) {
idToActiveJob(stage.jobId).properties
} else {
//this stage will be assigned to "default" pool
@@ -643,6 +839,7 @@ class DAGScheduler(
activeJobs -= job
resultStageToJob -= stage
markStageAsFinished(stage)
+ jobIdToStageIdsRemove(job.jobId)
listenerBus.post(SparkListenerJobEnd(job, JobSucceeded))
}
job.listener.taskSucceeded(rt.outputId, event.result)
@@ -679,7 +876,7 @@ class DAGScheduler(
changeEpoch = true)
}
clearCacheLocs()
- if (stage.outputLocs.count(_ == Nil) != 0) {
+ if (stage.outputLocs.exists(_ == Nil)) {
// Some tasks had failed; let's resubmit this stage
// TODO: Lower-level scheduler should also deal with this
logInfo("Resubmitting " + stage + " (" + stage.name +
@@ -696,9 +893,12 @@ class DAGScheduler(
}
waiting --= newlyRunnable
running ++= newlyRunnable
- for (stage <- newlyRunnable.sortBy(_.id)) {
+ for {
+ stage <- newlyRunnable.sortBy(_.id)
+ jobId <- activeJobForStage(stage)
+ } {
logInfo("Submitting " + stage + " (" + stage.rdd + "), which is now runnable")
- submitMissingTasks(stage)
+ submitMissingTasks(stage, jobId)
}
}
}
@@ -782,21 +982,42 @@ class DAGScheduler(
}
}
+ private def handleJobCancellation(jobId: Int) {
+ if (!jobIdToStageIds.contains(jobId)) {
+ logDebug("Trying to cancel unregistered job " + jobId)
+ } else {
+ val independentStages = removeJobAndIndependentStages(jobId)
+ independentStages.foreach { taskSched.cancelTasks }
+ val error = new SparkException("Job %d cancelled".format(jobId))
+ val job = idToActiveJob(jobId)
+ job.listener.jobFailed(error)
+ jobIdToStageIds -= jobId
+ activeJobs -= job
+ idToActiveJob -= jobId
+ listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, Some(job.finalStage))))
+ }
+ }
+
/**
* Aborts all jobs depending on a particular Stage. This is called in response to a task set
* being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside.
*/
private def abortStage(failedStage: Stage, reason: String) {
+ if (!stageIdToStage.contains(failedStage.id)) {
+ // Skip all the actions if the stage has been removed.
+ return
+ }
val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq
stageToInfos(failedStage).completionTime = Some(System.currentTimeMillis())
for (resultStage <- dependentStages) {
val job = resultStageToJob(resultStage)
val error = new SparkException("Job aborted: " + reason)
job.listener.jobFailed(error)
- listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, Some(failedStage))))
+ jobIdToStageIdsRemove(job.jobId)
idToActiveJob -= resultStage.jobId
activeJobs -= job
resultStageToJob -= resultStage
+ listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, Some(failedStage))))
}
if (dependentStages.isEmpty) {
logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done")
@@ -867,25 +1088,24 @@ class DAGScheduler(
}
private def cleanup(cleanupTime: Long) {
- var sizeBefore = stageIdToStage.size
- stageIdToStage.clearOldValues(cleanupTime)
- logInfo("stageIdToStage " + sizeBefore + " --> " + stageIdToStage.size)
-
- sizeBefore = shuffleToMapStage.size
- shuffleToMapStage.clearOldValues(cleanupTime)
- logInfo("shuffleToMapStage " + sizeBefore + " --> " + shuffleToMapStage.size)
-
- sizeBefore = pendingTasks.size
- pendingTasks.clearOldValues(cleanupTime)
- logInfo("pendingTasks " + sizeBefore + " --> " + pendingTasks.size)
-
- sizeBefore = stageToInfos.size
- stageToInfos.clearOldValues(cleanupTime)
- logInfo("stageToInfos " + sizeBefore + " --> " + stageToInfos.size)
+ Map(
+ "stageIdToStage" -> stageIdToStage,
+ "shuffleToMapStage" -> shuffleToMapStage,
+ "pendingTasks" -> pendingTasks,
+ "stageToInfos" -> stageToInfos,
+ "jobIdToStageIds" -> jobIdToStageIds,
+ "stageIdToJobIds" -> stageIdToJobIds).
+ foreach { case(s, t) => {
+ val sizeBefore = t.size
+ t.clearOldValues(cleanupTime)
+ logInfo("%s %d --> %d".format(s, sizeBefore, t.size))
+ }}
}
def stop() {
- eventProcessActor ! StopDAGScheduler
+ if (eventProcessActor != null) {
+ eventProcessActor ! StopDAGScheduler
+ }
metadataCleaner.cancel()
taskSched.stop()
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
index 708d221d60..add1187613 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala
@@ -65,12 +65,13 @@ private[scheduler] case class CompletionEvent(
taskMetrics: TaskMetrics)
extends DAGSchedulerEvent
-private[scheduler]
-case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent
+private[scheduler] case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent
private[scheduler] case class ExecutorLost(execId: String) extends DAGSchedulerEvent
private[scheduler]
case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent
+private[scheduler] case object ResubmitFailedStages extends DAGSchedulerEvent
+
private[scheduler] case object StopDAGScheduler extends DAGSchedulerEvent
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
index 58f238d8cf..b026f860a8 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala
@@ -31,6 +31,7 @@ private[spark] class JobWaiter[T](
private var finishedTasks = 0
// Is the job as a whole finished (succeeded or failed)?
+ @volatile
private var _jobFinished = totalTasks == 0
def jobFinished = _jobFinished
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulingMode.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulingMode.scala
index 0a786deb16..3832ee7ff6 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SchedulingMode.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulingMode.scala
@@ -22,7 +22,7 @@ package org.apache.spark.scheduler
* to order tasks amongst a Schedulable's sub-queues
* "NONE" is used when the a Schedulable has no sub-queues.
*/
-object SchedulingMode extends Enumeration("FAIR", "FIFO", "NONE") {
+object SchedulingMode extends Enumeration {
type SchedulingMode = Value
val FAIR,FIFO,NONE = Value
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index 1dc71a0428..0f2deb4bcb 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -167,6 +167,7 @@ private[spark] class ShuffleMapTask(
var totalTime = 0L
val compressedSizes: Array[Byte] = shuffle.writers.map { writer: BlockObjectWriter =>
writer.commit()
+ writer.close()
val size = writer.fileSegment().length
totalBytes += size
totalTime += writer.timeWriting()
@@ -184,14 +185,16 @@ private[spark] class ShuffleMapTask(
} catch { case e: Exception =>
// If there is an exception from running the task, revert the partial writes
// and throw the exception upstream to Spark.
- if (shuffle != null) {
- shuffle.writers.foreach(_.revertPartialWrites())
+ if (shuffle != null && shuffle.writers != null) {
+ for (writer <- shuffle.writers) {
+ writer.revertPartialWrites()
+ writer.close()
+ }
}
throw e
} finally {
// Release the writers back to the shuffle block manager.
if (shuffle != null && shuffle.writers != null) {
- shuffle.writers.foreach(_.close())
shuffle.releaseWriters(success)
}
// Execute the callbacks on task completion.
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
index a35081f7b1..ee63b3c4a1 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
@@ -37,7 +37,7 @@ case class SparkListenerTaskGettingResult(
case class SparkListenerTaskEnd(task: Task[_], reason: TaskEndReason, taskInfo: TaskInfo,
taskMetrics: TaskMetrics) extends SparkListenerEvents
-case class SparkListenerJobStart(job: ActiveJob, properties: Properties = null)
+case class SparkListenerJobStart(job: ActiveJob, stageIds: Array[Int], properties: Properties = null)
extends SparkListenerEvents
case class SparkListenerJobEnd(job: ActiveJob, jobResult: JobResult)
@@ -63,7 +63,7 @@ trait SparkListener {
* Called when a task begins remotely fetching its result (will not be called for tasks that do
* not need to fetch the result remotely).
*/
- def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { }
+ def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { }
/**
* Called when a task ends
@@ -131,8 +131,8 @@ object StatsReportListener extends Logging {
def showDistribution(heading: String, d: Distribution, formatNumber: Double => String) {
val stats = d.statCounter
- logInfo(heading + stats)
val quantiles = d.getQuantiles(probabilities).map{formatNumber}
+ logInfo(heading + stats)
logInfo(percentilesHeader)
logInfo("\t" + quantiles.mkString("\t"))
}
@@ -173,8 +173,6 @@ object StatsReportListener extends Logging {
showMillisDistribution(heading, extractLongDistribution(stage, getMetric))
}
-
-
val seconds = 1000L
val minutes = seconds * 60
val hours = minutes * 60
@@ -198,7 +196,6 @@ object StatsReportListener extends Logging {
}
-
case class RuntimePercentage(executorPct: Double, fetchPct: Option[Double], other: Double)
object RuntimePercentage {
def apply(totalTime: Long, metrics: TaskMetrics): RuntimePercentage = {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
index d5824e7954..85687ea330 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala
@@ -91,4 +91,3 @@ private[spark] class SparkListenerBus() extends Logging {
return true
}
}
-
diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
index 93599dfdc8..e9f2198a00 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
@@ -33,4 +33,5 @@ class StageInfo(
val name = stage.name
val numPartitions = stage.numPartitions
val numTasks = stage.numTasks
+ var emittedTaskSizeWarning = false
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
index 4bae26f3a6..3c22edd524 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
@@ -46,6 +46,8 @@ class TaskInfo(
var failed = false
+ var serializedSize: Int = 0
+
def markGettingResult(time: Long = System.currentTimeMillis) {
gettingResultTime = time
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala
index 47b0f387aa..35de13c385 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala
@@ -18,9 +18,7 @@
package org.apache.spark.scheduler
-private[spark] object TaskLocality
- extends Enumeration("PROCESS_LOCAL", "NODE_LOCAL", "RACK_LOCAL", "ANY")
-{
+private[spark] object TaskLocality extends Enumeration {
// process local is expected to be used ONLY within tasksetmanager for now.
val PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY = Value
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
index c1e65a3c48..66ab8ea4cd 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala
@@ -24,8 +24,7 @@ import java.util.{TimerTask, Timer}
import scala.collection.mutable.ArrayBuffer
import scala.collection.mutable.HashMap
import scala.collection.mutable.HashSet
-
-import akka.util.duration._
+import scala.concurrent.duration._
import org.apache.spark._
import org.apache.spark.TaskState.TaskState
@@ -100,8 +99,8 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
this.dagScheduler = dagScheduler
}
- def initialize(context: SchedulerBackend) {
- backend = context
+ def initialize(backend: SchedulerBackend) {
+ this.backend = backend
// temporarily set rootPool name to empty
rootPool = new Pool("", schedulingMode, 0, 0)
schedulableBuilder = {
@@ -122,7 +121,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
if (System.getProperty("spark.speculation", "false").toBoolean) {
logInfo("Starting speculative execution thread")
-
+ import sc.env.actorSystem.dispatcher
sc.env.actorSystem.scheduler.schedule(SPECULATION_INTERVAL milliseconds,
SPECULATION_INTERVAL milliseconds) {
checkSpeculatableTasks()
@@ -173,7 +172,9 @@ private[spark] class ClusterScheduler(val sc: SparkContext)
backend.killTask(tid, execId)
}
}
- tsm.error("Stage %d was cancelled".format(stageId))
+ logInfo("Stage %d was cancelled".format(stageId))
+ tsm.removeAllRunningTasks()
+ taskSetFinished(tsm)
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
index 4c5eca8537..bf494aa64d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
@@ -377,6 +377,7 @@ private[spark] class ClusterTaskSetManager(
logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
taskSet.id, index, serializedTask.limit, timeTaken))
val taskName = "task %s:%d".format(taskSet.id, index)
+ info.serializedSize = serializedTask.limit
if (taskAttempts(index).size == 1)
taskStarted(task,info)
return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask))
@@ -528,10 +529,10 @@ private[spark] class ClusterTaskSetManager(
addPendingTask(index)
if (state != TaskState.KILLED) {
numFailures(index) += 1
- if (numFailures(index) > MAX_TASK_FAILURES) {
- logError("Task %s:%d failed more than %d times; aborting job".format(
+ if (numFailures(index) >= MAX_TASK_FAILURES) {
+ logError("Task %s:%d failed %d times; aborting job".format(
taskSet.id, index, MAX_TASK_FAILURES))
- abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES))
+ abort("Task %s:%d failed %d times".format(taskSet.id, index, MAX_TASK_FAILURES))
}
}
} else {
@@ -573,7 +574,7 @@ private[spark] class ClusterTaskSetManager(
runningTasks = runningTasksSet.size
}
- private def removeAllRunningTasks() {
+ private[cluster] def removeAllRunningTasks() {
val numRunningTasks = runningTasksSet.size
runningTasksSet.clear()
if (parent != null) {
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index a45bee536c..f5e8766f6d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -20,13 +20,12 @@ package org.apache.spark.scheduler.cluster
import java.util.concurrent.atomic.AtomicInteger
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet}
+import scala.concurrent.Await
+import scala.concurrent.duration._
import akka.actor._
-import akka.dispatch.Await
import akka.pattern.ask
-import akka.remote.{RemoteClientShutdown, RemoteClientDisconnected, RemoteClientLifeCycleEvent}
-import akka.util.Duration
-import akka.util.duration._
+import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
import org.apache.spark.{SparkException, Logging, TaskState}
import org.apache.spark.scheduler.TaskDescription
@@ -53,15 +52,15 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac
private val executorAddress = new HashMap[String, Address]
private val executorHost = new HashMap[String, String]
private val freeCores = new HashMap[String, Int]
- private val actorToExecutorId = new HashMap[ActorRef, String]
private val addressToExecutorId = new HashMap[Address, String]
override def preStart() {
// Listen for remote client disconnection events, since they don't go through Akka's watch()
- context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
+ context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
// Periodically revive offers to allow delay scheduling to work
val reviveInterval = System.getProperty("spark.scheduler.revive.interval", "1000").toLong
+ import context.dispatcher
context.system.scheduler.schedule(0.millis, reviveInterval.millis, self, ReviveOffers)
}
@@ -73,12 +72,10 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac
} else {
logInfo("Registered executor: " + sender + " with ID " + executorId)
sender ! RegisteredExecutor(sparkProperties)
- context.watch(sender)
executorActor(executorId) = sender
executorHost(executorId) = Utils.parseHostPort(hostPort)._1
freeCores(executorId) = cores
executorAddress(executorId) = sender.path.address
- actorToExecutorId(sender) = executorId
addressToExecutorId(sender.path.address) = executorId
totalCoreCount.addAndGet(cores)
makeOffers()
@@ -118,14 +115,9 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac
removeExecutor(executorId, reason)
sender ! true
- case Terminated(actor) =>
- actorToExecutorId.get(actor).foreach(removeExecutor(_, "Akka actor terminated"))
+ case DisassociatedEvent(_, address, _) =>
+ addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client disassociated"))
- case RemoteClientDisconnected(transport, address) =>
- addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client disconnected"))
-
- case RemoteClientShutdown(transport, address) =>
- addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client shutdown"))
}
// Make fake resource offers on all executors
@@ -153,7 +145,6 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac
if (executorActor.contains(executorId)) {
logInfo("Executor " + executorId + " disconnected, so removing it")
val numCores = freeCores(executorId)
- actorToExecutorId -= executorActor(executorId)
addressToExecutorId -= executorAddress(executorId)
executorActor -= executorId
executorHost -= executorId
@@ -199,6 +190,7 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac
}
override def stop() {
+ stopExecutors()
try {
if (driverActor != null) {
val future = driverActor.ask(StopDriver)(timeout)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
index 0ea35e2b7a..e8fecec4a6 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala
@@ -36,7 +36,7 @@ private[spark] class SimrSchedulerBackend(
override def start() {
super.start()
- val driverUrl = "akka://spark@%s:%s/user/%s".format(
+ val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format(
System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"),
CoarseGrainedSchedulerBackend.ACTOR_NAME)
@@ -62,7 +62,6 @@ private[spark] class SimrSchedulerBackend(
val conf = new Configuration()
val fs = FileSystem.get(conf)
fs.delete(new Path(driverFilePath), false)
- super.stopExecutors()
super.stop()
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
index cefa970bb9..7127a72d6d 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala
@@ -42,7 +42,7 @@ private[spark] class SparkDeploySchedulerBackend(
super.start()
// The endpoint for executors to talk to us
- val driverUrl = "akka://spark@%s:%s/user/%s".format(
+ val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format(
System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"),
CoarseGrainedSchedulerBackend.ACTOR_NAME)
val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}")
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala
index 2064d97b49..e68c527713 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala
@@ -71,7 +71,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterSche
case cnf: ClassNotFoundException =>
val loader = Thread.currentThread.getContextClassLoader
taskSetManager.abort("ClassNotFound with classloader: " + loader)
- case ex =>
+ case ex: Throwable =>
taskSetManager.abort("Exception while deserializing and fetching task: %s".format(ex))
}
}
@@ -95,7 +95,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterSche
val loader = Thread.currentThread.getContextClassLoader
logError(
"Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader)
- case ex => {}
+ case ex: Throwable => {}
}
scheduler.handleFailedTask(taskSetManager, tid, taskState, reason)
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
index cd521e0f2b..84fe3094cc 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
@@ -120,7 +120,7 @@ private[spark] class CoarseMesosSchedulerBackend(
}
val command = CommandInfo.newBuilder()
.setEnvironment(environment)
- val driverUrl = "akka://spark@%s:%s/user/%s".format(
+ val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format(
System.getProperty("spark.driver.host"),
System.getProperty("spark.driver.port"),
CoarseGrainedSchedulerBackend.ACTOR_NAME)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
index 2699f0b33e..01e95162c0 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala
@@ -74,7 +74,7 @@ class LocalActor(localScheduler: LocalScheduler, private var freeCores: Int)
}
}
-private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: SparkContext)
+private[spark] class LocalScheduler(val threads: Int, val maxFailures: Int, val sc: SparkContext)
extends TaskScheduler
with ExecutorBackend
with Logging {
@@ -144,7 +144,8 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
localActor ! KillTask(tid)
}
}
- tsm.error("Stage %d was cancelled".format(stageId))
+ logInfo("Stage %d was cancelled".format(stageId))
+ taskSetFinished(tsm)
}
}
@@ -192,17 +193,19 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc:
synchronized {
taskIdToTaskSetId.get(taskId) match {
case Some(taskSetId) =>
- val taskSetManager = activeTaskSets(taskSetId)
- taskSetTaskIds(taskSetId) -= taskId
-
- state match {
- case TaskState.FINISHED =>
- taskSetManager.taskEnded(taskId, state, serializedData)
- case TaskState.FAILED =>
- taskSetManager.taskFailed(taskId, state, serializedData)
- case TaskState.KILLED =>
- taskSetManager.error("Task %d was killed".format(taskId))
- case _ => {}
+ val taskSetManager = activeTaskSets.get(taskSetId)
+ taskSetManager.foreach { tsm =>
+ taskSetTaskIds(taskSetId) -= taskId
+
+ state match {
+ case TaskState.FINISHED =>
+ tsm.taskEnded(taskId, state, serializedData)
+ case TaskState.FAILED =>
+ tsm.taskFailed(taskId, state, serializedData)
+ case TaskState.KILLED =>
+ tsm.error("Task %d was killed".format(taskId))
+ case _ => {}
+ }
}
case None =>
logInfo("Ignoring update from TID " + taskId + " because its task set is gone")
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 702aca8323..19a025a329 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -24,9 +24,9 @@ import scala.collection.mutable.{HashMap, ArrayBuffer}
import scala.util.Random
import akka.actor.{ActorSystem, Cancellable, Props}
-import akka.dispatch.{Await, Future}
-import akka.util.Duration
-import akka.util.duration._
+import scala.concurrent.{Await, Future}
+import scala.concurrent.duration.Duration
+import scala.concurrent.duration._
import it.unimi.dsi.fastutil.io.{FastBufferedOutputStream, FastByteArrayOutputStream}
@@ -924,4 +924,3 @@ private[spark] object BlockManager extends Logging {
blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.host))
}
}
-
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
index 94038649b3..e05b842476 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
@@ -17,16 +17,17 @@
package org.apache.spark.storage
-import akka.actor.ActorRef
-import akka.dispatch.{Await, Future}
+import scala.concurrent.{Await, Future}
+import scala.concurrent.duration._
+import scala.concurrent.ExecutionContext.Implicits.global
+
+import akka.actor._
import akka.pattern.ask
-import akka.util.Duration
import org.apache.spark.{Logging, SparkException}
import org.apache.spark.storage.BlockManagerMessages._
-
-private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Logging {
+private[spark] class BlockManagerMaster(var driverActor : Either[ActorRef, ActorSelection]) extends Logging {
val AKKA_RETRY_ATTEMPTS: Int = System.getProperty("spark.akka.num.retries", "3").toInt
val AKKA_RETRY_INTERVAL_MS: Int = System.getProperty("spark.akka.retry.wait", "3000").toInt
@@ -156,7 +157,10 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi
while (attempts < AKKA_RETRY_ATTEMPTS) {
attempts += 1
try {
- val future = driverActor.ask(message)(timeout)
+ val future = driverActor match {
+ case Left(a: ActorRef) => a.ask(message)(timeout)
+ case Right(b: ActorSelection) => b.ask(message)(timeout)
+ }
val result = Await.result(future, timeout)
if (result == null) {
throw new SparkException("BlockManagerMaster returned null")
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
index f8cf14b503..154a3980e9 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
@@ -23,10 +23,10 @@ import scala.collection.mutable
import scala.collection.JavaConversions._
import akka.actor.{Actor, ActorRef, Cancellable}
-import akka.dispatch.Future
import akka.pattern.ask
-import akka.util.Duration
-import akka.util.duration._
+
+import scala.concurrent.duration._
+import scala.concurrent.Future
import org.apache.spark.{Logging, SparkException}
import org.apache.spark.storage.BlockManagerMessages._
@@ -65,6 +65,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging {
override def preStart() {
if (!BlockManager.getDisableHeartBeatsForTesting) {
+ import context.dispatcher
timeoutCheckingTask = context.system.scheduler.schedule(
0.seconds, checkTimeoutInterval.milliseconds, self, ExpireDeadHosts)
}
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
index 469e68fed7..b4451fc7b8 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
@@ -93,6 +93,8 @@ class DiskBlockObjectWriter(
def write(i: Int): Unit = callWithTiming(out.write(i))
override def write(b: Array[Byte]) = callWithTiming(out.write(b))
override def write(b: Array[Byte], off: Int, len: Int) = callWithTiming(out.write(b, off, len))
+ override def close() = out.close()
+ override def flush() = out.flush()
}
private val syncWrites = System.getProperty("spark.shuffle.sync", "false").toBoolean
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
index 2f1b049ce4..e828e1d1c5 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala
@@ -62,7 +62,7 @@ class ShuffleBlockManager(blockManager: BlockManager) {
// Turning off shuffle file consolidation causes all shuffle Blocks to get their own file.
// TODO: Remove this once the shuffle file consolidation feature is stable.
val consolidateShuffleFiles =
- System.getProperty("spark.shuffle.consolidateFiles", "true").toBoolean
+ System.getProperty("spark.shuffle.consolidateFiles", "false").toBoolean
private val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024
diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
index 632ff047d1..b5596dffd3 100644
--- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala
@@ -101,7 +101,7 @@ class StorageLevel private(
var result = ""
result += (if (useDisk) "Disk " else "")
result += (if (useMemory) "Memory " else "")
- result += (if (deserialized) "Deserialized " else "Serialized")
+ result += (if (deserialized) "Deserialized " else "Serialized ")
result += "%sx Replicated".format(replication)
result
}
diff --git a/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala b/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala
index 1e4db4f66b..d52b3d8284 100644
--- a/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala
+++ b/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala
@@ -1,3 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
package org.apache.spark.storage
import java.util.concurrent.atomic.AtomicLong
diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
index 860e680576..a8db37ded1 100644
--- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala
@@ -93,7 +93,7 @@ private[spark] object ThreadingTest {
val actorSystem = ActorSystem("test")
val serializer = new KryoSerializer
val blockManagerMaster = new BlockManagerMaster(
- actorSystem.actorOf(Props(new BlockManagerMasterActor(true))))
+ Left(actorSystem.actorOf(Props(new BlockManagerMasterActor(true)))))
val blockManager = new BlockManager(
"<driver>", actorSystem, blockManagerMaster, serializer, 1024 * 1024)
val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i))
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala
index 42e9be6e19..e596690bc3 100644
--- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala
@@ -76,7 +76,7 @@ private[spark] class ExecutorsUI(val sc: SparkContext) {
</tr>
}
- val execInfo = for (b <- 0 until storageStatusList.size) yield getExecInfo(b)
+ val execInfo = for (statusId <- 0 until storageStatusList.size) yield getExecInfo(statusId)
val execTable = UIUtils.listingTable(execHead, execRow, execInfo)
val content =
@@ -99,16 +99,17 @@ private[spark] class ExecutorsUI(val sc: SparkContext) {
UIUtils.headerSparkPage(content, sc, "Executors (" + execInfo.size + ")", Executors)
}
- def getExecInfo(a: Int): Seq[String] = {
- val execId = sc.getExecutorStorageStatus(a).blockManagerId.executorId
- val hostPort = sc.getExecutorStorageStatus(a).blockManagerId.hostPort
- val rddBlocks = sc.getExecutorStorageStatus(a).blocks.size.toString
- val memUsed = sc.getExecutorStorageStatus(a).memUsed().toString
- val maxMem = sc.getExecutorStorageStatus(a).maxMem.toString
- val diskUsed = sc.getExecutorStorageStatus(a).diskUsed().toString
- val activeTasks = listener.executorToTasksActive.get(a.toString).map(l => l.size).getOrElse(0)
- val failedTasks = listener.executorToTasksFailed.getOrElse(a.toString, 0)
- val completedTasks = listener.executorToTasksComplete.getOrElse(a.toString, 0)
+ def getExecInfo(statusId: Int): Seq[String] = {
+ val status = sc.getExecutorStorageStatus(statusId)
+ val execId = status.blockManagerId.executorId
+ val hostPort = status.blockManagerId.hostPort
+ val rddBlocks = status.blocks.size.toString
+ val memUsed = status.memUsed().toString
+ val maxMem = status.maxMem.toString
+ val diskUsed = status.diskUsed().toString
+ val activeTasks = listener.executorToTasksActive.getOrElse(execId, HashSet.empty[Long]).size
+ val failedTasks = listener.executorToTasksFailed.getOrElse(execId, 0)
+ val completedTasks = listener.executorToTasksComplete.getOrElse(execId, 0)
val totalTasks = activeTasks + failedTasks + completedTasks
Seq(
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala
index e7eab374ad..c1ee2f3d00 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala
@@ -17,7 +17,7 @@
package org.apache.spark.ui.jobs
-import akka.util.Duration
+import scala.concurrent.duration._
import java.text.SimpleDateFormat
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index c1c7aa70e6..69f9446bab 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -60,11 +60,13 @@ private[spark] class StagePage(parent: JobProgressUI) {
var activeTime = 0L
listener.stageIdToTasksActive(stageId).foreach(activeTime += _.timeRunning(now))
+ val finishedTasks = listener.stageIdToTaskInfos(stageId).filter(_._1.finished)
+
val summary =
<div>
<ul class="unstyled">
<li>
- <strong>CPU time: </strong>
+ <strong>Total duration across all tasks: </strong>
{parent.formatDuration(listener.stageIdToTime.getOrElse(stageId, 0L) + activeTime)}
</li>
{if (hasShuffleRead)
@@ -104,6 +106,33 @@ private[spark] class StagePage(parent: JobProgressUI) {
val serviceQuantiles = "Duration" +: Distribution(serviceTimes).get.getQuantiles().map(
ms => parent.formatDuration(ms.toLong))
+ val gettingResultTimes = validTasks.map{case (info, metrics, exception) =>
+ if (info.gettingResultTime > 0) {
+ (info.finishTime - info.gettingResultTime).toDouble
+ } else {
+ 0.0
+ }
+ }
+ val gettingResultQuantiles = ("Time spent fetching task results" +:
+ Distribution(gettingResultTimes).get.getQuantiles().map(
+ millis => parent.formatDuration(millis.toLong)))
+ // The scheduler delay includes the network delay to send the task to the worker
+ // machine and to send back the result (but not the time to fetch the task result,
+ // if it needed to be fetched from the block manager on the worker).
+ val schedulerDelays = validTasks.map{case (info, metrics, exception) =>
+ val totalExecutionTime = {
+ if (info.gettingResultTime > 0) {
+ (info.gettingResultTime - info.launchTime).toDouble
+ } else {
+ (info.finishTime - info.launchTime).toDouble
+ }
+ }
+ totalExecutionTime - metrics.get.executorRunTime
+ }
+ val schedulerDelayQuantiles = ("Scheduler delay" +:
+ Distribution(schedulerDelays).get.getQuantiles().map(
+ millis => parent.formatDuration(millis.toLong)))
+
def getQuantileCols(data: Seq[Double]) =
Distribution(data).get.getQuantiles().map(d => Utils.bytesToString(d.toLong))
@@ -119,7 +148,10 @@ private[spark] class StagePage(parent: JobProgressUI) {
}
val shuffleWriteQuantiles = "Shuffle Write" +: getQuantileCols(shuffleWriteSizes)
- val listings: Seq[Seq[String]] = Seq(serviceQuantiles,
+ val listings: Seq[Seq[String]] = Seq(
+ serviceQuantiles,
+ gettingResultQuantiles,
+ schedulerDelayQuantiles,
if (hasShuffleRead) shuffleReadQuantiles else Nil,
if (hasShuffleWrite) shuffleWriteQuantiles else Nil)
@@ -133,7 +165,7 @@ private[spark] class StagePage(parent: JobProgressUI) {
summary ++
<h4>Summary Metrics for {numCompleted} Completed Tasks</h4> ++
<div>{summaryTable.getOrElse("No tasks have reported metrics yet.")}</div> ++
- <h4>Tasks</h4> ++ taskTable;
+ <h4>Tasks</h4> ++ taskTable
headerSparkPage(content, parent.sc, "Details for Stage %d".format(stageId), Stages)
}
@@ -152,21 +184,18 @@ private[spark] class StagePage(parent: JobProgressUI) {
else metrics.map(m => parent.formatDuration(m.executorRunTime)).getOrElse("")
val gcTime = metrics.map(m => m.jvmGCTime).getOrElse(0L)
- var shuffleReadSortable: String = ""
- var shuffleReadReadable: String = ""
- if (shuffleRead) {
- shuffleReadSortable = metrics.flatMap{m => m.shuffleReadMetrics}.map{s => s.remoteBytesRead}.toString()
- shuffleReadReadable = metrics.flatMap{m => m.shuffleReadMetrics}.map{s =>
- Utils.bytesToString(s.remoteBytesRead)}.getOrElse("")
- }
+ val maybeShuffleRead = metrics.flatMap{m => m.shuffleReadMetrics}.map{s => s.remoteBytesRead}
+ val shuffleReadSortable = maybeShuffleRead.map(_.toString).getOrElse("")
+ val shuffleReadReadable = maybeShuffleRead.map{Utils.bytesToString(_)}.getOrElse("")
- var shuffleWriteSortable: String = ""
- var shuffleWriteReadable: String = ""
- if (shuffleWrite) {
- shuffleWriteSortable = metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => s.shuffleBytesWritten}.toString()
- shuffleWriteReadable = metrics.flatMap{m => m.shuffleWriteMetrics}.map{s =>
- Utils.bytesToString(s.shuffleBytesWritten)}.getOrElse("")
- }
+ val maybeShuffleWrite = metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => s.shuffleBytesWritten}
+ val shuffleWriteSortable = maybeShuffleWrite.map(_.toString).getOrElse("")
+ val shuffleWriteReadable = maybeShuffleWrite.map{Utils.bytesToString(_)}.getOrElse("")
+
+ val maybeWriteTime = metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => s.shuffleWriteTime}
+ val writeTimeSortable = maybeWriteTime.map(_.toString).getOrElse("")
+ val writeTimeReadable = maybeWriteTime.map{ t => t / (1000 * 1000)}.map{ ms =>
+ if (ms == 0) "" else parent.formatDuration(ms)}.getOrElse("")
<tr>
<td>{info.index}</td>
@@ -187,8 +216,8 @@ private[spark] class StagePage(parent: JobProgressUI) {
</td>
}}
{if (shuffleWrite) {
- <td>{metrics.flatMap{m => m.shuffleWriteMetrics}.map{s =>
- parent.formatDuration(s.shuffleWriteTime / (1000 * 1000))}.getOrElse("")}
+ <td sorttable_customkey={writeTimeSortable}>
+ {writeTimeReadable}
</td>
<td sorttable_customkey={shuffleWriteSortable}>
{shuffleWriteReadable}
diff --git a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala
index 1d633d374a..a5446b3fc3 100644
--- a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala
@@ -17,7 +17,7 @@
package org.apache.spark.ui.storage
-import akka.util.Duration
+import scala.concurrent.duration._
import javax.servlet.http.HttpServletRequest
diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
index d4c5065c3f..74133cef6c 100644
--- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
+++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala
@@ -17,11 +17,8 @@
package org.apache.spark.util
-import akka.actor.{ActorSystem, ExtendedActorSystem}
+import akka.actor.{ActorSystem, ExtendedActorSystem, IndestructibleActorSystem}
import com.typesafe.config.ConfigFactory
-import akka.util.duration._
-import akka.remote.RemoteActorRefProvider
-
/**
* Various utility classes for working with Akka.
@@ -34,39 +31,57 @@ private[spark] object AkkaUtils {
*
* Note: the `name` parameter is important, as even if a client sends a message to right
* host + port, if the system name is incorrect, Akka will drop the message.
+ *
+ * If indestructible is set to true, the Actor System will continue running in the event
+ * of a fatal exception. This is used by [[org.apache.spark.executor.Executor]].
*/
- def createActorSystem(name: String, host: String, port: Int): (ActorSystem, Int) = {
- val akkaThreads = System.getProperty("spark.akka.threads", "4").toInt
+ def createActorSystem(name: String, host: String, port: Int, indestructible: Boolean = false)
+ : (ActorSystem, Int) = {
+
+ val akkaThreads = System.getProperty("spark.akka.threads", "4").toInt
val akkaBatchSize = System.getProperty("spark.akka.batchSize", "15").toInt
- val akkaTimeout = System.getProperty("spark.akka.timeout", "60").toInt
+
+ val akkaTimeout = System.getProperty("spark.akka.timeout", "100").toInt
+
val akkaFrameSize = System.getProperty("spark.akka.frameSize", "10").toInt
- val lifecycleEvents = if (System.getProperty("spark.akka.logLifecycleEvents", "false").toBoolean) "on" else "off"
- // 10 seconds is the default akka timeout, but in a cluster, we need higher by default.
- val akkaWriteTimeout = System.getProperty("spark.akka.writeTimeout", "30").toInt
-
- val akkaConf = ConfigFactory.parseString("""
- akka.daemonic = on
- akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"]
- akka.stdout-loglevel = "ERROR"
- akka.actor.provider = "akka.remote.RemoteActorRefProvider"
- akka.remote.transport = "akka.remote.netty.NettyRemoteTransport"
- akka.remote.netty.hostname = "%s"
- akka.remote.netty.port = %d
- akka.remote.netty.connection-timeout = %ds
- akka.remote.netty.message-frame-size = %d MiB
- akka.remote.netty.execution-pool-size = %d
- akka.actor.default-dispatcher.throughput = %d
- akka.remote.log-remote-lifecycle-events = %s
- akka.remote.netty.write-timeout = %ds
- """.format(host, port, akkaTimeout, akkaFrameSize, akkaThreads, akkaBatchSize,
- lifecycleEvents, akkaWriteTimeout))
+ val lifecycleEvents =
+ if (System.getProperty("spark.akka.logLifecycleEvents", "false").toBoolean) "on" else "off"
+
+ val akkaHeartBeatPauses = System.getProperty("spark.akka.heartbeat.pauses", "600").toInt
+ val akkaFailureDetector =
+ System.getProperty("spark.akka.failure-detector.threshold", "300.0").toDouble
+ val akkaHeartBeatInterval = System.getProperty("spark.akka.heartbeat.interval", "1000").toInt
- val actorSystem = ActorSystem(name, akkaConf)
+ val akkaConf = ConfigFactory.parseString(
+ s"""
+ |akka.daemonic = on
+ |akka.loggers = [""akka.event.slf4j.Slf4jLogger""]
+ |akka.stdout-loglevel = "ERROR"
+ |akka.jvm-exit-on-fatal-error = off
+ |akka.remote.transport-failure-detector.heartbeat-interval = $akkaHeartBeatInterval s
+ |akka.remote.transport-failure-detector.acceptable-heartbeat-pause = $akkaHeartBeatPauses s
+ |akka.remote.transport-failure-detector.threshold = $akkaFailureDetector
+ |akka.actor.provider = "akka.remote.RemoteActorRefProvider"
+ |akka.remote.netty.tcp.transport-class = "akka.remote.transport.netty.NettyTransport"
+ |akka.remote.netty.tcp.hostname = "$host"
+ |akka.remote.netty.tcp.port = $port
+ |akka.remote.netty.tcp.tcp-nodelay = on
+ |akka.remote.netty.tcp.connection-timeout = $akkaTimeout s
+ |akka.remote.netty.tcp.maximum-frame-size = ${akkaFrameSize}MiB
+ |akka.remote.netty.tcp.execution-pool-size = $akkaThreads
+ |akka.actor.default-dispatcher.throughput = $akkaBatchSize
+ |akka.remote.log-remote-lifecycle-events = $lifecycleEvents
+ """.stripMargin)
+
+ val actorSystem = if (indestructible) {
+ IndestructibleActorSystem(name, akkaConf)
+ } else {
+ ActorSystem(name, akkaConf)
+ }
- // Figure out the port number we bound to, in case port was passed as 0. This is a bit of a
- // hack because Akka doesn't let you figure out the port through the public API yet.
val provider = actorSystem.asInstanceOf[ExtendedActorSystem].provider
- val boundPort = provider.asInstanceOf[RemoteActorRefProvider].transport.address.port.get
- return (actorSystem, boundPort)
+ val boundPort = provider.getDefaultAddress.port.get
+ (actorSystem, boundPort)
}
+
}
diff --git a/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala
index f60deafc6f..8bb4ee3bfa 100644
--- a/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala
@@ -35,6 +35,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
private var capacity = nextPowerOf2(initialCapacity)
private var mask = capacity - 1
private var curSize = 0
+ private var growThreshold = LOAD_FACTOR * capacity
// Holds keys and values in the same array for memory locality; specifically, the order of
// elements is key0, value0, key1, value1, key2, value2, etc.
@@ -56,7 +57,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
var i = 1
while (true) {
val curKey = data(2 * pos)
- if (k.eq(curKey) || k == curKey) {
+ if (k.eq(curKey) || k.equals(curKey)) {
return data(2 * pos + 1).asInstanceOf[V]
} else if (curKey.eq(null)) {
return null.asInstanceOf[V]
@@ -80,9 +81,23 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
haveNullValue = true
return
}
- val isNewEntry = putInto(data, k, value.asInstanceOf[AnyRef])
- if (isNewEntry) {
- incrementSize()
+ var pos = rehash(key.hashCode) & mask
+ var i = 1
+ while (true) {
+ val curKey = data(2 * pos)
+ if (curKey.eq(null)) {
+ data(2 * pos) = k
+ data(2 * pos + 1) = value.asInstanceOf[AnyRef]
+ incrementSize() // Since we added a new key
+ return
+ } else if (k.eq(curKey) || k.equals(curKey)) {
+ data(2 * pos + 1) = value.asInstanceOf[AnyRef]
+ return
+ } else {
+ val delta = i
+ pos = (pos + delta) & mask
+ i += 1
+ }
}
}
@@ -104,7 +119,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
var i = 1
while (true) {
val curKey = data(2 * pos)
- if (k.eq(curKey) || k == curKey) {
+ if (k.eq(curKey) || k.equals(curKey)) {
val newValue = updateFunc(true, data(2 * pos + 1).asInstanceOf[V])
data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
return newValue
@@ -161,45 +176,17 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
/** Increase table size by 1, rehashing if necessary */
private def incrementSize() {
curSize += 1
- if (curSize > LOAD_FACTOR * capacity) {
+ if (curSize > growThreshold) {
growTable()
}
}
/**
- * Re-hash a value to deal better with hash functions that don't differ
- * in the lower bits, similar to java.util.HashMap
+ * Re-hash a value to deal better with hash functions that don't differ in the lower bits.
+ * We use the Murmur Hash 3 finalization step that's also used in fastutil.
*/
private def rehash(h: Int): Int = {
- val r = h ^ (h >>> 20) ^ (h >>> 12)
- r ^ (r >>> 7) ^ (r >>> 4)
- }
-
- /**
- * Put an entry into a table represented by data, returning true if
- * this increases the size of the table or false otherwise. Assumes
- * that "data" has at least one empty slot.
- */
- private def putInto(data: Array[AnyRef], key: AnyRef, value: AnyRef): Boolean = {
- val mask = (data.length / 2) - 1
- var pos = rehash(key.hashCode) & mask
- var i = 1
- while (true) {
- val curKey = data(2 * pos)
- if (curKey.eq(null)) {
- data(2 * pos) = key
- data(2 * pos + 1) = value.asInstanceOf[AnyRef]
- return true
- } else if (curKey.eq(key) || curKey == key) {
- data(2 * pos + 1) = value.asInstanceOf[AnyRef]
- return false
- } else {
- val delta = i
- pos = (pos + delta) & mask
- i += 1
- }
- }
- return false // Never reached but needed to keep compiler happy
+ it.unimi.dsi.fastutil.HashCommon.murmurHash3(h)
}
/** Double the table's size and re-hash everything */
@@ -211,16 +198,36 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
throw new Exception("Can't make capacity bigger than 2^29 elements")
}
val newData = new Array[AnyRef](2 * newCapacity)
- var pos = 0
- while (pos < capacity) {
- if (!data(2 * pos).eq(null)) {
- putInto(newData, data(2 * pos), data(2 * pos + 1))
+ val newMask = newCapacity - 1
+ // Insert all our old values into the new array. Note that because our old keys are
+ // unique, there's no need to check for equality here when we insert.
+ var oldPos = 0
+ while (oldPos < capacity) {
+ if (!data(2 * oldPos).eq(null)) {
+ val key = data(2 * oldPos)
+ val value = data(2 * oldPos + 1)
+ var newPos = rehash(key.hashCode) & newMask
+ var i = 1
+ var keepGoing = true
+ while (keepGoing) {
+ val curKey = newData(2 * newPos)
+ if (curKey.eq(null)) {
+ newData(2 * newPos) = key
+ newData(2 * newPos + 1) = value
+ keepGoing = false
+ } else {
+ val delta = i
+ newPos = (newPos + delta) & newMask
+ i += 1
+ }
+ }
}
- pos += 1
+ oldPos += 1
}
data = newData
capacity = newCapacity
- mask = newCapacity - 1
+ mask = newMask
+ growThreshold = LOAD_FACTOR * newCapacity
}
private def nextPowerOf2(n: Int): Int = {
diff --git a/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala b/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala
index 0b51c23f7b..a38329df03 100644
--- a/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala
+++ b/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala
@@ -34,6 +34,8 @@ class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Ordering[A])
override def iterator: Iterator[A] = underlying.iterator.asScala
+ override def size: Int = underlying.size
+
override def ++=(xs: TraversableOnce[A]): this.type = {
xs.foreach { this += _ }
this
diff --git a/core/src/main/scala/org/apache/spark/util/IndestructibleActorSystem.scala b/core/src/main/scala/org/apache/spark/util/IndestructibleActorSystem.scala
new file mode 100644
index 0000000000..bf71882ef7
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/IndestructibleActorSystem.scala
@@ -0,0 +1,68 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+// Must be in akka.actor package as ActorSystemImpl is protected[akka].
+package akka.actor
+
+import scala.util.control.{ControlThrowable, NonFatal}
+
+import com.typesafe.config.Config
+
+/**
+ * An [[akka.actor.ActorSystem]] which refuses to shut down in the event of a fatal exception.
+ * This is necessary as Spark Executors are allowed to recover from fatal exceptions
+ * (see [[org.apache.spark.executor.Executor]]).
+ */
+object IndestructibleActorSystem {
+ def apply(name: String, config: Config): ActorSystem =
+ apply(name, config, ActorSystem.findClassLoader())
+
+ def apply(name: String, config: Config, classLoader: ClassLoader): ActorSystem =
+ new IndestructibleActorSystemImpl(name, config, classLoader).start()
+}
+
+private[akka] class IndestructibleActorSystemImpl(
+ override val name: String,
+ applicationConfig: Config,
+ classLoader: ClassLoader)
+ extends ActorSystemImpl(name, applicationConfig, classLoader) {
+
+ protected override def uncaughtExceptionHandler: Thread.UncaughtExceptionHandler = {
+ val fallbackHandler = super.uncaughtExceptionHandler
+
+ new Thread.UncaughtExceptionHandler() {
+ def uncaughtException(thread: Thread, cause: Throwable): Unit = {
+ if (isFatalError(cause) && !settings.JvmExitOnFatalError) {
+ log.error(cause, "Uncaught fatal error from thread [{}] not shutting down " +
+ "ActorSystem [{}] tolerating and continuing.... ", thread.getName, name)
+ //shutdown() //TODO make it configurable
+ } else {
+ fallbackHandler.uncaughtException(thread, cause)
+ }
+ }
+ }
+ }
+
+ def isFatalError(e: Throwable): Boolean = {
+ e match {
+ case NonFatal(_) | _: InterruptedException | _: NotImplementedError | _: ControlThrowable =>
+ false
+ case _ =>
+ true
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
index 67a7f87a5c..7b41ef89f1 100644
--- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
+++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala
@@ -55,8 +55,7 @@ class MetadataCleaner(cleanerType: MetadataCleanerType.MetadataCleanerType, clea
}
}
-object MetadataCleanerType extends Enumeration("MapOutputTracker", "SparkContext", "HttpBroadcast", "DagScheduler", "ResultTask",
- "ShuffleMapTask", "BlockManager", "DiskBlockManager", "BroadcastVars") {
+object MetadataCleanerType extends Enumeration {
val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, RESULT_TASK,
SHUFFLE_MAP_TASK, BLOCK_MANAGER, SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS = Value
diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
index 277de2f8a6..dbff571de9 100644
--- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala
@@ -85,7 +85,7 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() with Logging {
}
override def filter(p: ((A, B)) => Boolean): Map[A, B] = {
- JavaConversions.asScalaConcurrentMap(internalMap).map(kv => (kv._1, kv._2._1)).filter(p)
+ JavaConversions.mapAsScalaConcurrentMap(internalMap).map(kv => (kv._1, kv._2._1)).filter(p)
}
override def empty: Map[A, B] = new TimeStampedHashMap[A, B]()
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index fe932d8ede..3f7858d2de 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -22,10 +22,11 @@ import java.net.{InetAddress, URL, URI, NetworkInterface, Inet4Address}
import java.util.{Locale, Random, UUID}
import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadPoolExecutor}
+import scala.collection.JavaConversions._
import scala.collection.Map
import scala.collection.mutable.ArrayBuffer
-import scala.collection.JavaConversions._
import scala.io.Source
+import scala.reflect.ClassTag
import com.google.common.io.Files
import com.google.common.util.concurrent.ThreadFactoryBuilder
@@ -319,7 +320,7 @@ private[spark] object Utils extends Logging {
* result in a new collection. Unlike scala.util.Random.shuffle, this method
* uses a local random number generator, avoiding inter-thread contention.
*/
- def randomize[T: ClassManifest](seq: TraversableOnce[T]): Seq[T] = {
+ def randomize[T: ClassTag](seq: TraversableOnce[T]): Seq[T] = {
randomizeInPlace(seq.toArray)
}
@@ -823,4 +824,28 @@ private[spark] object Utils extends Logging {
return System.getProperties().clone()
.asInstanceOf[java.util.Properties].toMap[String, String]
}
+
+ /**
+ * Method executed for repeating a task for side effects.
+ * Unlike a for comprehension, it permits JVM JIT optimization
+ */
+ def times(numIters: Int)(f: => Unit): Unit = {
+ var i = 0
+ while (i < numIters) {
+ f
+ i += 1
+ }
+ }
+
+ /**
+ * Timing method based on iterations that permit JVM JIT optimization.
+ * @param numIters number of iterations
+ * @param f function to be executed
+ */
+ def timeIt(numIters: Int)(f: => Unit): Long = {
+ val start = System.currentTimeMillis
+ times(numIters)(f)
+ System.currentTimeMillis - start
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala
new file mode 100644
index 0000000000..e9907e6c85
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.util.{Random => JavaRandom}
+import org.apache.spark.util.Utils.timeIt
+
+/**
+ * This class implements a XORShift random number generator algorithm
+ * Source:
+ * Marsaglia, G. (2003). Xorshift RNGs. Journal of Statistical Software, Vol. 8, Issue 14.
+ * @see <a href="http://www.jstatsoft.org/v08/i14/paper">Paper</a>
+ * This implementation is approximately 3.5 times faster than
+ * {@link java.util.Random java.util.Random}, partly because of the algorithm, but also due
+ * to renouncing thread safety. JDK's implementation uses an AtomicLong seed, this class
+ * uses a regular Long. We can forgo thread safety since we use a new instance of the RNG
+ * for each thread.
+ */
+private[spark] class XORShiftRandom(init: Long) extends JavaRandom(init) {
+
+ def this() = this(System.nanoTime)
+
+ private var seed = init
+
+ // we need to just override next - this will be called by nextInt, nextDouble,
+ // nextGaussian, nextLong, etc.
+ override protected def next(bits: Int): Int = {
+ var nextSeed = seed ^ (seed << 21)
+ nextSeed ^= (nextSeed >>> 35)
+ nextSeed ^= (nextSeed << 4)
+ seed = nextSeed
+ (nextSeed & ((1L << bits) -1)).asInstanceOf[Int]
+ }
+}
+
+/** Contains benchmark method and main method to run benchmark of the RNG */
+private[spark] object XORShiftRandom {
+
+ /**
+ * Main method for running benchmark
+ * @param args takes one argument - the number of random numbers to generate
+ */
+ def main(args: Array[String]): Unit = {
+ if (args.length != 1) {
+ println("Benchmark of XORShiftRandom vis-a-vis java.util.Random")
+ println("Usage: XORShiftRandom number_of_random_numbers_to_generate")
+ System.exit(1)
+ }
+ println(benchmark(args(0).toInt))
+ }
+
+ /**
+ * @param numIters Number of random numbers to generate while running the benchmark
+ * @return Map of execution times for {@link java.util.Random java.util.Random}
+ * and XORShift
+ */
+ def benchmark(numIters: Int) = {
+
+ val seed = 1L
+ val million = 1e6.toInt
+ val javaRand = new JavaRandom(seed)
+ val xorRand = new XORShiftRandom(seed)
+
+ // this is just to warm up the JIT - we're not timing anything
+ timeIt(1e6.toInt) {
+ javaRand.nextInt()
+ xorRand.nextInt()
+ }
+
+ val iters = timeIt(numIters)(_)
+
+ /* Return results as a map instead of just printing to screen
+ in case the user wants to do something with them */
+ Map("javaTime" -> iters {javaRand.nextInt()},
+ "xorTime" -> iters {xorRand.nextInt()})
+
+ }
+
+} \ No newline at end of file
diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
index 80545c9688..c26f23d500 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala
@@ -17,6 +17,7 @@
package org.apache.spark.util.collection
+import scala.reflect.ClassTag
/**
* A fast hash map implementation for nullable keys. This hash map supports insertions and updates,
@@ -26,7 +27,7 @@ package org.apache.spark.util.collection
* Under the hood, it uses our OpenHashSet implementation.
*/
private[spark]
-class OpenHashMap[K >: Null : ClassManifest, @specialized(Long, Int, Double) V: ClassManifest](
+class OpenHashMap[K >: Null : ClassTag, @specialized(Long, Int, Double) V: ClassTag](
initialCapacity: Int)
extends Iterable[(K, V)]
with Serializable {
diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
index 4592e4f939..87e009a4de 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
@@ -17,6 +17,7 @@
package org.apache.spark.util.collection
+import scala.reflect._
/**
* A simple, fast hash set optimized for non-null insertion-only use case, where keys are never
@@ -36,7 +37,7 @@ package org.apache.spark.util.collection
* to explore all spaces for each key (see http://en.wikipedia.org/wiki/Quadratic_probing).
*/
private[spark]
-class OpenHashSet[@specialized(Long, Int) T: ClassManifest](
+class OpenHashSet[@specialized(Long, Int) T: ClassTag](
initialCapacity: Int,
loadFactor: Double)
extends Serializable {
@@ -62,14 +63,14 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest](
// throws:
// scala.tools.nsc.symtab.Types$TypeError: type mismatch;
// found : scala.reflect.AnyValManifest[Long]
- // required: scala.reflect.ClassManifest[Int]
+ // required: scala.reflect.ClassTag[Int]
// at scala.tools.nsc.typechecker.Contexts$Context.error(Contexts.scala:298)
// at scala.tools.nsc.typechecker.Infer$Inferencer.error(Infer.scala:207)
// ...
- val mt = classManifest[T]
- if (mt == ClassManifest.Long) {
+ val mt = classTag[T]
+ if (mt == ClassTag.Long) {
(new LongHasher).asInstanceOf[Hasher[T]]
- } else if (mt == ClassManifest.Int) {
+ } else if (mt == ClassTag.Int) {
(new IntHasher).asInstanceOf[Hasher[T]]
} else {
new Hasher[T]
@@ -79,6 +80,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest](
protected var _capacity = nextPowerOf2(initialCapacity)
protected var _mask = _capacity - 1
protected var _size = 0
+ protected var _growThreshold = (loadFactor * _capacity).toInt
protected var _bitset = new BitSet(_capacity)
@@ -115,7 +117,29 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest](
* @return The position where the key is placed, plus the highest order bit is set if the key
* exists previously.
*/
- def addWithoutResize(k: T): Int = putInto(_bitset, _data, k)
+ def addWithoutResize(k: T): Int = {
+ var pos = hashcode(hasher.hash(k)) & _mask
+ var i = 1
+ while (true) {
+ if (!_bitset.get(pos)) {
+ // This is a new key.
+ _data(pos) = k
+ _bitset.set(pos)
+ _size += 1
+ return pos | NONEXISTENCE_MASK
+ } else if (_data(pos) == k) {
+ // Found an existing key.
+ return pos
+ } else {
+ val delta = i
+ pos = (pos + delta) & _mask
+ i += 1
+ }
+ }
+ // Never reached here
+ assert(INVALID_POS != INVALID_POS)
+ INVALID_POS
+ }
/**
* Rehash the set if it is overloaded.
@@ -126,7 +150,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest](
* to a new position (in the new data array).
*/
def rehashIfNeeded(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) {
- if (_size > loadFactor * _capacity) {
+ if (_size > _growThreshold) {
rehash(k, allocateFunc, moveFunc)
}
}
@@ -161,37 +185,6 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest](
def nextPos(fromPos: Int): Int = _bitset.nextSetBit(fromPos)
/**
- * Put an entry into the set. Return the position where the key is placed. In addition, the
- * highest bit in the returned position is set if the key exists prior to this put.
- *
- * This function assumes the data array has at least one empty slot.
- */
- private def putInto(bitset: BitSet, data: Array[T], k: T): Int = {
- val mask = data.length - 1
- var pos = hashcode(hasher.hash(k)) & mask
- var i = 1
- while (true) {
- if (!bitset.get(pos)) {
- // This is a new key.
- data(pos) = k
- bitset.set(pos)
- _size += 1
- return pos | NONEXISTENCE_MASK
- } else if (data(pos) == k) {
- // Found an existing key.
- return pos
- } else {
- val delta = i
- pos = (pos + delta) & mask
- i += 1
- }
- }
- // Never reached here
- assert(INVALID_POS != INVALID_POS)
- INVALID_POS
- }
-
- /**
* Double the table's size and re-hash everything. We are not really using k, but it is declared
* so Scala compiler can specialize this method (which leads to calling the specialized version
* of putInto).
@@ -204,34 +197,49 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest](
*/
private def rehash(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) {
val newCapacity = _capacity * 2
- require(newCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements")
-
allocateFunc(newCapacity)
- val newData = new Array[T](newCapacity)
val newBitset = new BitSet(newCapacity)
- var pos = 0
- _size = 0
- while (pos < _capacity) {
- if (_bitset.get(pos)) {
- val newPos = putInto(newBitset, newData, _data(pos))
- moveFunc(pos, newPos & POSITION_MASK)
+ val newData = new Array[T](newCapacity)
+ val newMask = newCapacity - 1
+
+ var oldPos = 0
+ while (oldPos < capacity) {
+ if (_bitset.get(oldPos)) {
+ val key = _data(oldPos)
+ var newPos = hashcode(hasher.hash(key)) & newMask
+ var i = 1
+ var keepGoing = true
+ // No need to check for equality here when we insert so this has one less if branch than
+ // the similar code path in addWithoutResize.
+ while (keepGoing) {
+ if (!newBitset.get(newPos)) {
+ // Inserting the key at newPos
+ newData(newPos) = key
+ newBitset.set(newPos)
+ moveFunc(oldPos, newPos)
+ keepGoing = false
+ } else {
+ val delta = i
+ newPos = (newPos + delta) & newMask
+ i += 1
+ }
+ }
}
- pos += 1
+ oldPos += 1
}
+
_bitset = newBitset
_data = newData
_capacity = newCapacity
- _mask = newCapacity - 1
+ _mask = newMask
+ _growThreshold = (loadFactor * newCapacity).toInt
}
/**
- * Re-hash a value to deal better with hash functions that don't differ
- * in the lower bits, similar to java.util.HashMap
+ * Re-hash a value to deal better with hash functions that don't differ in the lower bits.
+ * We use the Murmur Hash 3 finalization step that's also used in fastutil.
*/
- private def hashcode(h: Int): Int = {
- val r = h ^ (h >>> 20) ^ (h >>> 12)
- r ^ (r >>> 7) ^ (r >>> 4)
- }
+ private def hashcode(h: Int): Int = it.unimi.dsi.fastutil.HashCommon.murmurHash3(h)
private def nextPowerOf2(n: Int): Int = {
val highBit = Integer.highestOneBit(n)
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala
index d76143e45a..2e1ef06cbc 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala
@@ -17,6 +17,7 @@
package org.apache.spark.util.collection
+import scala.reflect._
/**
* A fast hash map implementation for primitive, non-null keys. This hash map supports
@@ -26,15 +27,15 @@ package org.apache.spark.util.collection
* Under the hood, it uses our OpenHashSet implementation.
*/
private[spark]
-class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassManifest,
- @specialized(Long, Int, Double) V: ClassManifest](
+class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag,
+ @specialized(Long, Int, Double) V: ClassTag](
initialCapacity: Int)
extends Iterable[(K, V)]
with Serializable {
def this() = this(64)
- require(classManifest[K] == classManifest[Long] || classManifest[K] == classManifest[Int])
+ require(classTag[K] == classTag[Long] || classTag[K] == classTag[Int])
// Init in constructor (instead of in declaration) to work around a Scala compiler specialization
// bug that would generate two arrays (one for Object and one for specialized T).
diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala
index 20554f0aab..b84eb65c62 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala
@@ -17,11 +17,13 @@
package org.apache.spark.util.collection
+import scala.reflect.ClassTag
+
/**
* An append-only, non-threadsafe, array-backed vector that is optimized for primitive types.
*/
private[spark]
-class PrimitiveVector[@specialized(Long, Int, Double) V: ClassManifest](initialSize: Int = 64) {
+class PrimitiveVector[@specialized(Long, Int, Double) V: ClassTag](initialSize: Int = 64) {
private var _numElements = 0
private var _array: Array[V] = _
diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
index 4434f3b87c..c443c5266e 100644
--- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala
@@ -27,6 +27,21 @@ import org.apache.spark.SparkContext._
class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkContext {
+
+ implicit def setAccum[A] = new AccumulableParam[mutable.Set[A], A] {
+ def addInPlace(t1: mutable.Set[A], t2: mutable.Set[A]) : mutable.Set[A] = {
+ t1 ++= t2
+ t1
+ }
+ def addAccumulator(t1: mutable.Set[A], t2: A) : mutable.Set[A] = {
+ t1 += t2
+ t1
+ }
+ def zero(t: mutable.Set[A]) : mutable.Set[A] = {
+ new mutable.HashSet[A]()
+ }
+ }
+
test ("basic accumulation"){
sc = new SparkContext("local", "test")
val acc : Accumulator[Int] = sc.accumulator(0)
@@ -51,7 +66,6 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkConte
}
test ("add value to collection accumulators") {
- import SetAccum._
val maxI = 1000
for (nThreads <- List(1, 10)) { //test single & multi-threaded
sc = new SparkContext("local[" + nThreads + "]", "test")
@@ -68,22 +82,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkConte
}
}
- implicit object SetAccum extends AccumulableParam[mutable.Set[Any], Any] {
- def addInPlace(t1: mutable.Set[Any], t2: mutable.Set[Any]) : mutable.Set[Any] = {
- t1 ++= t2
- t1
- }
- def addAccumulator(t1: mutable.Set[Any], t2: Any) : mutable.Set[Any] = {
- t1 += t2
- t1
- }
- def zero(t: mutable.Set[Any]) : mutable.Set[Any] = {
- new mutable.HashSet[Any]()
- }
- }
-
test ("value not readable in tasks") {
- import SetAccum._
val maxI = 1000
for (nThreads <- List(1, 10)) { //test single & multi-threaded
sc = new SparkContext("local[" + nThreads + "]", "test")
@@ -125,7 +124,6 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkConte
}
test ("localValue readable in tasks") {
- import SetAccum._
val maxI = 1000
for (nThreads <- List(1, 10)) { //test single & multi-threaded
sc = new SparkContext("local[" + nThreads + "]", "test")
diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
index f26c44d3e7..f25d921d3f 100644
--- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
@@ -17,6 +17,7 @@
package org.apache.spark
+import scala.reflect.ClassTag
import org.scalatest.FunSuite
import java.io.File
import org.apache.spark.rdd._
@@ -62,8 +63,6 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
testCheckpointing(_.sample(false, 0.5, 0))
testCheckpointing(_.glom())
testCheckpointing(_.mapPartitions(_.map(_.toString)))
- testCheckpointing(r => new MapPartitionsWithContextRDD(r,
- (context: TaskContext, iter: Iterator[Int]) => iter.map(_.toString), false ))
testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString))
testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x))
testCheckpointing(_.pipe(Seq("cat")))
@@ -207,7 +206,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
* not, but this is not done by default as usually the partitions do not refer to any RDD and
* therefore never store the lineage.
*/
- def testCheckpointing[U: ClassManifest](
+ def testCheckpointing[U: ClassTag](
op: (RDD[Int]) => RDD[U],
testRDDSize: Boolean = true,
testRDDPartitionSize: Boolean = false
@@ -276,7 +275,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
* RDDs partitions. So even if the parent RDD is checkpointed and its partitions changed,
* this RDD will remember the partitions and therefore potentially the whole lineage.
*/
- def testParentCheckpointing[U: ClassManifest](
+ def testParentCheckpointing[U: ClassTag](
op: (RDD[Int]) => RDD[U],
testRDDSize: Boolean,
testRDDPartitionSize: Boolean
diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
index 480bac84f3..d9cb7fead5 100644
--- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala
+++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala
@@ -122,7 +122,7 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
sc.parallelize(1 to 10, 10).foreach(x => println(x / 0))
}
assert(thrown.getClass === classOf[SparkException])
- assert(thrown.getMessage.contains("more than 4 times"))
+ assert(thrown.getMessage.contains("failed 4 times"))
}
test("caching") {
@@ -303,12 +303,13 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter
Thread.sleep(200)
}
} catch {
- case _ => { Thread.sleep(10) }
+ case _: Throwable => { Thread.sleep(10) }
// Do nothing. We might see exceptions because block manager
// is racing this thread to remove entries from the driver.
}
}
}
+
}
object DistributedSuite {
diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala
index 01a72d8401..6d1695eae7 100644
--- a/core/src/test/scala/org/apache/spark/DriverSuite.scala
+++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala
@@ -34,7 +34,7 @@ class DriverSuite extends FunSuite with Timeouts {
// Regression test for SPARK-530: "Spark driver process doesn't exit after finishing"
val masters = Table(("master"), ("local"), ("local-cluster[2,1,512]"))
forAll(masters) { (master: String) =>
- failAfter(30 seconds) {
+ failAfter(60 seconds) {
Utils.execute(Seq("./spark-class", "org.apache.spark.DriverWithoutCleanup", master),
new File(System.getenv("SPARK_HOME")))
}
diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
index 352036f182..4234f6eac7 100644
--- a/core/src/test/scala/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
@@ -365,6 +365,20 @@ public class JavaAPISuite implements Serializable {
}
@Test
+ public void javaDoubleRDDHistoGram() {
+ JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0));
+ // Test using generated buckets
+ Tuple2<double[], long[]> results = rdd.histogram(2);
+ double[] expected_buckets = {1.0, 2.5, 4.0};
+ long[] expected_counts = {2, 2};
+ Assert.assertArrayEquals(expected_buckets, results._1, 0.1);
+ Assert.assertArrayEquals(expected_counts, results._2);
+ // Test with provided buckets
+ long[] histogram = rdd.histogram(expected_buckets);
+ Assert.assertArrayEquals(expected_counts, histogram);
+ }
+
+ @Test
public void map() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
JavaDoubleRDD doubles = rdd.map(new DoubleFunction<Integer>() {
diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
index d8a0e983b2..1121e06e2e 100644
--- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
+++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala
@@ -114,7 +114,7 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf
// Once A is cancelled, job B should finish fairly quickly.
assert(jobB.get() === 100)
}
-
+/*
test("two jobs sharing the same stage") {
// sem1: make sure cancel is issued after some tasks are launched
// sem2: make sure the first stage is not finished until cancel is issued
@@ -148,7 +148,7 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf
intercept[SparkException] { f1.get() }
intercept[SparkException] { f2.get() }
}
-
+ */
def testCount() {
// Cancel before launching any tasks
{
diff --git a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala
index 459e257d79..8dd5786da6 100644
--- a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala
+++ b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala
@@ -30,7 +30,7 @@ trait LocalSparkContext extends BeforeAndAfterEach with BeforeAndAfterAll { self
@transient var sc: SparkContext = _
override def beforeAll() {
- InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory());
+ InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory())
super.beforeAll()
}
diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
index b7eb268bd5..271dc905bc 100644
--- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala
@@ -49,14 +49,14 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
test("master start and stop") {
val actorSystem = ActorSystem("test")
val tracker = new MapOutputTrackerMaster()
- tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker)))
+ tracker.trackerActor = Left(actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker))))
tracker.stop()
}
test("master register and fetch") {
val actorSystem = ActorSystem("test")
val tracker = new MapOutputTrackerMaster()
- tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker)))
+ tracker.trackerActor = Left(actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker))))
tracker.registerShuffle(10, 2)
val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
@@ -75,7 +75,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
test("master register and unregister and fetch") {
val actorSystem = ActorSystem("test")
val tracker = new MapOutputTrackerMaster()
- tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker)))
+ tracker.trackerActor = Left(actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker))))
tracker.registerShuffle(10, 2)
val compressedSize1000 = MapOutputTracker.compressSize(1000L)
val compressedSize10000 = MapOutputTracker.compressSize(10000L)
@@ -101,13 +101,13 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext {
System.setProperty("spark.hostPort", hostname + ":" + boundPort)
val masterTracker = new MapOutputTrackerMaster()
- masterTracker.trackerActor = actorSystem.actorOf(
- Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker")
+ masterTracker.trackerActor = Left(actorSystem.actorOf(
+ Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker"))
val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0)
val slaveTracker = new MapOutputTracker()
- slaveTracker.trackerActor = slaveSystem.actorFor(
- "akka://spark@localhost:" + boundPort + "/user/MapOutputTracker")
+ slaveTracker.trackerActor = Right(slaveSystem.actorSelection(
+ "akka.tcp://spark@localhost:" + boundPort + "/user/MapOutputTracker"))
masterTracker.registerShuffle(10, 1)
masterTracker.incrementEpoch()
diff --git a/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala b/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala
deleted file mode 100644
index 21f16ef2c6..0000000000
--- a/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala
+++ /dev/null
@@ -1,45 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark
-
-import org.scalatest.FunSuite
-import org.apache.spark.SparkContext._
-import org.apache.spark.rdd.{RDD, PartitionPruningRDD}
-
-
-class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext {
-
- test("Pruned Partitions inherit locality prefs correctly") {
- class TestPartition(i: Int) extends Partition {
- def index = i
- }
- val rdd = new RDD[Int](sc, Nil) {
- override protected def getPartitions = {
- Array[Partition](
- new TestPartition(1),
- new TestPartition(2),
- new TestPartition(3))
- }
- def compute(split: Partition, context: TaskContext) = {Iterator()}
- }
- val prunedRDD = PartitionPruningRDD.create(rdd, {x => if (x==2) true else false})
- val p = prunedRDD.partitions(0)
- assert(p.index == 2)
- assert(prunedRDD.partitions.length == 1)
- }
-}
diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala
index 7d938917f2..1374d01774 100644
--- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala
+++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala
@@ -142,11 +142,11 @@ class PartitioningSuite extends FunSuite with SharedSparkContext {
.filter(_ >= 0.0)
// Run the partitions, including the consecutive empty ones, through StatCounter
- val stats: StatCounter = rdd.stats();
- assert(abs(6.0 - stats.sum) < 0.01);
- assert(abs(6.0/2 - rdd.mean) < 0.01);
- assert(abs(1.0 - rdd.variance) < 0.01);
- assert(abs(1.0 - rdd.stdev) < 0.01);
+ val stats: StatCounter = rdd.stats()
+ assert(abs(6.0 - stats.sum) < 0.01)
+ assert(abs(6.0/2 - rdd.mean) < 0.01)
+ assert(abs(1.0 - rdd.variance) < 0.01)
+ assert(abs(1.0 - rdd.stdev) < 0.01)
// Add other tests here for classes that should be able to handle empty partitions correctly
}
diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala
new file mode 100644
index 0000000000..151af0d213
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala
@@ -0,0 +1,140 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import org.scalatest.{FunSuite, PrivateMethodTester}
+
+import org.apache.spark.scheduler.TaskScheduler
+import org.apache.spark.scheduler.cluster.{ClusterScheduler, SimrSchedulerBackend, SparkDeploySchedulerBackend}
+import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
+import org.apache.spark.scheduler.local.LocalScheduler
+
+class SparkContextSchedulerCreationSuite
+ extends FunSuite with PrivateMethodTester with LocalSparkContext with Logging {
+
+ def createTaskScheduler(master: String): TaskScheduler = {
+ // Create local SparkContext to setup a SparkEnv. We don't actually want to start() the
+ // real schedulers, so we don't want to create a full SparkContext with the desired scheduler.
+ sc = new SparkContext("local", "test")
+ val createTaskSchedulerMethod = PrivateMethod[TaskScheduler]('createTaskScheduler)
+ SparkContext invokePrivate createTaskSchedulerMethod(sc, master, "test")
+ }
+
+ test("bad-master") {
+ val e = intercept[SparkException] {
+ createTaskScheduler("localhost:1234")
+ }
+ assert(e.getMessage.contains("Could not parse Master URL"))
+ }
+
+ test("local") {
+ createTaskScheduler("local") match {
+ case s: LocalScheduler =>
+ assert(s.threads === 1)
+ assert(s.maxFailures === 0)
+ case _ => fail()
+ }
+ }
+
+ test("local-n") {
+ createTaskScheduler("local[5]") match {
+ case s: LocalScheduler =>
+ assert(s.threads === 5)
+ assert(s.maxFailures === 0)
+ case _ => fail()
+ }
+ }
+
+ test("local-n-failures") {
+ createTaskScheduler("local[4, 2]") match {
+ case s: LocalScheduler =>
+ assert(s.threads === 4)
+ assert(s.maxFailures === 2)
+ case _ => fail()
+ }
+ }
+
+ test("simr") {
+ createTaskScheduler("simr://uri") match {
+ case s: ClusterScheduler =>
+ assert(s.backend.isInstanceOf[SimrSchedulerBackend])
+ case _ => fail()
+ }
+ }
+
+ test("local-cluster") {
+ createTaskScheduler("local-cluster[3, 14, 512]") match {
+ case s: ClusterScheduler =>
+ assert(s.backend.isInstanceOf[SparkDeploySchedulerBackend])
+ case _ => fail()
+ }
+ }
+
+ def testYarn(master: String, expectedClassName: String) {
+ try {
+ createTaskScheduler(master) match {
+ case s: ClusterScheduler =>
+ assert(s.getClass === Class.forName(expectedClassName))
+ case _ => fail()
+ }
+ } catch {
+ case e: SparkException =>
+ assert(e.getMessage.contains("YARN mode not available"))
+ logWarning("YARN not available, could not test actual YARN scheduler creation")
+ case e: Throwable => fail(e)
+ }
+ }
+
+ test("yarn-standalone") {
+ testYarn("yarn-standalone", "org.apache.spark.scheduler.cluster.YarnClusterScheduler")
+ }
+
+ test("yarn-client") {
+ testYarn("yarn-client", "org.apache.spark.scheduler.cluster.YarnClientClusterScheduler")
+ }
+
+ def testMesos(master: String, expectedClass: Class[_]) {
+ try {
+ createTaskScheduler(master) match {
+ case s: ClusterScheduler =>
+ assert(s.backend.getClass === expectedClass)
+ case _ => fail()
+ }
+ } catch {
+ case e: UnsatisfiedLinkError =>
+ assert(e.getMessage.contains("no mesos in"))
+ logWarning("Mesos not available, could not test actual Mesos scheduler creation")
+ case e: Throwable => fail(e)
+ }
+ }
+
+ test("mesos fine-grained") {
+ System.setProperty("spark.mesos.coarse", "false")
+ testMesos("mesos://localhost:1234", classOf[MesosSchedulerBackend])
+ }
+
+ test("mesos coarse-grained") {
+ System.setProperty("spark.mesos.coarse", "true")
+ testMesos("mesos://localhost:1234", classOf[CoarseMesosSchedulerBackend])
+ }
+
+ test("mesos with zookeeper") {
+ System.setProperty("spark.mesos.coarse", "false")
+ testMesos("zk://localhost:1234,localhost:2345", classOf[MesosSchedulerBackend])
+ }
+}
diff --git a/core/src/test/scala/org/apache/spark/UnpersistSuite.scala b/core/src/test/scala/org/apache/spark/UnpersistSuite.scala
index 46a2da1724..768ca3850e 100644
--- a/core/src/test/scala/org/apache/spark/UnpersistSuite.scala
+++ b/core/src/test/scala/org/apache/spark/UnpersistSuite.scala
@@ -37,7 +37,7 @@ class UnpersistSuite extends FunSuite with LocalSparkContext {
Thread.sleep(200)
}
} catch {
- case _ => { Thread.sleep(10) }
+ case _: Throwable => { Thread.sleep(10) }
// Do nothing. We might see exceptions because block manager
// is racing this thread to remove entries from the driver.
}
diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala
index 8f0954122b..4cb4ddc9cd 100644
--- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala
+++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala
@@ -1,3 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
package org.apache.spark.deploy.worker
import java.io.File
diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
index da032b17d9..0d4c10db8e 100644
--- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala
@@ -19,6 +19,8 @@ package org.apache.spark.rdd
import java.util.concurrent.Semaphore
+import scala.concurrent.{Await, TimeoutException}
+import scala.concurrent.duration.Duration
import scala.concurrent.ExecutionContext.Implicits.global
import org.scalatest.{BeforeAndAfterAll, FunSuite}
@@ -173,4 +175,28 @@ class AsyncRDDActionsSuite extends FunSuite with BeforeAndAfterAll with Timeouts
sem.acquire(2)
}
}
+
+ /**
+ * Awaiting FutureAction results
+ */
+ test("FutureAction result, infinite wait") {
+ val f = sc.parallelize(1 to 100, 4)
+ .countAsync()
+ assert(Await.result(f, Duration.Inf) === 100)
+ }
+
+ test("FutureAction result, finite wait") {
+ val f = sc.parallelize(1 to 100, 4)
+ .countAsync()
+ assert(Await.result(f, Duration(30, "seconds")) === 100)
+ }
+
+ test("FutureAction result, timeout") {
+ val f = sc.parallelize(1 to 100, 4)
+ .mapPartitions(itr => { Thread.sleep(20); itr })
+ .countAsync()
+ intercept[TimeoutException] {
+ Await.result(f, Duration(20, "milliseconds"))
+ }
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala
new file mode 100644
index 0000000000..7f50a5a47c
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala
@@ -0,0 +1,271 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import scala.math.abs
+import scala.collection.mutable.ArrayBuffer
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.SparkContext._
+import org.apache.spark.rdd._
+import org.apache.spark._
+
+class DoubleRDDSuite extends FunSuite with SharedSparkContext {
+ // Verify tests on the histogram functionality. We test with both evenly
+ // and non-evenly spaced buckets as the bucket lookup function changes.
+ test("WorksOnEmpty") {
+ // Make sure that it works on an empty input
+ val rdd: RDD[Double] = sc.parallelize(Seq())
+ val buckets = Array(0.0, 10.0)
+ val histogramResults = rdd.histogram(buckets)
+ val histogramResults2 = rdd.histogram(buckets, true)
+ val expectedHistogramResults = Array(0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramResults2 === expectedHistogramResults)
+ }
+
+ test("WorksWithOutOfRangeWithOneBucket") {
+ // Verify that if all of the elements are out of range the counts are zero
+ val rdd = sc.parallelize(Seq(10.01, -0.01))
+ val buckets = Array(0.0, 10.0)
+ val histogramResults = rdd.histogram(buckets)
+ val histogramResults2 = rdd.histogram(buckets, true)
+ val expectedHistogramResults = Array(0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramResults2 === expectedHistogramResults)
+ }
+
+ test("WorksInRangeWithOneBucket") {
+ // Verify the basic case of one bucket and all elements in that bucket works
+ val rdd = sc.parallelize(Seq(1, 2, 3, 4))
+ val buckets = Array(0.0, 10.0)
+ val histogramResults = rdd.histogram(buckets)
+ val histogramResults2 = rdd.histogram(buckets, true)
+ val expectedHistogramResults = Array(4)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramResults2 === expectedHistogramResults)
+ }
+
+ test("WorksInRangeWithOneBucketExactMatch") {
+ // Verify the basic case of one bucket and all elements in that bucket works
+ val rdd = sc.parallelize(Seq(1, 2, 3, 4))
+ val buckets = Array(1.0, 4.0)
+ val histogramResults = rdd.histogram(buckets)
+ val histogramResults2 = rdd.histogram(buckets, true)
+ val expectedHistogramResults = Array(4)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramResults2 === expectedHistogramResults)
+ }
+
+ test("WorksWithOutOfRangeWithTwoBuckets") {
+ // Verify that out of range works with two buckets
+ val rdd = sc.parallelize(Seq(10.01, -0.01))
+ val buckets = Array(0.0, 5.0, 10.0)
+ val histogramResults = rdd.histogram(buckets)
+ val histogramResults2 = rdd.histogram(buckets, true)
+ val expectedHistogramResults = Array(0, 0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramResults2 === expectedHistogramResults)
+ }
+
+ test("WorksWithOutOfRangeWithTwoUnEvenBuckets") {
+ // Verify that out of range works with two un even buckets
+ val rdd = sc.parallelize(Seq(10.01, -0.01))
+ val buckets = Array(0.0, 4.0, 10.0)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(0, 0)
+ assert(histogramResults === expectedHistogramResults)
+ }
+
+ test("WorksInRangeWithTwoBuckets") {
+ // Make sure that it works with two equally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(1, 2, 3, 5, 6))
+ val buckets = Array(0.0, 5.0, 10.0)
+ val histogramResults = rdd.histogram(buckets)
+ val histogramResults2 = rdd.histogram(buckets, true)
+ val expectedHistogramResults = Array(3, 2)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramResults2 === expectedHistogramResults)
+ }
+
+ test("WorksInRangeWithTwoBucketsAndNaN") {
+ // Make sure that it works with two equally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(1, 2, 3, 5, 6, Double.NaN))
+ val buckets = Array(0.0, 5.0, 10.0)
+ val histogramResults = rdd.histogram(buckets)
+ val histogramResults2 = rdd.histogram(buckets, true)
+ val expectedHistogramResults = Array(3, 2)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramResults2 === expectedHistogramResults)
+ }
+
+ test("WorksInRangeWithTwoUnevenBuckets") {
+ // Make sure that it works with two unequally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(1, 2, 3, 5, 6))
+ val buckets = Array(0.0, 5.0, 11.0)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(3, 2)
+ assert(histogramResults === expectedHistogramResults)
+ }
+
+ test("WorksMixedRangeWithTwoUnevenBuckets") {
+ // Make sure that it works with two unequally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.0, 11.01))
+ val buckets = Array(0.0, 5.0, 11.0)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(4, 3)
+ assert(histogramResults === expectedHistogramResults)
+ }
+
+ test("WorksMixedRangeWithFourUnevenBuckets") {
+ // Make sure that it works with two unequally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0,
+ 200.0, 200.1))
+ val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(4, 2, 1, 3)
+ assert(histogramResults === expectedHistogramResults)
+ }
+
+ test("WorksMixedRangeWithUnevenBucketsAndNaN") {
+ // Make sure that it works with two unequally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0,
+ 200.0, 200.1, Double.NaN))
+ val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(4, 2, 1, 3)
+ assert(histogramResults === expectedHistogramResults)
+ }
+ // Make sure this works with a NaN end bucket
+ test("WorksMixedRangeWithUnevenBucketsAndNaNAndNaNRange") {
+ // Make sure that it works with two unequally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0,
+ 200.0, 200.1, Double.NaN))
+ val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0, Double.NaN)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(4, 2, 1, 2, 3)
+ assert(histogramResults === expectedHistogramResults)
+ }
+ // Make sure this works with a NaN end bucket and an inifity
+ test("WorksMixedRangeWithUnevenBucketsAndNaNAndNaNRangeAndInfity") {
+ // Make sure that it works with two unequally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0,
+ 200.0, 200.1, 1.0/0.0, -1.0/0.0, Double.NaN))
+ val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0, Double.NaN)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(4, 2, 1, 2, 4)
+ assert(histogramResults === expectedHistogramResults)
+ }
+
+ test("WorksWithOutOfRangeWithInfiniteBuckets") {
+ // Verify that out of range works with two buckets
+ val rdd = sc.parallelize(Seq(10.01, -0.01, Double.NaN))
+ val buckets = Array(-1.0/0.0 , 0.0, 1.0/0.0)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(1, 1)
+ assert(histogramResults === expectedHistogramResults)
+ }
+ // Test the failure mode with an invalid bucket array
+ test("ThrowsExceptionOnInvalidBucketArray") {
+ val rdd = sc.parallelize(Seq(1.0))
+ // Empty array
+ intercept[IllegalArgumentException] {
+ val buckets = Array.empty[Double]
+ val result = rdd.histogram(buckets)
+ }
+ // Single element array
+ intercept[IllegalArgumentException] {
+ val buckets = Array(1.0)
+ val result = rdd.histogram(buckets)
+ }
+ }
+
+ // Test automatic histogram function
+ test("WorksWithoutBucketsBasic") {
+ // Verify the basic case of one bucket and all elements in that bucket works
+ val rdd = sc.parallelize(Seq(1, 2, 3, 4))
+ val (histogramBuckets, histogramResults) = rdd.histogram(1)
+ val expectedHistogramResults = Array(4)
+ val expectedHistogramBuckets = Array(1.0, 4.0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramBuckets === expectedHistogramBuckets)
+ }
+ // Test automatic histogram function with a single element
+ test("WorksWithoutBucketsBasicSingleElement") {
+ // Verify the basic case of one bucket and all elements in that bucket works
+ val rdd = sc.parallelize(Seq(1))
+ val (histogramBuckets, histogramResults) = rdd.histogram(1)
+ val expectedHistogramResults = Array(1)
+ val expectedHistogramBuckets = Array(1.0, 1.0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramBuckets === expectedHistogramBuckets)
+ }
+ // Test automatic histogram function with a single element
+ test("WorksWithoutBucketsBasicNoRange") {
+ // Verify the basic case of one bucket and all elements in that bucket works
+ val rdd = sc.parallelize(Seq(1, 1, 1, 1))
+ val (histogramBuckets, histogramResults) = rdd.histogram(1)
+ val expectedHistogramResults = Array(4)
+ val expectedHistogramBuckets = Array(1.0, 1.0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramBuckets === expectedHistogramBuckets)
+ }
+
+ test("WorksWithoutBucketsBasicTwo") {
+ // Verify the basic case of one bucket and all elements in that bucket works
+ val rdd = sc.parallelize(Seq(1, 2, 3, 4))
+ val (histogramBuckets, histogramResults) = rdd.histogram(2)
+ val expectedHistogramResults = Array(2, 2)
+ val expectedHistogramBuckets = Array(1.0, 2.5, 4.0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramBuckets === expectedHistogramBuckets)
+ }
+
+ test("WorksWithoutBucketsWithMoreRequestedThanElements") {
+ // Verify the basic case of one bucket and all elements in that bucket works
+ val rdd = sc.parallelize(Seq(1, 2))
+ val (histogramBuckets, histogramResults) = rdd.histogram(10)
+ val expectedHistogramResults =
+ Array(1, 0, 0, 0, 0, 0, 0, 0, 0, 1)
+ val expectedHistogramBuckets =
+ Array(1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramBuckets === expectedHistogramBuckets)
+ }
+
+ // Test the failure mode with an invalid RDD
+ test("ThrowsExceptionOnInvalidRDDs") {
+ // infinity
+ intercept[UnsupportedOperationException] {
+ val rdd = sc.parallelize(Seq(1, 1.0/0.0))
+ val result = rdd.histogram(1)
+ }
+ // NaN
+ intercept[UnsupportedOperationException] {
+ val rdd = sc.parallelize(Seq(1, Double.NaN))
+ val result = rdd.histogram(1)
+ }
+ // Empty
+ intercept[UnsupportedOperationException] {
+ val rdd: RDD[Double] = sc.parallelize(Seq())
+ val result = rdd.histogram(1)
+ }
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala
new file mode 100644
index 0000000000..53a7b7c44d
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import org.scalatest.FunSuite
+import org.apache.spark.{TaskContext, Partition, SharedSparkContext}
+
+
+class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext {
+
+
+ test("Pruned Partitions inherit locality prefs correctly") {
+
+ val rdd = new RDD[Int](sc, Nil) {
+ override protected def getPartitions = {
+ Array[Partition](
+ new TestPartition(0, 1),
+ new TestPartition(1, 1),
+ new TestPartition(2, 1))
+ }
+
+ def compute(split: Partition, context: TaskContext) = {
+ Iterator()
+ }
+ }
+ val prunedRDD = PartitionPruningRDD.create(rdd, {
+ x => if (x == 2) true else false
+ })
+ assert(prunedRDD.partitions.length == 1)
+ val p = prunedRDD.partitions(0)
+ assert(p.index == 0)
+ assert(p.asInstanceOf[PartitionPruningRDDPartition].parentSplit.index == 2)
+ }
+
+
+ test("Pruned Partitions can be unioned ") {
+
+ val rdd = new RDD[Int](sc, Nil) {
+ override protected def getPartitions = {
+ Array[Partition](
+ new TestPartition(0, 4),
+ new TestPartition(1, 5),
+ new TestPartition(2, 6))
+ }
+
+ def compute(split: Partition, context: TaskContext) = {
+ List(split.asInstanceOf[TestPartition].testValue).iterator
+ }
+ }
+ val prunedRDD1 = PartitionPruningRDD.create(rdd, {
+ x => if (x == 0) true else false
+ })
+
+ val prunedRDD2 = PartitionPruningRDD.create(rdd, {
+ x => if (x == 2) true else false
+ })
+
+ val merged = prunedRDD1 ++ prunedRDD2
+ assert(merged.count() == 2)
+ val take = merged.take(2)
+ assert(take.apply(0) == 4)
+ assert(take.apply(1) == 6)
+ }
+}
+
+class TestPartition(i: Int, value: Int) extends Partition with Serializable {
+ def index = i
+
+ def testValue = this.value
+
+}
diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
index 88b36a6855..964ca9e91d 100644
--- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
+++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala
@@ -271,8 +271,8 @@ class RDDSuite extends FunSuite with SharedSparkContext {
// test that you get over 90% locality in each group
val minLocality = coalesced2.partitions
.map(part => part.asInstanceOf[CoalescedRDDPartition].localFraction)
- .foldLeft(1.)((perc, loc) => math.min(perc,loc))
- assert(minLocality >= 0.90, "Expected 90% locality but got " + (minLocality*100.).toInt + "%")
+ .foldLeft(1.0)((perc, loc) => math.min(perc,loc))
+ assert(minLocality >= 0.90, "Expected 90% locality but got " + (minLocality*100.0).toInt + "%")
// test that the groups are load balanced with 100 +/- 20 elements in each
val maxImbalance = coalesced2.partitions
@@ -284,9 +284,9 @@ class RDDSuite extends FunSuite with SharedSparkContext {
val coalesced3 = data3.coalesce(numMachines*2)
val minLocality2 = coalesced3.partitions
.map(part => part.asInstanceOf[CoalescedRDDPartition].localFraction)
- .foldLeft(1.)((perc, loc) => math.min(perc,loc))
+ .foldLeft(1.0)((perc, loc) => math.min(perc,loc))
assert(minLocality2 >= 0.90, "Expected 90% locality for derived RDD but got " +
- (minLocality2*100.).toInt + "%")
+ (minLocality2*100.0).toInt + "%")
}
test("zipped RDDs") {
diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
index a4d41ebbff..706d84a58b 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala
@@ -206,6 +206,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
submit(rdd, Array(0))
complete(taskSets(0), List((Success, 42)))
assert(results === Map(0 -> 42))
+ assertDataStructuresEmpty
}
test("local job") {
@@ -219,6 +220,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
val jobId = scheduler.nextJobId.getAndIncrement()
runEvent(JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, null, listener))
assert(results === Map(0 -> 42))
+ assertDataStructuresEmpty
}
test("run trivial job w/ dependency") {
@@ -227,6 +229,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
submit(finalRdd, Array(0))
complete(taskSets(0), Seq((Success, 42)))
assert(results === Map(0 -> 42))
+ assertDataStructuresEmpty
}
test("cache location preferences w/ dependency") {
@@ -239,12 +242,14 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
assertLocations(taskSet, Seq(Seq("hostA", "hostB")))
complete(taskSet, Seq((Success, 42)))
assert(results === Map(0 -> 42))
+ assertDataStructuresEmpty
}
test("trivial job failure") {
submit(makeRdd(1, Nil), Array(0))
failed(taskSets(0), "some failure")
assert(failure.getMessage === "Job aborted: some failure")
+ assertDataStructuresEmpty
}
test("run trivial shuffle") {
@@ -260,6 +265,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB")))
complete(taskSets(1), Seq((Success, 42)))
assert(results === Map(0 -> 42))
+ assertDataStructuresEmpty
}
test("run trivial shuffle with fetch failure") {
@@ -285,6 +291,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === Array("hostA", "hostB"))
complete(taskSets(3), Seq((Success, 43)))
assert(results === Map(0 -> 42, 1 -> 43))
+ assertDataStructuresEmpty
}
test("ignore late map task completions") {
@@ -313,6 +320,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA")))
complete(taskSets(1), Seq((Success, 42), (Success, 43)))
assert(results === Map(0 -> 42, 1 -> 43))
+ assertDataStructuresEmpty
}
test("run trivial shuffle with out-of-band failure and retry") {
@@ -329,15 +337,16 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
complete(taskSets(0), Seq(
(Success, makeMapStatus("hostA", 1)),
(Success, makeMapStatus("hostB", 1))))
- // have hostC complete the resubmitted task
- complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1))))
- assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
- Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB")))
- complete(taskSets(2), Seq((Success, 42)))
- assert(results === Map(0 -> 42))
- }
-
- test("recursive shuffle failures") {
+ // have hostC complete the resubmitted task
+ complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1))))
+ assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) ===
+ Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB")))
+ complete(taskSets(2), Seq((Success, 42)))
+ assert(results === Map(0 -> 42))
+ assertDataStructuresEmpty
+ }
+
+ test("recursive shuffle failures") {
val shuffleOneRdd = makeRdd(2, Nil)
val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null)
val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne))
@@ -363,6 +372,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
complete(taskSets(4), Seq((Success, makeMapStatus("hostA", 1))))
complete(taskSets(5), Seq((Success, 42)))
assert(results === Map(0 -> 42))
+ assertDataStructuresEmpty
}
test("cached post-shuffle") {
@@ -394,6 +404,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
complete(taskSets(3), Seq((Success, makeMapStatus("hostD", 1))))
complete(taskSets(4), Seq((Success, 42)))
assert(results === Map(0 -> 42))
+ assertDataStructuresEmpty
}
/**
@@ -413,4 +424,18 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont
private def makeBlockManagerId(host: String): BlockManagerId =
BlockManagerId("exec-" + host, host, 12345, 0)
+ private def assertDataStructuresEmpty = {
+ assert(scheduler.pendingTasks.isEmpty)
+ assert(scheduler.activeJobs.isEmpty)
+ assert(scheduler.failed.isEmpty)
+ assert(scheduler.idToActiveJob.isEmpty)
+ assert(scheduler.jobIdToStageIds.isEmpty)
+ assert(scheduler.stageIdToJobIds.isEmpty)
+ assert(scheduler.stageIdToStage.isEmpty)
+ assert(scheduler.stageToInfos.isEmpty)
+ assert(scheduler.resultStageToJob.isEmpty)
+ assert(scheduler.running.isEmpty)
+ assert(scheduler.shuffleToMapStage.isEmpty)
+ assert(scheduler.waiting.isEmpty)
+ }
}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala
index 984881861c..002368ff55 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala
@@ -31,6 +31,7 @@ import org.apache.spark.rdd.RDD
class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
+ val WAIT_TIMEOUT_MILLIS = 10000
test("inner method") {
sc = new SparkContext("local", "joblogger")
@@ -92,6 +93,8 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) }
rdd.reduceByKey(_+_).collect()
+ assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
+
val user = System.getProperty("user.name", SparkContext.SPARK_UNKNOWN_USER)
joblogger.getLogDir should be ("/tmp/spark-%s".format(user))
@@ -120,7 +123,9 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
sc.addSparkListener(joblogger)
val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) }
rdd.reduceByKey(_+_).collect()
-
+
+ assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
+
joblogger.onJobStartCount should be (1)
joblogger.onJobEndCount should be (1)
joblogger.onTaskEndCount should be (8)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
index 1fd76420ea..2e41438a52 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala
@@ -145,7 +145,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc
// Make a task whose result is larger than the akka frame size
System.setProperty("spark.akka.frameSize", "1")
val akkaFrameSize =
- sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt
+ sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt
val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x,y) => x)
assert(result === 1.to(akkaFrameSize).toArray)
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala
index b97f2b19b5..29c4cc5d9c 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala
@@ -283,7 +283,7 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo
// Fail the task MAX_TASK_FAILURES times, and check that the task set is aborted
// after the last failure.
- (0 until manager.MAX_TASK_FAILURES).foreach { index =>
+ (1 to manager.MAX_TASK_FAILURES).foreach { index =>
val offerResult = manager.resourceOffer("exec1", "host1", 1, ANY)
assert(offerResult != None,
"Expect resource offer on iteration %s to return a task".format(index))
diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala
index ee150a3107..27c2d53361 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala
@@ -82,7 +82,7 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndA
test("handling results larger than Akka frame size") {
val akkaFrameSize =
- sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt
+ sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt
val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x)
assert(result === 1.to(akkaFrameSize).toArray)
@@ -103,7 +103,7 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndA
}
scheduler.taskResultGetter = new ResultDeletingTaskResultGetter(sc.env, scheduler)
val akkaFrameSize =
- sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt
+ sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt
val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x)
assert(result === 1.to(akkaFrameSize).toArray)
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala
index cb76275e39..b647e8a672 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala
@@ -39,7 +39,7 @@ class BlockIdSuite extends FunSuite {
fail()
} catch {
case e: IllegalStateException => // OK
- case _ => fail()
+ case _: Throwable => fail()
}
}
diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
index 484a654108..5b4d63b954 100644
--- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala
@@ -56,7 +56,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT
System.setProperty("spark.hostPort", "localhost:" + boundPort)
master = new BlockManagerMaster(
- actorSystem.actorOf(Props(new BlockManagerMasterActor(true))))
+ Left(actorSystem.actorOf(Props(new BlockManagerMasterActor(true)))))
// Set the arch to 64-bit and compressedOops to true to get a deterministic test-case
oldArch = System.setProperty("os.arch", "amd64")
diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
index 0b9056344c..070982e798 100644
--- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala
@@ -1,3 +1,20 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
package org.apache.spark.storage
import java.io.{FileWriter, File}
@@ -5,9 +22,9 @@ import java.io.{FileWriter, File}
import scala.collection.mutable
import com.google.common.io.Files
-import org.scalatest.{BeforeAndAfterEach, FunSuite}
+import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite}
-class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach {
+class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll {
val rootDir0 = Files.createTempDir()
rootDir0.deleteOnExit()
@@ -16,6 +33,12 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach {
val rootDirs = rootDir0.getName + "," + rootDir1.getName
println("Created root dirs: " + rootDirs)
+ // This suite focuses primarily on consolidation features,
+ // so we coerce consolidation if not already enabled.
+ val consolidateProp = "spark.shuffle.consolidateFiles"
+ val oldConsolidate = Option(System.getProperty(consolidateProp))
+ System.setProperty(consolidateProp, "true")
+
val shuffleBlockManager = new ShuffleBlockManager(null) {
var idToSegmentMap = mutable.Map[ShuffleBlockId, FileSegment]()
override def getBlockLocation(id: ShuffleBlockId) = idToSegmentMap(id)
@@ -23,6 +46,10 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach {
var diskBlockManager: DiskBlockManager = _
+ override def afterAll() {
+ oldConsolidate.map(c => System.setProperty(consolidateProp, c))
+ }
+
override def beforeEach() {
diskBlockManager = new DiskBlockManager(shuffleBlockManager, rootDirs)
shuffleBlockManager.idToSegmentMap.clear()
diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala
index 8f0ec6683b..3764f4d1a0 100644
--- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala
+++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala
@@ -34,7 +34,6 @@ class UISuite extends FunSuite {
}
val (jettyServer1, boundPort1) = JettyUtils.startJettyServer("localhost", startPort, Seq())
val (jettyServer2, boundPort2) = JettyUtils.startJettyServer("localhost", startPort, Seq())
-
// Allow some wiggle room in case ports on the machine are under contention
assert(boundPort1 > startPort && boundPort1 < startPort + 10)
assert(boundPort2 > boundPort1 && boundPort2 < boundPort1 + 10)
diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala
index 4e40dcbdee..5aff26f9fc 100644
--- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala
@@ -63,54 +63,53 @@ class SizeEstimatorSuite
}
test("simple classes") {
- assert(SizeEstimator.estimate(new DummyClass1) === 16)
- assert(SizeEstimator.estimate(new DummyClass2) === 16)
- assert(SizeEstimator.estimate(new DummyClass3) === 24)
- assert(SizeEstimator.estimate(new DummyClass4(null)) === 24)
- assert(SizeEstimator.estimate(new DummyClass4(new DummyClass3)) === 48)
+ expectResult(16)(SizeEstimator.estimate(new DummyClass1))
+ expectResult(16)(SizeEstimator.estimate(new DummyClass2))
+ expectResult(24)(SizeEstimator.estimate(new DummyClass3))
+ expectResult(24)(SizeEstimator.estimate(new DummyClass4(null)))
+ expectResult(48)(SizeEstimator.estimate(new DummyClass4(new DummyClass3)))
}
// NOTE: The String class definition varies across JDK versions (1.6 vs. 1.7) and vendors
// (Sun vs IBM). Use a DummyString class to make tests deterministic.
test("strings") {
- assert(SizeEstimator.estimate(DummyString("")) === 40)
- assert(SizeEstimator.estimate(DummyString("a")) === 48)
- assert(SizeEstimator.estimate(DummyString("ab")) === 48)
- assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 56)
+ expectResult(40)(SizeEstimator.estimate(DummyString("")))
+ expectResult(48)(SizeEstimator.estimate(DummyString("a")))
+ expectResult(48)(SizeEstimator.estimate(DummyString("ab")))
+ expectResult(56)(SizeEstimator.estimate(DummyString("abcdefgh")))
}
test("primitive arrays") {
- assert(SizeEstimator.estimate(new Array[Byte](10)) === 32)
- assert(SizeEstimator.estimate(new Array[Char](10)) === 40)
- assert(SizeEstimator.estimate(new Array[Short](10)) === 40)
- assert(SizeEstimator.estimate(new Array[Int](10)) === 56)
- assert(SizeEstimator.estimate(new Array[Long](10)) === 96)
- assert(SizeEstimator.estimate(new Array[Float](10)) === 56)
- assert(SizeEstimator.estimate(new Array[Double](10)) === 96)
- assert(SizeEstimator.estimate(new Array[Int](1000)) === 4016)
- assert(SizeEstimator.estimate(new Array[Long](1000)) === 8016)
+ expectResult(32)(SizeEstimator.estimate(new Array[Byte](10)))
+ expectResult(40)(SizeEstimator.estimate(new Array[Char](10)))
+ expectResult(40)(SizeEstimator.estimate(new Array[Short](10)))
+ expectResult(56)(SizeEstimator.estimate(new Array[Int](10)))
+ expectResult(96)(SizeEstimator.estimate(new Array[Long](10)))
+ expectResult(56)(SizeEstimator.estimate(new Array[Float](10)))
+ expectResult(96)(SizeEstimator.estimate(new Array[Double](10)))
+ expectResult(4016)(SizeEstimator.estimate(new Array[Int](1000)))
+ expectResult(8016)(SizeEstimator.estimate(new Array[Long](1000)))
}
test("object arrays") {
// Arrays containing nulls should just have one pointer per element
- assert(SizeEstimator.estimate(new Array[String](10)) === 56)
- assert(SizeEstimator.estimate(new Array[AnyRef](10)) === 56)
-
+ expectResult(56)(SizeEstimator.estimate(new Array[String](10)))
+ expectResult(56)(SizeEstimator.estimate(new Array[AnyRef](10)))
// For object arrays with non-null elements, each object should take one pointer plus
// however many bytes that class takes. (Note that Array.fill calls the code in its
// second parameter separately for each object, so we get distinct objects.)
- assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass1)) === 216)
- assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass2)) === 216)
- assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass3)) === 296)
- assert(SizeEstimator.estimate(Array(new DummyClass1, new DummyClass2)) === 56)
+ expectResult(216)(SizeEstimator.estimate(Array.fill(10)(new DummyClass1)))
+ expectResult(216)(SizeEstimator.estimate(Array.fill(10)(new DummyClass2)))
+ expectResult(296)(SizeEstimator.estimate(Array.fill(10)(new DummyClass3)))
+ expectResult(56)(SizeEstimator.estimate(Array(new DummyClass1, new DummyClass2)))
// Past size 100, our samples 100 elements, but we should still get the right size.
- assert(SizeEstimator.estimate(Array.fill(1000)(new DummyClass3)) === 28016)
+ expectResult(28016)(SizeEstimator.estimate(Array.fill(1000)(new DummyClass3)))
// If an array contains the *same* element many times, we should only count it once.
val d1 = new DummyClass1
- assert(SizeEstimator.estimate(Array.fill(10)(d1)) === 72) // 10 pointers plus 8-byte object
- assert(SizeEstimator.estimate(Array.fill(100)(d1)) === 432) // 100 pointers plus 8-byte object
+ expectResult(72)(SizeEstimator.estimate(Array.fill(10)(d1))) // 10 pointers plus 8-byte object
+ expectResult(432)(SizeEstimator.estimate(Array.fill(100)(d1))) // 100 pointers plus 8-byte object
// Same thing with huge array containing the same element many times. Note that this won't
// return exactly 4032 because it can't tell that *all* the elements will equal the first
@@ -128,11 +127,10 @@ class SizeEstimatorSuite
val initialize = PrivateMethod[Unit]('initialize)
SizeEstimator invokePrivate initialize()
- assert(SizeEstimator.estimate(DummyString("")) === 40)
- assert(SizeEstimator.estimate(DummyString("a")) === 48)
- assert(SizeEstimator.estimate(DummyString("ab")) === 48)
- assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 56)
-
+ expectResult(40)(SizeEstimator.estimate(DummyString("")))
+ expectResult(48)(SizeEstimator.estimate(DummyString("a")))
+ expectResult(48)(SizeEstimator.estimate(DummyString("ab")))
+ expectResult(56)(SizeEstimator.estimate(DummyString("abcdefgh")))
resetOrClear("os.arch", arch)
}
@@ -145,10 +143,10 @@ class SizeEstimatorSuite
val initialize = PrivateMethod[Unit]('initialize)
SizeEstimator invokePrivate initialize()
- assert(SizeEstimator.estimate(DummyString("")) === 56)
- assert(SizeEstimator.estimate(DummyString("a")) === 64)
- assert(SizeEstimator.estimate(DummyString("ab")) === 64)
- assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 72)
+ expectResult(56)(SizeEstimator.estimate(DummyString("")))
+ expectResult(64)(SizeEstimator.estimate(DummyString("a")))
+ expectResult(64)(SizeEstimator.estimate(DummyString("ab")))
+ expectResult(72)(SizeEstimator.estimate(DummyString("abcdefgh")))
resetOrClear("os.arch", arch)
resetOrClear("spark.test.useCompressedOops", oops)
diff --git a/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala
new file mode 100644
index 0000000000..b78367b6ca
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.util.Random
+import org.scalatest.FlatSpec
+import org.scalatest.FunSuite
+import org.scalatest.matchers.ShouldMatchers
+import org.apache.spark.util.Utils.times
+
+class XORShiftRandomSuite extends FunSuite with ShouldMatchers {
+
+ def fixture = new {
+ val seed = 1L
+ val xorRand = new XORShiftRandom(seed)
+ val hundMil = 1e8.toInt
+ }
+
+ /*
+ * This test is based on a chi-squared test for randomness. The values are hard-coded
+ * so as not to create Spark's dependency on apache.commons.math3 just to call one
+ * method for calculating the exact p-value for a given number of random numbers
+ * and bins. In case one would want to move to a full-fledged test based on
+ * apache.commons.math3, the relevant class is here:
+ * org.apache.commons.math3.stat.inference.ChiSquareTest
+ */
+ test ("XORShift generates valid random numbers") {
+
+ val f = fixture
+
+ val numBins = 10
+ // create 10 bins
+ val bins = Array.fill(numBins)(0)
+
+ // populate bins based on modulus of the random number
+ times(f.hundMil) {bins(math.abs(f.xorRand.nextInt) % 10) += 1}
+
+ /* since the seed is deterministic, until the algorithm is changed, we know the result will be
+ * exactly this: Array(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272,
+ * 10000790, 10002286, 9998699), so the test will never fail at the prespecified (5%)
+ * significance level. However, should the RNG implementation change, the test should still
+ * pass at the same significance level. The chi-squared test done in R gave the following
+ * results:
+ * > chisq.test(c(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272,
+ * 10000790, 10002286, 9998699))
+ * Chi-squared test for given probabilities
+ * data: c(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272, 10000790,
+ * 10002286, 9998699)
+ * X-squared = 11.975, df = 9, p-value = 0.2147
+ * Note that the p-value was ~0.22. The test will fail if alpha < 0.05, which for 100 million
+ * random numbers
+ * and 10 bins will happen at X-squared of ~16.9196. So, the test will fail if X-squared
+ * is greater than or equal to that number.
+ */
+ val binSize = f.hundMil/numBins
+ val xSquared = bins.map(x => math.pow((binSize - x), 2)/binSize).sum
+ xSquared should be < (16.9196)
+
+ }
+
+} \ No newline at end of file
diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala
index ca3f684668..e9b62ea70d 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala
@@ -1,9 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
package org.apache.spark.util.collection
import scala.collection.mutable.HashSet
import org.scalatest.FunSuite
-
-class OpenHashMapSuite extends FunSuite {
+import org.scalatest.matchers.ShouldMatchers
+import org.apache.spark.util.SizeEstimator
+
+class OpenHashMapSuite extends FunSuite with ShouldMatchers {
+
+ test("size for specialized, primitive value (int)") {
+ val capacity = 1024
+ val map = new OpenHashMap[String, Int](capacity)
+ val actualSize = SizeEstimator.estimate(map)
+ // 64 bit for pointers, 32 bit for ints, and 1 bit for the bitset.
+ val expectedSize = capacity * (64 + 32 + 1) / 8
+ // Make sure we are not allocating a significant amount of memory beyond our expected.
+ actualSize should be <= (expectedSize * 1.1).toLong
+ }
test("initialization") {
val goodMap1 = new OpenHashMap[String, Int](1)
diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala
index 4e11e8a628..1b24f8f287 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala
@@ -1,9 +1,44 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
package org.apache.spark.util.collection
import org.scalatest.FunSuite
+import org.scalatest.matchers.ShouldMatchers
+
+import org.apache.spark.util.SizeEstimator
+
+class OpenHashSetSuite extends FunSuite with ShouldMatchers {
-class OpenHashSetSuite extends FunSuite {
+ test("size for specialized, primitive int") {
+ val loadFactor = 0.7
+ val set = new OpenHashSet[Int](64, loadFactor)
+ for (i <- 0 until 1024) {
+ set.add(i)
+ }
+ assert(set.size === 1024)
+ assert(set.capacity > 1024)
+ val actualSize = SizeEstimator.estimate(set)
+ // 32 bits for the ints + 1 bit for the bitset
+ val expectedSize = set.capacity * (32 + 1) / 8
+ // Make sure we are not allocating a significant amount of memory beyond our expected.
+ actualSize should be <= (expectedSize * 1.1).toLong
+ }
test("primitive int") {
val set = new OpenHashSet[Int]
diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala
index dfd6aed2c4..3b60decee9 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashSetSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala
@@ -1,9 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
package org.apache.spark.util.collection
import scala.collection.mutable.HashSet
import org.scalatest.FunSuite
+import org.scalatest.matchers.ShouldMatchers
+import org.apache.spark.util.SizeEstimator
+
+class PrimitiveKeyOpenHashMapSuite extends FunSuite with ShouldMatchers {
-class PrimitiveKeyOpenHashSetSuite extends FunSuite {
+ test("size for specialized, primitive key, value (int, int)") {
+ val capacity = 1024
+ val map = new PrimitiveKeyOpenHashMap[Int, Int](capacity)
+ val actualSize = SizeEstimator.estimate(map)
+ // 32 bit for keys, 32 bit for values, and 1 bit for the bitset.
+ val expectedSize = capacity * (32 + 32 + 1) / 8
+ // Make sure we are not allocating a significant amount of memory beyond our expected.
+ actualSize should be <= (expectedSize * 1.1).toLong
+ }
test("initialization") {
val goodMap1 = new PrimitiveKeyOpenHashMap[Int, Int](1)