diff options
author | Tathagata Das <tathagata.das1565@gmail.com> | 2013-12-19 12:05:08 -0800 |
---|---|---|
committer | Tathagata Das <tathagata.das1565@gmail.com> | 2013-12-19 12:05:08 -0800 |
commit | de41c436a0088efc83bc4705dcd279e61b085759 (patch) | |
tree | d725da59145e243c1b1934567dbb61c7ddb2abcb /core | |
parent | 03ef6e889929befa968d35fa3757c687edc3a38b (diff) | |
parent | ec71b445ad0440e84c4b4909e4faf75aba0f13d7 (diff) | |
download | spark-de41c436a0088efc83bc4705dcd279e61b085759.tar.gz spark-de41c436a0088efc83bc4705dcd279e61b085759.tar.bz2 spark-de41c436a0088efc83bc4705dcd279e61b085759.zip |
Merge branch 'scheduler-update' into window-improvement
Conflicts:
streaming/src/main/scala/org/apache/spark/streaming/dstream/WindowedDStream.scala
Diffstat (limited to 'core')
149 files changed, 2721 insertions, 1263 deletions
diff --git a/core/pom.xml b/core/pom.xml index 8621d257e5..043f6cf68d 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -26,7 +26,7 @@ </parent> <groupId>org.apache.spark</groupId> - <artifactId>spark-core_2.9.3</artifactId> + <artifactId>spark-core_2.10</artifactId> <packaging>jar</packaging> <name>Spark Project Core</name> <url>http://spark.incubator.apache.org/</url> @@ -81,12 +81,8 @@ <artifactId>asm</artifactId> </dependency> <dependency> - <groupId>com.google.protobuf</groupId> - <artifactId>protobuf-java</artifactId> - </dependency> - <dependency> <groupId>com.twitter</groupId> - <artifactId>chill_2.9.3</artifactId> + <artifactId>chill_${scala.binary.version}</artifactId> <version>0.3.1</version> </dependency> <dependency> @@ -95,20 +91,12 @@ <version>0.3.1</version> </dependency> <dependency> - <groupId>com.typesafe.akka</groupId> - <artifactId>akka-actor</artifactId> - </dependency> - <dependency> - <groupId>com.typesafe.akka</groupId> - <artifactId>akka-remote</artifactId> + <groupId>${akka.group}</groupId> + <artifactId>akka-remote_${scala.binary.version}</artifactId> </dependency> <dependency> - <groupId>com.typesafe.akka</groupId> - <artifactId>akka-slf4j</artifactId> - </dependency> - <dependency> - <groupId>org.scala-lang</groupId> - <artifactId>scalap</artifactId> + <groupId>${akka.group}</groupId> + <artifactId>akka-slf4j_${scala.binary.version}</artifactId> </dependency> <dependency> <groupId>org.scala-lang</groupId> @@ -116,7 +104,7 @@ </dependency> <dependency> <groupId>net.liftweb</groupId> - <artifactId>lift-json_2.9.2</artifactId> + <artifactId>lift-json_${scala.binary.version}</artifactId> </dependency> <dependency> <groupId>it.unimi.dsi</groupId> @@ -127,10 +115,6 @@ <artifactId>colt</artifactId> </dependency> <dependency> - <groupId>com.github.scala-incubator.io</groupId> - <artifactId>scala-io-file_2.9.2</artifactId> - </dependency> - <dependency> <groupId>org.apache.mesos</groupId> <artifactId>mesos</artifactId> </dependency> @@ -159,18 +143,27 @@ <artifactId>metrics-ganglia</artifactId> </dependency> <dependency> + <groupId>com.codahale.metrics</groupId> + <artifactId>metrics-graphite</artifactId> + </dependency> + <dependency> <groupId>org.apache.derby</groupId> <artifactId>derby</artifactId> <scope>test</scope> </dependency> <dependency> + <groupId>commons-io</groupId> + <artifactId>commons-io</artifactId> + <scope>test</scope> + </dependency> + <dependency> <groupId>org.scalatest</groupId> - <artifactId>scalatest_2.9.3</artifactId> + <artifactId>scalatest_${scala.binary.version}</artifactId> <scope>test</scope> </dependency> <dependency> <groupId>org.scalacheck</groupId> - <artifactId>scalacheck_2.9.3</artifactId> + <artifactId>scalacheck_${scala.binary.version}</artifactId> <scope>test</scope> </dependency> <dependency> @@ -190,8 +183,8 @@ </dependency> </dependencies> <build> - <outputDirectory>target/scala-${scala.version}/classes</outputDirectory> - <testOutputDirectory>target/scala-${scala.version}/test-classes</testOutputDirectory> + <outputDirectory>target/scala-${scala.binary.version}/classes</outputDirectory> + <testOutputDirectory>target/scala-${scala.binary.version}/test-classes</testOutputDirectory> <plugins> <plugin> <groupId>org.apache.maven.plugins</groupId> diff --git a/core/src/main/java/org/apache/spark/network/netty/FileClient.java b/core/src/main/java/org/apache/spark/network/netty/FileClient.java index 20a7a3aa8c..edd0fc56f8 100644 --- a/core/src/main/java/org/apache/spark/network/netty/FileClient.java +++ b/core/src/main/java/org/apache/spark/network/netty/FileClient.java @@ -19,8 +19,6 @@ package org.apache.spark.network.netty; import io.netty.bootstrap.Bootstrap; import io.netty.channel.Channel; -import io.netty.channel.ChannelFuture; -import io.netty.channel.ChannelFutureListener; import io.netty.channel.ChannelOption; import io.netty.channel.oio.OioEventLoopGroup; import io.netty.channel.socket.oio.OioSocketChannel; diff --git a/core/src/main/java/org/apache/spark/network/netty/FileServer.java b/core/src/main/java/org/apache/spark/network/netty/FileServer.java index 666432474d..a99af348ce 100644 --- a/core/src/main/java/org/apache/spark/network/netty/FileServer.java +++ b/core/src/main/java/org/apache/spark/network/netty/FileServer.java @@ -20,7 +20,6 @@ package org.apache.spark.network.netty; import java.net.InetSocketAddress; import io.netty.bootstrap.ServerBootstrap; -import io.netty.channel.Channel; import io.netty.channel.ChannelFuture; import io.netty.channel.ChannelOption; import io.netty.channel.oio.OioEventLoopGroup; diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index 1ad9240cfa..c6b4ac5192 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -99,7 +99,7 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: override def ready(atMost: Duration)(implicit permit: CanAwait): SimpleFutureAction.this.type = { if (!atMost.isFinite()) { awaitResult() - } else { + } else jobWaiter.synchronized { val finishTime = System.currentTimeMillis() + atMost.toMillis while (!isCompleted) { val time = System.currentTimeMillis() diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 5e465fa22c..10fae5af9f 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -21,12 +21,11 @@ import java.io._ import java.util.zip.{GZIPInputStream, GZIPOutputStream} import scala.collection.mutable.HashSet +import scala.concurrent.Await +import scala.concurrent.duration._ import akka.actor._ -import akka.dispatch._ import akka.pattern.ask -import akka.util.Duration - import org.apache.spark.scheduler.MapStatus import org.apache.spark.storage.BlockManagerId @@ -55,9 +54,9 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster private[spark] class MapOutputTracker extends Logging { private val timeout = Duration.create(System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") - + // Set to the MapOutputTrackerActor living on the driver - var trackerActor: ActorRef = _ + var trackerActor: Either[ActorRef, ActorSelection] = _ protected val mapStatuses = new TimeStampedHashMap[Int, Array[MapStatus]] @@ -73,8 +72,18 @@ private[spark] class MapOutputTracker extends Logging { // throw a SparkException if this fails. private def askTracker(message: Any): Any = { try { - val future = trackerActor.ask(message)(timeout) - return Await.result(future, timeout) + /* + The difference between ActorRef and ActorSelection is well explained here: + http://doc.akka.io/docs/akka/2.2.3/project/migration-guide-2.1.x-2.2.x.html#Use_actorSelection_instead_of_actorFor + In spark a map output tracker can be either started on Driver where it is created which + is an ActorRef or it can be on executor from where it is looked up which is an + actorSelection. + */ + val future = trackerActor match { + case Left(a: ActorRef) => a.ask(message)(timeout) + case Right(b: ActorSelection) => b.ask(message)(timeout) + } + Await.result(future, timeout) } catch { case e: Exception => throw new SparkException("Error communicating with MapOutputTracker", e) @@ -117,7 +126,7 @@ private[spark] class MapOutputTracker extends Logging { fetching += shuffleId } } - + if (fetchedStatuses == null) { // We won the race to fetch the output locs; do so logInfo("Doing the fetch; tracker actor = " + trackerActor) @@ -144,7 +153,7 @@ private[spark] class MapOutputTracker extends Logging { else{ throw new FetchFailedException(null, shuffleId, -1, reduceId, new Exception("Missing all output locations for shuffle " + shuffleId)) - } + } } else { statuses.synchronized { return MapOutputTracker.convertMapStatuses(shuffleId, reduceId, statuses) @@ -244,12 +253,12 @@ private[spark] class MapOutputTrackerMaster extends MapOutputTracker { case Some(bytes) => return bytes case None => - statuses = mapStatuses(shuffleId) + statuses = mapStatuses.getOrElse(shuffleId, Array[MapStatus]()) epochGotten = epoch } } // If we got here, we failed to find the serialized locations in the cache, so we pulled - // out a snapshot of the locations as "locs"; let's serialize and return that + // out a snapshot of the locations as "statuses"; let's serialize and return that val bytes = MapOutputTracker.serializeMapStatuses(statuses) logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length)) // Add them into the table only if the epoch hasn't changed while we were working @@ -274,6 +283,10 @@ private[spark] class MapOutputTrackerMaster extends MapOutputTracker { override def updateEpoch(newEpoch: Long) { // This might be called on the MapOutputTrackerMaster if we're running in local mode. } + + def has(shuffleId: Int): Boolean = { + cachedSerializedStatuses.get(shuffleId).isDefined || mapStatuses.contains(shuffleId) + } } private[spark] object MapOutputTracker { @@ -308,7 +321,7 @@ private[spark] object MapOutputTracker { statuses: Array[MapStatus]): Array[(BlockManagerId, Long)] = { assert (statuses != null) statuses.map { - status => + status => if (status == null) { throw new FetchFailedException(null, shuffleId, -1, reduceId, new Exception("Missing an output location for shuffle " + shuffleId)) diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 0e2c987a59..bcec41c439 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -17,8 +17,10 @@ package org.apache.spark -import org.apache.spark.util.Utils +import scala.reflect.ClassTag + import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils /** * An object that defines how the elements in a key-value pair RDD are partitioned by key. @@ -72,7 +74,7 @@ class HashPartitioner(partitions: Int) extends Partitioner { case null => 0 case _ => Utils.nonNegativeMod(key.hashCode, numPartitions) } - + override def equals(other: Any): Boolean = other match { case h: HashPartitioner => h.numPartitions == numPartitions @@ -85,7 +87,7 @@ class HashPartitioner(partitions: Int) extends Partitioner { * A [[org.apache.spark.Partitioner]] that partitions sortable records by range into roughly equal ranges. * Determines the ranges by sampling the RDD passed in. */ -class RangePartitioner[K <% Ordered[K]: ClassManifest, V]( +class RangePartitioner[K <% Ordered[K]: ClassTag, V]( partitions: Int, @transient rdd: RDD[_ <: Product2[K,V]], private val ascending: Boolean = true) diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 42b2985b50..a0f794edfd 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -26,6 +26,7 @@ import scala.collection.Map import scala.collection.generic.Growable import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap +import scala.reflect.{ClassTag, classTag} import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.Path @@ -81,7 +82,7 @@ class SparkContext( val sparkHome: String = null, val jars: Seq[String] = Nil, val environment: Map[String, String] = Map(), - // This is used only by yarn for now, but should be relevant to other cluster types (mesos, etc) + // This is used only by YARN for now, but should be relevant to other cluster types (Mesos, etc) // too. This is typically generated from InputFormatInfo.computePreferredLocations .. host, set // of data-local splits on host val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] = @@ -153,98 +154,11 @@ class SparkContext( executorEnvs("SPARK_USER") = sparkUser // Create and start the scheduler - private[spark] var taskScheduler: TaskScheduler = { - // Regular expression used for local[N] master format - val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r - // Regular expression for local[N, maxRetries], used in tests with failing tasks - val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r - // Regular expression for simulating a Spark cluster of [N, cores, memory] locally - val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r - // Regular expression for connecting to Spark deploy clusters - val SPARK_REGEX = """spark://(.*)""".r - // Regular expression for connection to Mesos cluster - val MESOS_REGEX = """mesos://(.*)""".r - // Regular expression for connection to Simr cluster - val SIMR_REGEX = """simr://(.*)""".r - - master match { - case "local" => - new LocalScheduler(1, 0, this) - - case LOCAL_N_REGEX(threads) => - new LocalScheduler(threads.toInt, 0, this) - - case LOCAL_N_FAILURES_REGEX(threads, maxFailures) => - new LocalScheduler(threads.toInt, maxFailures.toInt, this) - - case SPARK_REGEX(sparkUrl) => - val scheduler = new ClusterScheduler(this) - val masterUrls = sparkUrl.split(",").map("spark://" + _) - val backend = new SparkDeploySchedulerBackend(scheduler, this, masterUrls, appName) - scheduler.initialize(backend) - scheduler - - case SIMR_REGEX(simrUrl) => - val scheduler = new ClusterScheduler(this) - val backend = new SimrSchedulerBackend(scheduler, this, simrUrl) - scheduler.initialize(backend) - scheduler - - case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) => - // Check to make sure memory requested <= memoryPerSlave. Otherwise Spark will just hang. - val memoryPerSlaveInt = memoryPerSlave.toInt - if (SparkContext.executorMemoryRequested > memoryPerSlaveInt) { - throw new SparkException( - "Asked to launch cluster with %d MB RAM / worker but requested %d MB/worker".format( - memoryPerSlaveInt, SparkContext.executorMemoryRequested)) - } - - val scheduler = new ClusterScheduler(this) - val localCluster = new LocalSparkCluster( - numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt) - val masterUrls = localCluster.start() - val backend = new SparkDeploySchedulerBackend(scheduler, this, masterUrls, appName) - scheduler.initialize(backend) - backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => { - localCluster.stop() - } - scheduler - - case "yarn-standalone" => - val scheduler = try { - val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClusterScheduler") - val cons = clazz.getConstructor(classOf[SparkContext]) - cons.newInstance(this).asInstanceOf[ClusterScheduler] - } catch { - // TODO: Enumerate the exact reasons why it can fail - // But irrespective of it, it means we cannot proceed ! - case th: Throwable => { - throw new SparkException("YARN mode not available ?", th) - } - } - val backend = new CoarseGrainedSchedulerBackend(scheduler, this.env.actorSystem) - scheduler.initialize(backend) - scheduler - - case MESOS_REGEX(mesosUrl) => - MesosNativeLibrary.load() - val scheduler = new ClusterScheduler(this) - val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean - val backend = if (coarseGrained) { - new CoarseMesosSchedulerBackend(scheduler, this, mesosUrl, appName) - } else { - new MesosSchedulerBackend(scheduler, this, mesosUrl, appName) - } - scheduler.initialize(backend) - scheduler - - case _ => - throw new SparkException("Could not parse Master URL: '" + master + "'") - } - } + private[spark] var taskScheduler = SparkContext.createTaskScheduler(this, master, appName) taskScheduler.start() @volatile private[spark] var dagScheduler = new DAGScheduler(taskScheduler) + dagScheduler.start() ui.start() @@ -354,19 +268,19 @@ class SparkContext( // Methods for creating RDDs /** Distribute a local Scala collection to form an RDD. */ - def parallelize[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = { + def parallelize[T: ClassTag](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = { new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]()) } /** Distribute a local Scala collection to form an RDD. */ - def makeRDD[T: ClassManifest](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = { + def makeRDD[T: ClassTag](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = { parallelize(seq, numSlices) } /** Distribute a local Scala collection to form an RDD, with one or more * location preferences (hostnames of Spark nodes) for each object. * Create a new partition for each collection item. */ - def makeRDD[T: ClassManifest](seq: Seq[(T, Seq[String])]): RDD[T] = { + def makeRDD[T: ClassTag](seq: Seq[(T, Seq[String])]): RDD[T] = { val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap new ParallelCollectionRDD[T](this, seq.map(_._1), seq.size, indexToPrefs) } @@ -419,7 +333,7 @@ class SparkContext( } /** - * Smarter version of hadoopFile() that uses class manifests to figure out the classes of keys, + * Smarter version of hadoopFile() that uses class tags to figure out the classes of keys, * values and the InputFormat so that users don't need to pass them directly. Instead, callers * can just write, for example, * {{{ @@ -427,17 +341,17 @@ class SparkContext( * }}} */ def hadoopFile[K, V, F <: InputFormat[K, V]](path: String, minSplits: Int) - (implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F]) + (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]) : RDD[(K, V)] = { hadoopFile(path, - fm.erasure.asInstanceOf[Class[F]], - km.erasure.asInstanceOf[Class[K]], - vm.erasure.asInstanceOf[Class[V]], + fm.runtimeClass.asInstanceOf[Class[F]], + km.runtimeClass.asInstanceOf[Class[K]], + vm.runtimeClass.asInstanceOf[Class[V]], minSplits) } /** - * Smarter version of hadoopFile() that uses class manifests to figure out the classes of keys, + * Smarter version of hadoopFile() that uses class tags to figure out the classes of keys, * values and the InputFormat so that users don't need to pass them directly. Instead, callers * can just write, for example, * {{{ @@ -445,17 +359,17 @@ class SparkContext( * }}} */ def hadoopFile[K, V, F <: InputFormat[K, V]](path: String) - (implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F]): RDD[(K, V)] = + (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] = hadoopFile[K, V, F](path, defaultMinSplits) /** Get an RDD for a Hadoop file with an arbitrary new API InputFormat. */ def newAPIHadoopFile[K, V, F <: NewInputFormat[K, V]](path: String) - (implicit km: ClassManifest[K], vm: ClassManifest[V], fm: ClassManifest[F]): RDD[(K, V)] = { + (implicit km: ClassTag[K], vm: ClassTag[V], fm: ClassTag[F]): RDD[(K, V)] = { newAPIHadoopFile( path, - fm.erasure.asInstanceOf[Class[F]], - km.erasure.asInstanceOf[Class[K]], - vm.erasure.asInstanceOf[Class[V]]) + fm.runtimeClass.asInstanceOf[Class[F]], + km.runtimeClass.asInstanceOf[Class[K]], + vm.runtimeClass.asInstanceOf[Class[V]]) } /** @@ -513,11 +427,11 @@ class SparkContext( * IntWritable). The most natural thing would've been to have implicit objects for the * converters, but then we couldn't have an object for every subclass of Writable (you can't * have a parameterized singleton object). We use functions instead to create a new converter - * for the appropriate type. In addition, we pass the converter a ClassManifest of its type to + * for the appropriate type. In addition, we pass the converter a ClassTag of its type to * allow it to figure out the Writable class to use in the subclass case. */ def sequenceFile[K, V](path: String, minSplits: Int = defaultMinSplits) - (implicit km: ClassManifest[K], vm: ClassManifest[V], + (implicit km: ClassTag[K], vm: ClassTag[V], kcf: () => WritableConverter[K], vcf: () => WritableConverter[V]) : RDD[(K, V)] = { val kc = kcf() @@ -536,7 +450,7 @@ class SparkContext( * slow if you use the default serializer (Java serialization), though the nice thing about it is * that there's very little effort required to save arbitrary objects. */ - def objectFile[T: ClassManifest]( + def objectFile[T: ClassTag]( path: String, minSplits: Int = defaultMinSplits ): RDD[T] = { @@ -545,17 +459,17 @@ class SparkContext( } - protected[spark] def checkpointFile[T: ClassManifest]( + protected[spark] def checkpointFile[T: ClassTag]( path: String ): RDD[T] = { new CheckpointRDD[T](this, path) } /** Build the union of a list of RDDs. */ - def union[T: ClassManifest](rdds: Seq[RDD[T]]): RDD[T] = new UnionRDD(this, rdds) + def union[T: ClassTag](rdds: Seq[RDD[T]]): RDD[T] = new UnionRDD(this, rdds) /** Build the union of a list of RDDs passed as variable-length arguments. */ - def union[T: ClassManifest](first: RDD[T], rest: RDD[T]*): RDD[T] = + def union[T: ClassTag](first: RDD[T], rest: RDD[T]*): RDD[T] = new UnionRDD(this, Seq(first) ++ rest) // Methods for creating shared variables @@ -798,7 +712,7 @@ class SparkContext( * flag specifies whether the scheduler can run the computation on the driver rather than * shipping it out to the cluster, for short actions like first(). */ - def runJob[T, U: ClassManifest]( + def runJob[T, U: ClassTag]( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], @@ -819,7 +733,7 @@ class SparkContext( * allowLocal flag specifies whether the scheduler can run the computation on the driver rather * than shipping it out to the cluster, for short actions like first(). */ - def runJob[T, U: ClassManifest]( + def runJob[T, U: ClassTag]( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], @@ -834,7 +748,7 @@ class SparkContext( * Run a job on a given set of partitions of an RDD, but take a function of type * `Iterator[T] => U` instead of `(TaskContext, Iterator[T]) => U`. */ - def runJob[T, U: ClassManifest]( + def runJob[T, U: ClassTag]( rdd: RDD[T], func: Iterator[T] => U, partitions: Seq[Int], @@ -846,21 +760,21 @@ class SparkContext( /** * Run a job on all partitions in an RDD and return the results in an array. */ - def runJob[T, U: ClassManifest](rdd: RDD[T], func: (TaskContext, Iterator[T]) => U): Array[U] = { + def runJob[T, U: ClassTag](rdd: RDD[T], func: (TaskContext, Iterator[T]) => U): Array[U] = { runJob(rdd, func, 0 until rdd.partitions.size, false) } /** * Run a job on all partitions in an RDD and return the results in an array. */ - def runJob[T, U: ClassManifest](rdd: RDD[T], func: Iterator[T] => U): Array[U] = { + def runJob[T, U: ClassTag](rdd: RDD[T], func: Iterator[T] => U): Array[U] = { runJob(rdd, func, 0 until rdd.partitions.size, false) } /** * Run a job on all partitions in an RDD and pass the results to a handler function. */ - def runJob[T, U: ClassManifest]( + def runJob[T, U: ClassTag]( rdd: RDD[T], processPartition: (TaskContext, Iterator[T]) => U, resultHandler: (Int, U) => Unit) @@ -871,7 +785,7 @@ class SparkContext( /** * Run a job on all partitions in an RDD and pass the results to a handler function. */ - def runJob[T, U: ClassManifest]( + def runJob[T, U: ClassTag]( rdd: RDD[T], processPartition: Iterator[T] => U, resultHandler: (Int, U) => Unit) @@ -1017,16 +931,16 @@ object SparkContext { // TODO: Add AccumulatorParams for other types, e.g. lists and strings - implicit def rddToPairRDDFunctions[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]) = + implicit def rddToPairRDDFunctions[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]) = new PairRDDFunctions(rdd) - implicit def rddToAsyncRDDActions[T: ClassManifest](rdd: RDD[T]) = new AsyncRDDActions(rdd) + implicit def rddToAsyncRDDActions[T: ClassTag](rdd: RDD[T]) = new AsyncRDDActions(rdd) - implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable: ClassManifest]( + implicit def rddToSequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable: ClassTag]( rdd: RDD[(K, V)]) = new SequenceFileRDDFunctions(rdd) - implicit def rddToOrderedRDDFunctions[K <% Ordered[K]: ClassManifest, V: ClassManifest]( + implicit def rddToOrderedRDDFunctions[K <% Ordered[K]: ClassTag, V: ClassTag]( rdd: RDD[(K, V)]) = new OrderedRDDFunctions[K, V, (K, V)](rdd) @@ -1051,16 +965,16 @@ object SparkContext { implicit def stringToText(s: String) = new Text(s) - private implicit def arrayToArrayWritable[T <% Writable: ClassManifest](arr: Traversable[T]): ArrayWritable = { + private implicit def arrayToArrayWritable[T <% Writable: ClassTag](arr: Traversable[T]): ArrayWritable = { def anyToWritable[U <% Writable](u: U): Writable = u - new ArrayWritable(classManifest[T].erasure.asInstanceOf[Class[Writable]], + new ArrayWritable(classTag[T].runtimeClass.asInstanceOf[Class[Writable]], arr.map(x => anyToWritable(x)).toArray) } // Helper objects for converting common types to Writable - private def simpleWritableConverter[T, W <: Writable: ClassManifest](convert: W => T) = { - val wClass = classManifest[W].erasure.asInstanceOf[Class[W]] + private def simpleWritableConverter[T, W <: Writable: ClassTag](convert: W => T) = { + val wClass = classTag[W].runtimeClass.asInstanceOf[Class[W]] new WritableConverter[T](_ => wClass, x => convert(x.asInstanceOf[W])) } @@ -1079,7 +993,7 @@ object SparkContext { implicit def stringWritableConverter() = simpleWritableConverter[String, Text](_.toString) implicit def writableWritableConverter[T <: Writable]() = - new WritableConverter[T](_.erasure.asInstanceOf[Class[T]], _.asInstanceOf[T]) + new WritableConverter[T](_.runtimeClass.asInstanceOf[Class[T]], _.asInstanceOf[T]) /** * Find the JAR from which a given class was loaded, to make it easy for users to pass @@ -1111,17 +1025,135 @@ object SparkContext { .map(Utils.memoryStringToMb) .getOrElse(512) } + + // Creates a task scheduler based on a given master URL. Extracted for testing. + private + def createTaskScheduler(sc: SparkContext, master: String, appName: String): TaskScheduler = { + // Regular expression used for local[N] master format + val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r + // Regular expression for local[N, maxRetries], used in tests with failing tasks + val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r + // Regular expression for simulating a Spark cluster of [N, cores, memory] locally + val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r + // Regular expression for connecting to Spark deploy clusters + val SPARK_REGEX = """spark://(.*)""".r + // Regular expression for connection to Mesos cluster by mesos:// or zk:// url + val MESOS_REGEX = """(mesos|zk)://.*""".r + // Regular expression for connection to Simr cluster + val SIMR_REGEX = """simr://(.*)""".r + + master match { + case "local" => + new LocalScheduler(1, 0, sc) + + case LOCAL_N_REGEX(threads) => + new LocalScheduler(threads.toInt, 0, sc) + + case LOCAL_N_FAILURES_REGEX(threads, maxFailures) => + new LocalScheduler(threads.toInt, maxFailures.toInt, sc) + + case SPARK_REGEX(sparkUrl) => + val scheduler = new ClusterScheduler(sc) + val masterUrls = sparkUrl.split(",").map("spark://" + _) + val backend = new SparkDeploySchedulerBackend(scheduler, sc, masterUrls, appName) + scheduler.initialize(backend) + scheduler + + case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) => + // Check to make sure memory requested <= memoryPerSlave. Otherwise Spark will just hang. + val memoryPerSlaveInt = memoryPerSlave.toInt + if (SparkContext.executorMemoryRequested > memoryPerSlaveInt) { + throw new SparkException( + "Asked to launch cluster with %d MB RAM / worker but requested %d MB/worker".format( + memoryPerSlaveInt, SparkContext.executorMemoryRequested)) + } + + val scheduler = new ClusterScheduler(sc) + val localCluster = new LocalSparkCluster( + numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt) + val masterUrls = localCluster.start() + val backend = new SparkDeploySchedulerBackend(scheduler, sc, masterUrls, appName) + scheduler.initialize(backend) + backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => { + localCluster.stop() + } + scheduler + + case "yarn-standalone" => + val scheduler = try { + val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClusterScheduler") + val cons = clazz.getConstructor(classOf[SparkContext]) + cons.newInstance(sc).asInstanceOf[ClusterScheduler] + } catch { + // TODO: Enumerate the exact reasons why it can fail + // But irrespective of it, it means we cannot proceed ! + case th: Throwable => { + throw new SparkException("YARN mode not available ?", th) + } + } + val backend = new CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) + scheduler.initialize(backend) + scheduler + + case "yarn-client" => + val scheduler = try { + val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClientClusterScheduler") + val cons = clazz.getConstructor(classOf[SparkContext]) + cons.newInstance(sc).asInstanceOf[ClusterScheduler] + + } catch { + case th: Throwable => { + throw new SparkException("YARN mode not available ?", th) + } + } + + val backend = try { + val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend") + val cons = clazz.getConstructor(classOf[ClusterScheduler], classOf[SparkContext]) + cons.newInstance(scheduler, sc).asInstanceOf[CoarseGrainedSchedulerBackend] + } catch { + case th: Throwable => { + throw new SparkException("YARN mode not available ?", th) + } + } + + scheduler.initialize(backend) + scheduler + + case mesosUrl @ MESOS_REGEX(_) => + MesosNativeLibrary.load() + val scheduler = new ClusterScheduler(sc) + val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean + val url = mesosUrl.stripPrefix("mesos://") // strip scheme from raw Mesos URLs + val backend = if (coarseGrained) { + new CoarseMesosSchedulerBackend(scheduler, sc, url, appName) + } else { + new MesosSchedulerBackend(scheduler, sc, url, appName) + } + scheduler.initialize(backend) + scheduler + + case SIMR_REGEX(simrUrl) => + val scheduler = new ClusterScheduler(sc) + val backend = new SimrSchedulerBackend(scheduler, sc, simrUrl) + scheduler.initialize(backend) + scheduler + + case _ => + throw new SparkException("Could not parse Master URL: '" + master + "'") + } + } } /** * A class encapsulating how to convert some type T to Writable. It stores both the Writable class * corresponding to T (e.g. IntWritable for Int) and a function for doing the conversion. - * The getter for the writable class takes a ClassManifest[T] in case this is a generic object + * The getter for the writable class takes a ClassTag[T] in case this is a generic object * that doesn't know the type of T when it is created. This sounds strange but is necessary to * support converting subclasses of Writable to themselves (writableWritableConverter). */ private[spark] class WritableConverter[T]( - val writableClass: ClassManifest[T] => Class[_ <: Writable], + val writableClass: ClassTag[T] => Class[_ <: Writable], val convert: Writable => T) extends Serializable diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index ff2df8fb6a..826f5c2d8c 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -20,7 +20,7 @@ package org.apache.spark import collection.mutable import serializer.Serializer -import akka.actor.{Actor, ActorRef, Props, ActorSystemImpl, ActorSystem} +import akka.actor._ import akka.remote.RemoteActorRefProvider import org.apache.spark.broadcast.BroadcastManager @@ -74,7 +74,8 @@ class SparkEnv ( actorSystem.shutdown() // Unfortunately Akka's awaitTermination doesn't actually wait for the Netty server to shut // down, but let's call it anyway in case it gets fixed in a later release - actorSystem.awaitTermination() + // UPDATE: In Akka 2.1.x, this hangs if there are remote actors, so we can't call it. + //actorSystem.awaitTermination() } def createPythonWorker(pythonExec: String, envVars: Map[String, String]): java.net.Socket = { @@ -151,17 +152,17 @@ object SparkEnv extends Logging { val closureSerializer = serializerManager.get( System.getProperty("spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer")) - def registerOrLookup(name: String, newActor: => Actor): ActorRef = { + def registerOrLookup(name: String, newActor: => Actor): Either[ActorRef, ActorSelection] = { if (isDriver) { logInfo("Registering " + name) - actorSystem.actorOf(Props(newActor), name = name) + Left(actorSystem.actorOf(Props(newActor), name = name)) } else { val driverHost: String = System.getProperty("spark.driver.host", "localhost") val driverPort: Int = System.getProperty("spark.driver.port", "7077").toInt Utils.checkHost(driverHost, "Expected hostname") - val url = "akka://spark@%s:%s/user/%s".format(driverHost, driverPort, name) + val url = "akka.tcp://spark@%s:%s/user/%s".format(driverHost, driverPort, name) logInfo("Connecting to " + name + ": " + url) - actorSystem.actorFor(url) + Right(actorSystem.actorSelection(url)) } } diff --git a/core/src/main/scala/org/apache/spark/TaskState.scala b/core/src/main/scala/org/apache/spark/TaskState.scala index 19ce8369d9..0bf1e4a5e2 100644 --- a/core/src/main/scala/org/apache/spark/TaskState.scala +++ b/core/src/main/scala/org/apache/spark/TaskState.scala @@ -19,8 +19,7 @@ package org.apache.spark import org.apache.mesos.Protos.{TaskState => MesosTaskState} -private[spark] object TaskState - extends Enumeration("LAUNCHING", "RUNNING", "FINISHED", "FAILED", "KILLED", "LOST") { +private[spark] object TaskState extends Enumeration { val LAUNCHING, RUNNING, FINISHED, FAILED, KILLED, LOST = Value diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index 043cb183ba..da30cf619a 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -17,18 +17,23 @@ package org.apache.spark.api.java +import scala.reflect.ClassTag + import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext.doubleRDDToDoubleRDDFunctions import org.apache.spark.api.java.function.{Function => JFunction} import org.apache.spark.util.StatCounter import org.apache.spark.partial.{BoundedDouble, PartialResult} import org.apache.spark.storage.StorageLevel + import java.lang.Double import org.apache.spark.Partitioner +import scala.collection.JavaConverters._ + class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, JavaDoubleRDD] { - override val classManifest: ClassManifest[Double] = implicitly[ClassManifest[Double]] + override val classTag: ClassTag[Double] = implicitly[ClassTag[Double]] override val rdd: RDD[Double] = srdd.map(x => Double.valueOf(x)) @@ -42,7 +47,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ def cache(): JavaDoubleRDD = fromRDD(srdd.cache()) - /** + /** * Set this RDD's storage level to persist its values across operations after the first time * it is computed. Can only be called once on each RDD. */ @@ -106,7 +111,7 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav /** * Return an RDD with the elements from `this` that are not in `other`. - * + * * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting * RDD will be <= us. */ @@ -182,6 +187,44 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav /** (Experimental) Approximate operation to return the sum within a timeout. */ def sumApprox(timeout: Long): PartialResult[BoundedDouble] = srdd.sumApprox(timeout) + + /** + * Compute a histogram of the data using bucketCount number of buckets evenly + * spaced between the minimum and maximum of the RDD. For example if the min + * value is 0 and the max is 100 and there are two buckets the resulting + * buckets will be [0,50) [50,100]. bucketCount must be at least 1 + * If the RDD contains infinity, NaN throws an exception + * If the elements in RDD do not vary (max == min) always returns a single bucket. + */ + def histogram(bucketCount: Int): Pair[Array[scala.Double], Array[Long]] = { + val result = srdd.histogram(bucketCount) + (result._1, result._2) + } + + /** + * Compute a histogram using the provided buckets. The buckets are all open + * to the left except for the last which is closed + * e.g. for the array + * [1,10,20,50] the buckets are [1,10) [10,20) [20,50] + * e.g 1<=x<10 , 10<=x<20, 20<=x<50 + * And on the input of 1 and 50 we would have a histogram of 1,0,0 + * + * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched + * from an O(log n) inseration to O(1) per element. (where n = # buckets) if you set evenBuckets + * to true. + * buckets must be sorted and not contain any duplicates. + * buckets array must be at least two elements + * All NaN entries are treated the same. If you have a NaN bucket it must be + * the maximum value of the last position and all NaN entries will be counted + * in that bucket. + */ + def histogram(buckets: Array[scala.Double]): Array[Long] = { + srdd.histogram(buckets, false) + } + + def histogram(buckets: Array[Double], evenBuckets: Boolean): Array[Long] = { + srdd.histogram(buckets.map(_.toDouble), evenBuckets) + } } object JavaDoubleRDD { diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala index 2142fd7327..363667fa86 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaPairRDD.scala @@ -22,6 +22,7 @@ import java.util.Comparator import scala.Tuple2 import scala.collection.JavaConversions._ +import scala.reflect.ClassTag import com.google.common.base.Optional import org.apache.hadoop.io.compress.CompressionCodec @@ -43,13 +44,13 @@ import org.apache.spark.rdd.OrderedRDDFunctions import org.apache.spark.storage.StorageLevel -class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManifest[K], - implicit val vManifest: ClassManifest[V]) extends JavaRDDLike[(K, V), JavaPairRDD[K, V]] { +class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kClassTag: ClassTag[K], + implicit val vClassTag: ClassTag[V]) extends JavaRDDLike[(K, V), JavaPairRDD[K, V]] { override def wrapRDD(rdd: RDD[(K, V)]): JavaPairRDD[K, V] = JavaPairRDD.fromRDD(rdd) - override val classManifest: ClassManifest[(K, V)] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K, V]]] + override val classTag: ClassTag[(K, V)] = + implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[Tuple2[K, V]]] import JavaPairRDD._ @@ -58,7 +59,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif /** Persist this RDD with the default storage level (`MEMORY_ONLY`). */ def cache(): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd.cache()) - /** + /** * Set this RDD's storage level to persist its values across operations after the first time * it is computed. Can only be called once on each RDD. */ @@ -138,14 +139,14 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif override def first(): (K, V) = rdd.first() // Pair RDD functions - + /** - * Generic function to combine the elements for each key using a custom set of aggregation - * functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a - * "combined type" C * Note that V and C can be different -- for example, one might group an - * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three + * Generic function to combine the elements for each key using a custom set of aggregation + * functions. Turns a JavaPairRDD[(K, V)] into a result of type JavaPairRDD[(K, C)], for a + * "combined type" C * Note that V and C can be different -- for example, one might group an + * RDD of type (Int, Int) into an RDD of type (Int, List[Int]). Users provide three * functions: - * + * * - `createCombiner`, which turns a V into a C (e.g., creates a one-element list) * - `mergeValue`, to merge a V into a C (e.g., adds it to the end of a list) * - `mergeCombiners`, to combine two C's into a single one. @@ -157,8 +158,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif mergeValue: JFunction2[C, V, C], mergeCombiners: JFunction2[C, C, C], partitioner: Partitioner): JavaPairRDD[K, C] = { - implicit val cm: ClassManifest[C] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[C]] + implicit val cm: ClassTag[C] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[C]] fromRDD(rdd.combineByKey( createCombiner, mergeValue, @@ -195,14 +195,14 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif /** Count the number of elements for each key, and return the result to the master as a Map. */ def countByKey(): java.util.Map[K, Long] = mapAsJavaMap(rdd.countByKey()) - /** + /** * (Experimental) Approximate version of countByKey that can return a partial result if it does * not finish within a timeout. */ def countByKeyApprox(timeout: Long): PartialResult[java.util.Map[K, BoundedDouble]] = rdd.countByKeyApprox(timeout).map(mapAsJavaMap) - /** + /** * (Experimental) Approximate version of countByKey that can return a partial result if it does * not finish within a timeout. */ @@ -258,7 +258,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif /** * Return an RDD with the elements from `this` that are not in `other`. - * + * * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting * RDD will be <= us. */ @@ -315,15 +315,14 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif fromRDD(joinResult.mapValues{case (v, w) => (JavaUtils.optionToOptional(v), w)}) } - /** + /** * Simplified version of combineByKey that hash-partitions the resulting RDD using the existing * partitioner/parallelism level. */ def combineByKey[C](createCombiner: JFunction[V, C], mergeValue: JFunction2[C, V, C], mergeCombiners: JFunction2[C, C, C]): JavaPairRDD[K, C] = { - implicit val cm: ClassManifest[C] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[C]] + implicit val cm: ClassTag[C] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[C]] fromRDD(combineByKey(createCombiner, mergeValue, mergeCombiners, defaultPartitioner(rdd))) } @@ -414,8 +413,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif * this also retains the original RDD's partitioning. */ def mapValues[U](f: JFunction[V, U]): JavaPairRDD[K, U] = { - implicit val cm: ClassManifest[U] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]] + implicit val cm: ClassTag[U] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[U]] fromRDD(rdd.mapValues(f)) } @@ -426,8 +424,7 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif def flatMapValues[U](f: JFunction[V, java.lang.Iterable[U]]): JavaPairRDD[K, U] = { import scala.collection.JavaConverters._ def fn = (x: V) => f.apply(x).asScala - implicit val cm: ClassManifest[U] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[U]] + implicit val cm: ClassTag[U] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[U]] fromRDD(rdd.flatMapValues(fn)) } @@ -592,6 +589,20 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif } /** + * Sort the RDD by key, so that each partition contains a sorted range of the elements. Calling + * `collect` or `save` on the resulting RDD will return or output an ordered list of records + * (in the `save` case, they will be written to multiple `part-X` files in the filesystem, in + * order of the keys). + */ + def sortByKey(comp: Comparator[K], ascending: Boolean, numPartitions: Int): JavaPairRDD[K, V] = { + class KeyOrdering(val a: K) extends Ordered[K] { + override def compare(b: K) = comp.compare(a, b) + } + implicit def toOrdered(x: K): Ordered[K] = new KeyOrdering(x) + fromRDD(new OrderedRDDFunctions[K, V, (K, V)](rdd).sortByKey(ascending, numPartitions)) + } + + /** * Return an RDD with the keys of each tuple. */ def keys(): JavaRDD[K] = JavaRDD.fromRDD[K](rdd.map(_._1)) @@ -603,22 +614,22 @@ class JavaPairRDD[K, V](val rdd: RDD[(K, V)])(implicit val kManifest: ClassManif } object JavaPairRDD { - def groupByResultToJava[K, T](rdd: RDD[(K, Seq[T])])(implicit kcm: ClassManifest[K], - vcm: ClassManifest[T]): RDD[(K, JList[T])] = + def groupByResultToJava[K, T](rdd: RDD[(K, Seq[T])])(implicit kcm: ClassTag[K], + vcm: ClassTag[T]): RDD[(K, JList[T])] = rddToPairRDDFunctions(rdd).mapValues(seqAsJavaList _) - def cogroupResultToJava[W, K, V](rdd: RDD[(K, (Seq[V], Seq[W]))])(implicit kcm: ClassManifest[K], - vcm: ClassManifest[V]): RDD[(K, (JList[V], JList[W]))] = rddToPairRDDFunctions(rdd).mapValues((x: (Seq[V], - Seq[W])) => (seqAsJavaList(x._1), seqAsJavaList(x._2))) + def cogroupResultToJava[W, K, V](rdd: RDD[(K, (Seq[V], Seq[W]))])(implicit kcm: ClassTag[K], + vcm: ClassTag[V]): RDD[(K, (JList[V], JList[W]))] = rddToPairRDDFunctions(rdd) + .mapValues((x: (Seq[V], Seq[W])) => (seqAsJavaList(x._1), seqAsJavaList(x._2))) def cogroupResult2ToJava[W1, W2, K, V](rdd: RDD[(K, (Seq[V], Seq[W1], - Seq[W2]))])(implicit kcm: ClassManifest[K]) : RDD[(K, (JList[V], JList[W1], + Seq[W2]))])(implicit kcm: ClassTag[K]) : RDD[(K, (JList[V], JList[W1], JList[W2]))] = rddToPairRDDFunctions(rdd).mapValues( (x: (Seq[V], Seq[W1], Seq[W2])) => (seqAsJavaList(x._1), seqAsJavaList(x._2), seqAsJavaList(x._3))) - def fromRDD[K: ClassManifest, V: ClassManifest](rdd: RDD[(K, V)]): JavaPairRDD[K, V] = + def fromRDD[K: ClassTag, V: ClassTag](rdd: RDD[(K, V)]): JavaPairRDD[K, V] = new JavaPairRDD[K, V](rdd) implicit def toRDD[K, V](rdd: JavaPairRDD[K, V]): RDD[(K, V)] = rdd.rdd @@ -626,10 +637,8 @@ object JavaPairRDD { /** Convert a JavaRDD of key-value pairs to JavaPairRDD. */ def fromJavaRDD[K, V](rdd: JavaRDD[(K, V)]): JavaPairRDD[K, V] = { - implicit val cmk: ClassManifest[K] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]] - implicit val cmv: ClassManifest[V] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]] + implicit val cmk: ClassTag[K] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]] + implicit val cmv: ClassTag[V] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]] new JavaPairRDD[K, V](rdd.rdd) } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala index 3b359a8fd6..c47657f512 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDD.scala @@ -17,12 +17,14 @@ package org.apache.spark.api.java +import scala.reflect.ClassTag + import org.apache.spark._ import org.apache.spark.rdd.RDD import org.apache.spark.api.java.function.{Function => JFunction} import org.apache.spark.storage.StorageLevel -class JavaRDD[T](val rdd: RDD[T])(implicit val classManifest: ClassManifest[T]) extends +class JavaRDD[T](val rdd: RDD[T])(implicit val classTag: ClassTag[T]) extends JavaRDDLike[T, JavaRDD[T]] { override def wrapRDD(rdd: RDD[T]): JavaRDD[T] = JavaRDD.fromRDD(rdd) @@ -127,8 +129,7 @@ JavaRDDLike[T, JavaRDD[T]] { object JavaRDD { - implicit def fromRDD[T: ClassManifest](rdd: RDD[T]): JavaRDD[T] = new JavaRDD[T](rdd) + implicit def fromRDD[T: ClassTag](rdd: RDD[T]): JavaRDD[T] = new JavaRDD[T](rdd) implicit def toRDD[T](rdd: JavaRDD[T]): RDD[T] = rdd.rdd } - diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala index 7a3568c5ef..9e912d3adb 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala @@ -20,6 +20,7 @@ package org.apache.spark.api.java import java.util.{List => JList, Comparator} import scala.Tuple2 import scala.collection.JavaConversions._ +import scala.reflect.ClassTag import com.google.common.base.Optional import org.apache.hadoop.io.compress.CompressionCodec @@ -35,7 +36,7 @@ import org.apache.spark.storage.StorageLevel trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def wrapRDD(rdd: RDD[T]): This - implicit val classManifest: ClassManifest[T] + implicit val classTag: ClassTag[T] def rdd: RDD[T] @@ -71,7 +72,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return a new RDD by applying a function to each partition of this RDD, while tracking the index * of the original partition. */ - def mapPartitionsWithIndex[R: ClassManifest]( + def mapPartitionsWithIndex[R: ClassTag]( f: JFunction2[Int, java.util.Iterator[T], java.util.Iterator[R]], preservesPartitioning: Boolean = false): JavaRDD[R] = new JavaRDD(rdd.mapPartitionsWithIndex(((a,b) => f(a,asJavaIterator(b))), @@ -87,7 +88,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Return a new RDD by applying a function to all elements of this RDD. */ def map[K2, V2](f: PairFunction[T, K2, V2]): JavaPairRDD[K2, V2] = { - def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K2, V2]]] + def cm = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[Tuple2[K2, V2]]] new JavaPairRDD(rdd.map(f)(cm))(f.keyType(), f.valueType()) } @@ -118,7 +119,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def flatMap[K2, V2](f: PairFlatMapFunction[T, K2, V2]): JavaPairRDD[K2, V2] = { import scala.collection.JavaConverters._ def fn = (x: T) => f.apply(x).asScala - def cm = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[Tuple2[K2, V2]]] + def cm = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[Tuple2[K2, V2]]] JavaPairRDD.fromRDD(rdd.flatMap(fn)(cm))(f.keyType(), f.valueType()) } @@ -158,18 +159,16 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * elements (a, b) where a is in `this` and b is in `other`. */ def cartesian[U](other: JavaRDDLike[U, _]): JavaPairRDD[T, U] = - JavaPairRDD.fromRDD(rdd.cartesian(other.rdd)(other.classManifest))(classManifest, - other.classManifest) + JavaPairRDD.fromRDD(rdd.cartesian(other.rdd)(other.classTag))(classTag, other.classTag) /** * Return an RDD of grouped elements. Each group consists of a key and a sequence of elements * mapping to that key. */ def groupBy[K](f: JFunction[T, K]): JavaPairRDD[K, JList[T]] = { - implicit val kcm: ClassManifest[K] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]] - implicit val vcm: ClassManifest[JList[T]] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[JList[T]]] + implicit val kcm: ClassTag[K] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]] + implicit val vcm: ClassTag[JList[T]] = + implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[JList[T]]] JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f)(f.returnType)))(kcm, vcm) } @@ -178,10 +177,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * mapping to that key. */ def groupBy[K](f: JFunction[T, K], numPartitions: Int): JavaPairRDD[K, JList[T]] = { - implicit val kcm: ClassManifest[K] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]] - implicit val vcm: ClassManifest[JList[T]] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[JList[T]]] + implicit val kcm: ClassTag[K] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]] + implicit val vcm: ClassTag[JList[T]] = + implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[JList[T]]] JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f, numPartitions)(f.returnType)))(kcm, vcm) } @@ -209,7 +207,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * a map on the other). */ def zip[U](other: JavaRDDLike[U, _]): JavaPairRDD[T, U] = { - JavaPairRDD.fromRDD(rdd.zip(other.rdd)(other.classManifest))(classManifest, other.classManifest) + JavaPairRDD.fromRDD(rdd.zip(other.rdd)(other.classTag))(classTag, other.classTag) } /** @@ -224,7 +222,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { def fn = (x: Iterator[T], y: Iterator[U]) => asScalaIterator( f.apply(asJavaIterator(x), asJavaIterator(y)).iterator()) JavaRDD.fromRDD( - rdd.zipPartitions(other.rdd)(fn)(other.classManifest, f.elementType()))(f.elementType()) + rdd.zipPartitions(other.rdd)(fn)(other.classTag, f.elementType()))(f.elementType()) } // Actions (launch a job to return a value to the user program) @@ -356,7 +354,7 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable { * Creates tuples of the elements in this RDD by applying `f`. */ def keyBy[K](f: JFunction[T, K]): JavaPairRDD[K, T] = { - implicit val kcm: ClassManifest[K] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]] + implicit val kcm: ClassTag[K] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]] JavaPairRDD.fromRDD(rdd.keyBy(f)) } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index 8869e072bf..acf328aa6a 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -21,6 +21,7 @@ import java.util.{Map => JMap} import scala.collection.JavaConversions import scala.collection.JavaConversions._ +import scala.reflect.ClassTag import org.apache.hadoop.conf.Configuration import org.apache.hadoop.mapred.InputFormat @@ -82,8 +83,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork /** Distribute a local Scala collection to form an RDD. */ def parallelize[T](list: java.util.List[T], numSlices: Int): JavaRDD[T] = { - implicit val cm: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] sc.parallelize(JavaConversions.asScalaBuffer(list), numSlices) } @@ -94,10 +94,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork /** Distribute a local Scala collection to form an RDD. */ def parallelizePairs[K, V](list: java.util.List[Tuple2[K, V]], numSlices: Int) : JavaPairRDD[K, V] = { - implicit val kcm: ClassManifest[K] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[K]] - implicit val vcm: ClassManifest[V] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]] + implicit val kcm: ClassTag[K] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[K]] + implicit val vcm: ClassTag[V] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[V]] JavaPairRDD.fromRDD(sc.parallelize(JavaConversions.asScalaBuffer(list), numSlices)) } @@ -132,16 +130,16 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork valueClass: Class[V], minSplits: Int ): JavaPairRDD[K, V] = { - implicit val kcm = ClassManifest.fromClass(keyClass) - implicit val vcm = ClassManifest.fromClass(valueClass) + implicit val kcm: ClassTag[K] = ClassTag(keyClass) + implicit val vcm: ClassTag[V] = ClassTag(valueClass) new JavaPairRDD(sc.sequenceFile(path, keyClass, valueClass, minSplits)) } /**Get an RDD for a Hadoop SequenceFile. */ def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]): JavaPairRDD[K, V] = { - implicit val kcm = ClassManifest.fromClass(keyClass) - implicit val vcm = ClassManifest.fromClass(valueClass) + implicit val kcm: ClassTag[K] = ClassTag(keyClass) + implicit val vcm: ClassTag[V] = ClassTag(valueClass) new JavaPairRDD(sc.sequenceFile(path, keyClass, valueClass)) } @@ -153,8 +151,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork * that there's very little effort required to save arbitrary objects. */ def objectFile[T](path: String, minSplits: Int): JavaRDD[T] = { - implicit val cm: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] sc.objectFile(path, minSplits)(cm) } @@ -166,8 +163,7 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork * that there's very little effort required to save arbitrary objects. */ def objectFile[T](path: String): JavaRDD[T] = { - implicit val cm: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + implicit val cm: ClassTag[T] = implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] sc.objectFile(path)(cm) } @@ -183,8 +179,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork valueClass: Class[V], minSplits: Int ): JavaPairRDD[K, V] = { - implicit val kcm = ClassManifest.fromClass(keyClass) - implicit val vcm = ClassManifest.fromClass(valueClass) + implicit val kcm: ClassTag[K] = ClassTag(keyClass) + implicit val vcm: ClassTag[V] = ClassTag(valueClass) new JavaPairRDD(sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass, minSplits)) } @@ -199,8 +195,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork keyClass: Class[K], valueClass: Class[V] ): JavaPairRDD[K, V] = { - implicit val kcm = ClassManifest.fromClass(keyClass) - implicit val vcm = ClassManifest.fromClass(valueClass) + implicit val kcm: ClassTag[K] = ClassTag(keyClass) + implicit val vcm: ClassTag[V] = ClassTag(valueClass) new JavaPairRDD(sc.hadoopRDD(conf, inputFormatClass, keyClass, valueClass)) } @@ -212,8 +208,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork valueClass: Class[V], minSplits: Int ): JavaPairRDD[K, V] = { - implicit val kcm = ClassManifest.fromClass(keyClass) - implicit val vcm = ClassManifest.fromClass(valueClass) + implicit val kcm: ClassTag[K] = ClassTag(keyClass) + implicit val vcm: ClassTag[V] = ClassTag(valueClass) new JavaPairRDD(sc.hadoopFile(path, inputFormatClass, keyClass, valueClass, minSplits)) } @@ -224,8 +220,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork keyClass: Class[K], valueClass: Class[V] ): JavaPairRDD[K, V] = { - implicit val kcm = ClassManifest.fromClass(keyClass) - implicit val vcm = ClassManifest.fromClass(valueClass) + implicit val kcm: ClassTag[K] = ClassTag(keyClass) + implicit val vcm: ClassTag[V] = ClassTag(valueClass) new JavaPairRDD(sc.hadoopFile(path, inputFormatClass, keyClass, valueClass)) } @@ -240,8 +236,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork kClass: Class[K], vClass: Class[V], conf: Configuration): JavaPairRDD[K, V] = { - implicit val kcm = ClassManifest.fromClass(kClass) - implicit val vcm = ClassManifest.fromClass(vClass) + implicit val kcm: ClassTag[K] = ClassTag(kClass) + implicit val vcm: ClassTag[V] = ClassTag(vClass) new JavaPairRDD(sc.newAPIHadoopFile(path, fClass, kClass, vClass, conf)) } @@ -254,15 +250,15 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork fClass: Class[F], kClass: Class[K], vClass: Class[V]): JavaPairRDD[K, V] = { - implicit val kcm = ClassManifest.fromClass(kClass) - implicit val vcm = ClassManifest.fromClass(vClass) + implicit val kcm: ClassTag[K] = ClassTag(kClass) + implicit val vcm: ClassTag[V] = ClassTag(vClass) new JavaPairRDD(sc.newAPIHadoopRDD(conf, fClass, kClass, vClass)) } /** Build the union of two or more RDDs. */ override def union[T](first: JavaRDD[T], rest: java.util.List[JavaRDD[T]]): JavaRDD[T] = { val rdds: Seq[RDD[T]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.rdd) - implicit val cm: ClassManifest[T] = first.classManifest + implicit val cm: ClassTag[T] = first.classTag sc.union(rdds)(cm) } @@ -270,9 +266,9 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork override def union[K, V](first: JavaPairRDD[K, V], rest: java.util.List[JavaPairRDD[K, V]]) : JavaPairRDD[K, V] = { val rdds: Seq[RDD[(K, V)]] = (Seq(first) ++ asScalaBuffer(rest)).map(_.rdd) - implicit val cm: ClassManifest[(K, V)] = first.classManifest - implicit val kcm: ClassManifest[K] = first.kManifest - implicit val vcm: ClassManifest[V] = first.vManifest + implicit val cm: ClassTag[(K, V)] = first.classTag + implicit val kcm: ClassTag[K] = first.kClassTag + implicit val vcm: ClassTag[V] = first.vClassTag new JavaPairRDD(sc.union(rdds)(cm))(kcm, vcm) } @@ -405,8 +401,8 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork } protected def checkpointFile[T](path: String): JavaRDD[T] = { - implicit val cm: ClassManifest[T] = - implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[T]] + implicit val cm: ClassTag[T] = + implicitly[ClassTag[AnyRef]].asInstanceOf[ClassTag[T]] new JavaRDD(sc.checkpointFile(path)) } } diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java index c9cbce5624..2090efd3b9 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContextVarargsWorkaround.java @@ -17,7 +17,6 @@ package org.apache.spark.api.java; -import java.util.Arrays; import java.util.ArrayList; import java.util.List; diff --git a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala index 2dfda8b09a..bdb01f7670 100644 --- a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala +++ b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction.scala @@ -17,9 +17,11 @@ package org.apache.spark.api.java.function +import scala.reflect.ClassTag + /** * A function that returns zero or more output records from each input record. */ abstract class FlatMapFunction[T, R] extends Function[T, java.lang.Iterable[R]] { - def elementType() : ClassManifest[R] = ClassManifest.Any.asInstanceOf[ClassManifest[R]] + def elementType(): ClassTag[R] = ClassTag.Any.asInstanceOf[ClassTag[R]] } diff --git a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala index 528e1c0a7c..aae1349c5e 100644 --- a/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala +++ b/core/src/main/scala/org/apache/spark/api/java/function/FlatMapFunction2.scala @@ -17,9 +17,11 @@ package org.apache.spark.api.java.function +import scala.reflect.ClassTag + /** * A function that takes two inputs and returns zero or more output records. */ abstract class FlatMapFunction2[A, B, C] extends Function2[A, B, java.lang.Iterable[C]] { - def elementType() : ClassManifest[C] = ClassManifest.Any.asInstanceOf[ClassManifest[C]] + def elementType() : ClassTag[C] = ClassTag.Any.asInstanceOf[ClassTag[C]] } diff --git a/core/src/main/scala/org/apache/spark/api/java/function/Function.java b/core/src/main/scala/org/apache/spark/api/java/function/Function.java index ce368ee01b..537439ef53 100644 --- a/core/src/main/scala/org/apache/spark/api/java/function/Function.java +++ b/core/src/main/scala/org/apache/spark/api/java/function/Function.java @@ -17,8 +17,8 @@ package org.apache.spark.api.java.function; -import scala.reflect.ClassManifest; -import scala.reflect.ClassManifest$; +import scala.reflect.ClassTag; +import scala.reflect.ClassTag$; import java.io.Serializable; @@ -29,8 +29,8 @@ import java.io.Serializable; * when mapping RDDs of other types. */ public abstract class Function<T, R> extends WrappedFunction1<T, R> implements Serializable { - public ClassManifest<R> returnType() { - return (ClassManifest<R>) ClassManifest$.MODULE$.fromClass(Object.class); + public ClassTag<R> returnType() { + return ClassTag$.MODULE$.apply(Object.class); } } diff --git a/core/src/main/scala/org/apache/spark/api/java/function/Function2.java b/core/src/main/scala/org/apache/spark/api/java/function/Function2.java index 44ad559d48..a2d1214fb4 100644 --- a/core/src/main/scala/org/apache/spark/api/java/function/Function2.java +++ b/core/src/main/scala/org/apache/spark/api/java/function/Function2.java @@ -17,8 +17,8 @@ package org.apache.spark.api.java.function; -import scala.reflect.ClassManifest; -import scala.reflect.ClassManifest$; +import scala.reflect.ClassTag; +import scala.reflect.ClassTag$; import java.io.Serializable; @@ -28,8 +28,8 @@ import java.io.Serializable; public abstract class Function2<T1, T2, R> extends WrappedFunction2<T1, T2, R> implements Serializable { - public ClassManifest<R> returnType() { - return (ClassManifest<R>) ClassManifest$.MODULE$.fromClass(Object.class); + public ClassTag<R> returnType() { + return (ClassTag<R>) ClassTag$.MODULE$.apply(Object.class); } } diff --git a/core/src/main/scala/org/apache/spark/api/java/function/Function3.java b/core/src/main/scala/org/apache/spark/api/java/function/Function3.java index ac6178924a..fb1deceab5 100644 --- a/core/src/main/scala/org/apache/spark/api/java/function/Function3.java +++ b/core/src/main/scala/org/apache/spark/api/java/function/Function3.java @@ -17,8 +17,8 @@ package org.apache.spark.api.java.function; -import scala.reflect.ClassManifest; -import scala.reflect.ClassManifest$; +import scala.reflect.ClassTag; +import scala.reflect.ClassTag$; import scala.runtime.AbstractFunction2; import java.io.Serializable; @@ -29,8 +29,8 @@ import java.io.Serializable; public abstract class Function3<T1, T2, T3, R> extends WrappedFunction3<T1, T2, T3, R> implements Serializable { - public ClassManifest<R> returnType() { - return (ClassManifest<R>) ClassManifest$.MODULE$.fromClass(Object.class); + public ClassTag<R> returnType() { + return (ClassTag<R>) ClassTag$.MODULE$.apply(Object.class); } } diff --git a/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java index 6d76a8f970..ca485b3cc2 100644 --- a/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java +++ b/core/src/main/scala/org/apache/spark/api/java/function/PairFlatMapFunction.java @@ -18,8 +18,8 @@ package org.apache.spark.api.java.function; import scala.Tuple2; -import scala.reflect.ClassManifest; -import scala.reflect.ClassManifest$; +import scala.reflect.ClassTag; +import scala.reflect.ClassTag$; import java.io.Serializable; @@ -33,11 +33,11 @@ public abstract class PairFlatMapFunction<T, K, V> extends WrappedFunction1<T, Iterable<Tuple2<K, V>>> implements Serializable { - public ClassManifest<K> keyType() { - return (ClassManifest<K>) ClassManifest$.MODULE$.fromClass(Object.class); + public ClassTag<K> keyType() { + return (ClassTag<K>) ClassTag$.MODULE$.apply(Object.class); } - public ClassManifest<V> valueType() { - return (ClassManifest<V>) ClassManifest$.MODULE$.fromClass(Object.class); + public ClassTag<V> valueType() { + return (ClassTag<V>) ClassTag$.MODULE$.apply(Object.class); } } diff --git a/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java b/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java index ede7ceefb5..cbe2306026 100644 --- a/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java +++ b/core/src/main/scala/org/apache/spark/api/java/function/PairFunction.java @@ -18,8 +18,8 @@ package org.apache.spark.api.java.function; import scala.Tuple2; -import scala.reflect.ClassManifest; -import scala.reflect.ClassManifest$; +import scala.reflect.ClassTag; +import scala.reflect.ClassTag$; import java.io.Serializable; @@ -31,11 +31,11 @@ import java.io.Serializable; public abstract class PairFunction<T, K, V> extends WrappedFunction1<T, Tuple2<K, V>> implements Serializable { - public ClassManifest<K> keyType() { - return (ClassManifest<K>) ClassManifest$.MODULE$.fromClass(Object.class); + public ClassTag<K> keyType() { + return (ClassTag<K>) ClassTag$.MODULE$.apply(Object.class); } - public ClassManifest<V> valueType() { - return (ClassManifest<V>) ClassManifest$.MODULE$.fromClass(Object.class); + public ClassTag<V> valueType() { + return (ClassTag<V>) ClassTag$.MODULE$.apply(Object.class); } } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 12b4d94a56..a659cc06c2 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -22,18 +22,17 @@ import java.net._ import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections} import scala.collection.JavaConversions._ +import scala.reflect.ClassTag import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} import org.apache.spark.broadcast.Broadcast import org.apache.spark._ import org.apache.spark.rdd.RDD -import org.apache.spark.rdd.PipedRDD import org.apache.spark.util.Utils - -private[spark] class PythonRDD[T: ClassManifest]( +private[spark] class PythonRDD[T: ClassTag]( parent: RDD[T], - command: Seq[String], + command: Array[Byte], envVars: JMap[String, String], pythonIncludes: JList[String], preservePartitoning: Boolean, @@ -44,21 +43,10 @@ private[spark] class PythonRDD[T: ClassManifest]( val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt - // Similar to Runtime.exec(), if we are given a single string, split it into words - // using a standard StringTokenizer (i.e. by spaces) - def this(parent: RDD[T], command: String, envVars: JMap[String, String], - pythonIncludes: JList[String], - preservePartitoning: Boolean, pythonExec: String, - broadcastVars: JList[Broadcast[Array[Byte]]], - accumulator: Accumulator[JList[Array[Byte]]]) = - this(parent, PipedRDD.tokenize(command), envVars, pythonIncludes, preservePartitoning, pythonExec, - broadcastVars, accumulator) - override def getPartitions = parent.partitions override val partitioner = if (preservePartitoning) parent.partitioner else None - override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { val startTime = System.currentTimeMillis val env = SparkEnv.get @@ -71,11 +59,10 @@ private[spark] class PythonRDD[T: ClassManifest]( SparkEnv.set(env) val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) val dataOut = new DataOutputStream(stream) - val printOut = new PrintWriter(stream) // Partition index dataOut.writeInt(split.index) // sparkFilesDir - PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dataOut) + dataOut.writeUTF(SparkFiles.getRootDirectory) // Broadcast variables dataOut.writeInt(broadcastVars.length) for (broadcast <- broadcastVars) { @@ -85,21 +72,16 @@ private[spark] class PythonRDD[T: ClassManifest]( } // Python includes (*.zip and *.egg files) dataOut.writeInt(pythonIncludes.length) - for (f <- pythonIncludes) { - PythonRDD.writeAsPickle(f, dataOut) - } + pythonIncludes.foreach(dataOut.writeUTF) dataOut.flush() - // Serialized user code - for (elem <- command) { - printOut.println(elem) - } - printOut.flush() + // Serialized command: + dataOut.writeInt(command.length) + dataOut.write(command) // Data values for (elem <- parent.iterator(split, context)) { - PythonRDD.writeAsPickle(elem, dataOut) + PythonRDD.writeToStream(elem, dataOut) } dataOut.flush() - printOut.flush() worker.shutdownOutput() } catch { case e: IOException => @@ -132,7 +114,7 @@ private[spark] class PythonRDD[T: ClassManifest]( val obj = new Array[Byte](length) stream.readFully(obj) obj - case -3 => + case SpecialLengths.TIMING_DATA => // Timing data from worker val bootTime = stream.readLong() val initTime = stream.readLong() @@ -143,30 +125,30 @@ private[spark] class PythonRDD[T: ClassManifest]( val total = finishTime - startTime logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, init, finish)) read - case -2 => + case SpecialLengths.PYTHON_EXCEPTION_THROWN => // Signals that an exception has been thrown in python val exLength = stream.readInt() val obj = new Array[Byte](exLength) stream.readFully(obj) throw new PythonException(new String(obj)) - case -1 => + case SpecialLengths.END_OF_DATA_SECTION => // We've finished the data section of the output, but we can still - // read some accumulator updates; let's do that, breaking when we - // get a negative length record. - var len2 = stream.readInt() - while (len2 >= 0) { - val update = new Array[Byte](len2) + // read some accumulator updates: + val numAccumulatorUpdates = stream.readInt() + (1 to numAccumulatorUpdates).foreach { _ => + val updateLen = stream.readInt() + val update = new Array[Byte](updateLen) stream.readFully(update) accumulator += Collections.singletonList(update) - len2 = stream.readInt() + } - new Array[Byte](0) + Array.empty[Byte] } } catch { case eof: EOFException => { throw new SparkException("Python worker exited unexpectedly (crashed)", eof) } - case e => throw e + case e: Throwable => throw e } } @@ -197,62 +179,15 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends val asJavaPairRDD : JavaPairRDD[Long, Array[Byte]] = JavaPairRDD.fromRDD(this) } -private[spark] object PythonRDD { - - /** Strips the pickle PROTO and STOP opcodes from the start and end of a pickle */ - def stripPickle(arr: Array[Byte]) : Array[Byte] = { - arr.slice(2, arr.length - 1) - } +private object SpecialLengths { + val END_OF_DATA_SECTION = -1 + val PYTHON_EXCEPTION_THROWN = -2 + val TIMING_DATA = -3 +} - /** - * Write strings, pickled Python objects, or pairs of pickled objects to a data output stream. - * The data format is a 32-bit integer representing the pickled object's length (in bytes), - * followed by the pickled data. - * - * Pickle module: - * - * http://docs.python.org/2/library/pickle.html - * - * The pickle protocol is documented in the source of the `pickle` and `pickletools` modules: - * - * http://hg.python.org/cpython/file/2.6/Lib/pickle.py - * http://hg.python.org/cpython/file/2.6/Lib/pickletools.py - * - * @param elem the object to write - * @param dOut a data output stream - */ - def writeAsPickle(elem: Any, dOut: DataOutputStream) { - if (elem.isInstanceOf[Array[Byte]]) { - val arr = elem.asInstanceOf[Array[Byte]] - dOut.writeInt(arr.length) - dOut.write(arr) - } else if (elem.isInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]) { - val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]] - val length = t._1.length + t._2.length - 3 - 3 + 4 // stripPickle() removes 3 bytes - dOut.writeInt(length) - dOut.writeByte(Pickle.PROTO) - dOut.writeByte(Pickle.TWO) - dOut.write(PythonRDD.stripPickle(t._1)) - dOut.write(PythonRDD.stripPickle(t._2)) - dOut.writeByte(Pickle.TUPLE2) - dOut.writeByte(Pickle.STOP) - } else if (elem.isInstanceOf[String]) { - // For uniformity, strings are wrapped into Pickles. - val s = elem.asInstanceOf[String].getBytes("UTF-8") - val length = 2 + 1 + 4 + s.length + 1 - dOut.writeInt(length) - dOut.writeByte(Pickle.PROTO) - dOut.writeByte(Pickle.TWO) - dOut.write(Pickle.BINUNICODE) - dOut.writeInt(Integer.reverseBytes(s.length)) - dOut.write(s) - dOut.writeByte(Pickle.STOP) - } else { - throw new SparkException("Unexpected RDD type") - } - } +private[spark] object PythonRDD { - def readRDDFromPickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) : + def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int): JavaRDD[Array[Byte]] = { val file = new DataInputStream(new FileInputStream(filename)) val objs = new collection.mutable.ArrayBuffer[Array[Byte]] @@ -265,41 +200,47 @@ private[spark] object PythonRDD { } } catch { case eof: EOFException => {} - case e => throw e + case e: Throwable => throw e } JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) } - def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) { + def writeToStream(elem: Any, dataOut: DataOutputStream) { + elem match { + case bytes: Array[Byte] => + dataOut.writeInt(bytes.length) + dataOut.write(bytes) + case pair: (Array[Byte], Array[Byte]) => + dataOut.writeInt(pair._1.length) + dataOut.write(pair._1) + dataOut.writeInt(pair._2.length) + dataOut.write(pair._2) + case str: String => + dataOut.writeUTF(str) + case other => + throw new SparkException("Unexpected element type " + other.getClass) + } + } + + def writeToFile[T](items: java.util.Iterator[T], filename: String) { import scala.collection.JavaConverters._ - writeIteratorToPickleFile(items.asScala, filename) + writeToFile(items.asScala, filename) } - def writeIteratorToPickleFile[T](items: Iterator[T], filename: String) { + def writeToFile[T](items: Iterator[T], filename: String) { val file = new DataOutputStream(new FileOutputStream(filename)) for (item <- items) { - writeAsPickle(item, file) + writeToStream(item, file) } file.close() } def takePartition[T](rdd: RDD[T], partition: Int): Iterator[T] = { - implicit val cm : ClassManifest[T] = rdd.elementClassManifest + implicit val cm : ClassTag[T] = rdd.elementClassTag rdd.context.runJob(rdd, ((x: Iterator[T]) => x.toArray), Seq(partition), true).head.iterator } } -private object Pickle { - val PROTO: Byte = 0x80.toByte - val TWO: Byte = 0x02.toByte - val BINUNICODE: Byte = 'X' - val STOP: Byte = '.' - val TUPLE2: Byte = 0x86.toByte - val EMPTY_LIST: Byte = ']' - val MARK: Byte = '(' - val APPENDS: Byte = 'e' -} - private class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] { override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8") } diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala index 67d45723ba..f291266fcf 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonWorkerFactory.scala @@ -64,7 +64,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String startDaemon() new Socket(daemonHost, daemonPort) } - case e => throw e + case e: Throwable => throw e } } } @@ -198,7 +198,7 @@ private[spark] class PythonWorkerFactory(pythonExec: String, envVars: Map[String } }.start() } catch { - case e => { + case e: Throwable => { stopDaemon() throw e } diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 609464e38d..47db720416 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -19,6 +19,7 @@ package org.apache.spark.broadcast import java.io.{File, FileOutputStream, ObjectInputStream, OutputStream} import java.net.URL +import java.util.concurrent.TimeUnit import it.unimi.dsi.fastutil.io.FastBufferedInputStream import it.unimi.dsi.fastutil.io.FastBufferedOutputStream @@ -83,6 +84,8 @@ private object HttpBroadcast extends Logging { private val files = new TimeStampedHashSet[String] private val cleaner = new MetadataCleaner(MetadataCleanerType.HTTP_BROADCAST, cleanup) + private val httpReadTimeout = TimeUnit.MILLISECONDS.convert(5,TimeUnit.MINUTES).toInt + private lazy val compressionCodec = CompressionCodec.createCodec() def initialize(isDriver: Boolean) { @@ -138,10 +141,13 @@ private object HttpBroadcast extends Logging { def read[T](id: Long): T = { val url = serverUri + "/" + BroadcastBlockId(id).name val in = { + val httpConnection = new URL(url).openConnection() + httpConnection.setReadTimeout(httpReadTimeout) + val inputStream = httpConnection.getInputStream() if (compress) { - compressionCodec.compressedInputStream(new URL(url).openStream()) + compressionCodec.compressedInputStream(inputStream) } else { - new FastBufferedInputStream(new URL(url).openStream(), bufferSize) + new FastBufferedInputStream(inputStream, bufferSize) } } val ser = SparkEnv.get.serializer.newInstance() diff --git a/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala b/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala index fcfea96ad6..37dfa7fec0 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ExecutorState.scala @@ -17,8 +17,7 @@ package org.apache.spark.deploy -private[spark] object ExecutorState - extends Enumeration("LAUNCHING", "LOADING", "RUNNING", "KILLED", "FAILED", "LOST") { +private[spark] object ExecutorState extends Enumeration { val LAUNCHING, LOADING, RUNNING, KILLED, FAILED, LOST = Value diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala index 668032a3a2..0aa8852649 100644 --- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala @@ -1,19 +1,19 @@ /* * - * * Licensed to the Apache Software Foundation (ASF) under one or more - * * contributor license agreements. See the NOTICE file distributed with - * * this work for additional information regarding copyright ownership. - * * The ASF licenses this file to You under the Apache License, Version 2.0 - * * (the "License"); you may not use this file except in compliance with - * * the License. You may obtain a copy of the License at - * * - * * http://www.apache.org/licenses/LICENSE-2.0 - * * - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, - * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * * See the License for the specific language governing permissions and - * * limitations under the License. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. * */ diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala index a724900943..59d12a3e6f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala @@ -34,11 +34,11 @@ import scala.collection.mutable.ArrayBuffer */ private[spark] class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: Int) extends Logging { - + private val localHostname = Utils.localHostName() private val masterActorSystems = ArrayBuffer[ActorSystem]() private val workerActorSystems = ArrayBuffer[ActorSystem]() - + def start(): Array[String] = { logInfo("Starting a local Spark cluster with " + numWorkers + " workers.") @@ -61,10 +61,13 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I def stop() { logInfo("Shutting down local Spark cluster.") // Stop the workers before the master so they don't get upset that it disconnected + // TODO: In Akka 2.1.x, ActorSystem.awaitTermination hangs when you have remote actors! + // This is unfortunate, but for now we just comment it out. workerActorSystems.foreach(_.shutdown()) - workerActorSystems.foreach(_.awaitTermination()) - + //workerActorSystems.foreach(_.awaitTermination()) masterActorSystems.foreach(_.shutdown()) - masterActorSystems.foreach(_.awaitTermination()) + //masterActorSystems.foreach(_.awaitTermination()) + masterActorSystems.clear() + workerActorSystems.clear() } } diff --git a/core/src/main/scala/org/apache/spark/deploy/client/Client.scala b/core/src/main/scala/org/apache/spark/deploy/client/Client.scala index 77422f61ec..4d95efa73a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/Client.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/Client.scala @@ -19,17 +19,15 @@ package org.apache.spark.deploy.client import java.util.concurrent.TimeoutException +import scala.concurrent.duration._ +import scala.concurrent.Await + import akka.actor._ -import akka.actor.Terminated +import akka.pattern.AskTimeoutException import akka.pattern.ask -import akka.util.Duration -import akka.util.duration._ -import akka.remote.RemoteClientDisconnected -import akka.remote.RemoteClientLifeCycleEvent -import akka.remote.RemoteClientShutdown -import akka.dispatch.Await - -import org.apache.spark.Logging +import akka.remote.{RemotingLifecycleEvent, DisassociatedEvent, AssociationErrorEvent} + +import org.apache.spark.{SparkException, Logging} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.Master @@ -51,18 +49,19 @@ private[spark] class Client( val REGISTRATION_TIMEOUT = 20.seconds val REGISTRATION_RETRIES = 3 + var masterAddress: Address = null var actor: ActorRef = null var appId: String = null var registered = false var activeMasterUrl: String = null class ClientActor extends Actor with Logging { - var master: ActorRef = null - var masterAddress: Address = null + var master: ActorSelection = null var alreadyDisconnected = false // To avoid calling listener.disconnected() multiple times var alreadyDead = false // To avoid calling listener.dead() multiple times override def preStart() { + context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) try { registerWithMaster() } catch { @@ -76,7 +75,7 @@ private[spark] class Client( def tryRegisterAllMasters() { for (masterUrl <- masterUrls) { logInfo("Connecting to master " + masterUrl + "...") - val actor = context.actorFor(Master.toAkkaUrl(masterUrl)) + val actor = context.actorSelection(Master.toAkkaUrl(masterUrl)) actor ! RegisterApplication(appDescription) } } @@ -84,6 +83,7 @@ private[spark] class Client( def registerWithMaster() { tryRegisterAllMasters() + import context.dispatcher var retries = 0 lazy val retryTimer: Cancellable = context.system.scheduler.schedule(REGISTRATION_TIMEOUT, REGISTRATION_TIMEOUT) { @@ -102,10 +102,13 @@ private[spark] class Client( def changeMaster(url: String) { activeMasterUrl = url - master = context.actorFor(Master.toAkkaUrl(url)) - masterAddress = master.path.address - context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) - context.watch(master) // Doesn't work with remote actors, but useful for testing + master = context.actorSelection(Master.toAkkaUrl(activeMasterUrl)) + masterAddress = activeMasterUrl match { + case Master.sparkUrlRegex(host, port) => + Address("akka.tcp", Master.systemName, host, port.toInt) + case x => + throw new SparkException("Invalid spark URL: " + x) + } } override def receive = { @@ -135,21 +138,12 @@ private[spark] class Client( case MasterChanged(masterUrl, masterWebUiUrl) => logInfo("Master has changed, new master is at " + masterUrl) - context.unwatch(master) changeMaster(masterUrl) alreadyDisconnected = false sender ! MasterChangeAcknowledged(appId) - case Terminated(actor_) if actor_ == master => - logWarning("Connection to master failed; waiting for master to reconnect...") - markDisconnected() - - case RemoteClientDisconnected(transport, address) if address == masterAddress => - logWarning("Connection to master failed; waiting for master to reconnect...") - markDisconnected() - - case RemoteClientShutdown(transport, address) if address == masterAddress => - logWarning("Connection to master failed; waiting for master to reconnect...") + case DisassociatedEvent(_, address, _) if address == masterAddress => + logWarning(s"Connection to $address failed; waiting for master to reconnect...") markDisconnected() case StopClient => diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala index fedf879eff..67e6c5d66a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ApplicationState.scala @@ -17,8 +17,7 @@ package org.apache.spark.deploy.master -private[spark] object ApplicationState - extends Enumeration("WAITING", "RUNNING", "FINISHED", "FAILED", "UNKNOWN") { +private[spark] object ApplicationState extends Enumeration { type ApplicationState = Value diff --git a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala index c0849ef324..043945a211 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/FileSystemPersistenceEngine.scala @@ -65,7 +65,7 @@ private[spark] class FileSystemPersistenceEngine( (apps, workers) } - private def serializeIntoFile(file: File, value: Serializable) { + private def serializeIntoFile(file: File, value: AnyRef) { val created = file.createNewFile() if (!created) { throw new IllegalStateException("Could not create file: " + file) } @@ -77,13 +77,13 @@ private[spark] class FileSystemPersistenceEngine( out.close() } - def deserializeFromFile[T <: Serializable](file: File)(implicit m: Manifest[T]): T = { + def deserializeFromFile[T](file: File)(implicit m: Manifest[T]): T = { val fileData = new Array[Byte](file.length().asInstanceOf[Int]) val dis = new DataInputStream(new FileInputStream(file)) dis.readFully(fileData) dis.close() - val clazz = m.erasure.asInstanceOf[Class[T]] + val clazz = m.runtimeClass.asInstanceOf[Class[T]] val serializer = serialization.serializerFor(clazz) serializer.fromBinary(fileData).asInstanceOf[T] } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index cd916672ac..c627dd3806 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -17,19 +17,20 @@ package org.apache.spark.deploy.master -import java.util.Date import java.text.SimpleDateFormat +import java.util.concurrent.TimeUnit +import java.util.Date import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.concurrent.Await +import scala.concurrent.duration._ +import scala.concurrent.duration.{Duration, FiniteDuration} import akka.actor._ -import akka.actor.Terminated -import akka.dispatch.Await import akka.pattern.ask -import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientDisconnected, RemoteClientShutdown} +import akka.remote._ import akka.serialization.SerializationExtension -import akka.util.duration._ -import akka.util.{Duration, Timeout} +import akka.util.Timeout import org.apache.spark.{Logging, SparkException} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} @@ -37,9 +38,11 @@ import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.MasterMessages._ import org.apache.spark.deploy.master.ui.MasterWebUI import org.apache.spark.metrics.MetricsSystem -import org.apache.spark.util.{AkkaUtils, Utils} +import org.apache.spark.util.{Utils, AkkaUtils} private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Actor with Logging { + import context.dispatcher + val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs val WORKER_TIMEOUT = System.getProperty("spark.worker.timeout", "60").toLong * 1000 val RETAINED_APPLICATIONS = System.getProperty("spark.deploy.retainedApplications", "200").toInt @@ -93,7 +96,7 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act override def preStart() { logInfo("Starting Spark master at " + masterUrl) // Listen for remote client disconnection events, since they don't go through Akka's watch() - context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) + context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) webUi.start() masterWebUiUrl = "http://" + masterPublicAddress + ":" + webUi.boundPort.get context.system.scheduler.schedule(0 millis, WORKER_TIMEOUT millis, self, CheckForWorkerTimeOut) @@ -113,13 +116,12 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act new BlackHolePersistenceEngine() } - leaderElectionAgent = context.actorOf(Props( - RECOVERY_MODE match { + leaderElectionAgent = RECOVERY_MODE match { case "ZOOKEEPER" => - new ZooKeeperLeaderElectionAgent(self, masterUrl) + context.actorOf(Props(classOf[ZooKeeperLeaderElectionAgent], self, masterUrl)) case _ => - new MonarchyLeaderAgent(self) - })) + context.actorOf(Props(classOf[MonarchyLeaderAgent], self)) + } } override def preRestart(reason: Throwable, message: Option[Any]) { @@ -142,9 +144,7 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act RecoveryState.ALIVE else RecoveryState.RECOVERING - logInfo("I have been elected leader! New state: " + state) - if (state == RecoveryState.RECOVERING) { beginRecovery(storedApps, storedWorkers) context.system.scheduler.scheduleOnce(WORKER_TIMEOUT millis) { completeRecovery() } @@ -156,7 +156,7 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act System.exit(0) } - case RegisterWorker(id, host, workerPort, cores, memory, webUiPort, publicAddress) => { + case RegisterWorker(id, workerHost, workerPort, cores, memory, workerWebUiPort, publicAddress) => { logInfo("Registering worker %s:%d with %d cores, %s RAM".format( host, workerPort, cores, Utils.megabytesToString(memory))) if (state == RecoveryState.STANDBY) { @@ -164,9 +164,9 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act } else if (idToWorker.contains(id)) { sender ! RegisterWorkerFailed("Duplicate worker ID") } else { - val worker = new WorkerInfo(id, host, port, cores, memory, sender, webUiPort, publicAddress) + val worker = new WorkerInfo(id, workerHost, workerPort, cores, memory, + sender, workerWebUiPort, publicAddress) registerWorker(worker) - context.watch(sender) // This doesn't work with remote actors but helps for testing persistenceEngine.addWorker(worker) sender ! RegisteredWorker(masterUrl, masterWebUiUrl) schedule() @@ -181,7 +181,6 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act val app = createApplication(description, sender) registerApplication(app) logInfo("Registered app " + description.name + " with ID " + app.id) - context.watch(sender) // This doesn't work with remote actors but helps for testing persistenceEngine.addApplication(app) sender ! RegisteredApplication(app.id, masterUrl) schedule() @@ -257,23 +256,9 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act if (canCompleteRecovery) { completeRecovery() } } - case Terminated(actor) => { - // The disconnected actor could've been either a worker or an app; remove whichever of - // those we have an entry for in the corresponding actor hashmap - actorToWorker.get(actor).foreach(removeWorker) - actorToApp.get(actor).foreach(finishApplication) - if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() } - } - - case RemoteClientDisconnected(transport, address) => { - // The disconnected client could've been either a worker or an app; remove whichever it was - addressToWorker.get(address).foreach(removeWorker) - addressToApp.get(address).foreach(finishApplication) - if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() } - } - - case RemoteClientShutdown(transport, address) => { + case DisassociatedEvent(_, address, _) => { // The disconnected client could've been either a worker or an app; remove whichever it was + logInfo(s"$address got disassociated, removing it.") addressToWorker.get(address).foreach(removeWorker) addressToApp.get(address).foreach(finishApplication) if (state == RecoveryState.RECOVERING && canCompleteRecovery) { completeRecovery() } @@ -530,9 +515,9 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act } private[spark] object Master { - private val systemName = "sparkMaster" + val systemName = "sparkMaster" private val actorName = "Master" - private val sparkUrlRegex = "spark://([^:]+):([0-9]+)".r + val sparkUrlRegex = "spark://([^:]+):([0-9]+)".r def main(argStrings: Array[String]) { val args = new MasterArguments(argStrings) @@ -540,11 +525,11 @@ private[spark] object Master { actorSystem.awaitTermination() } - /** Returns an `akka://...` URL for the Master actor given a sparkUrl `spark://host:ip`. */ + /** Returns an `akka.tcp://...` URL for the Master actor given a sparkUrl `spark://host:ip`. */ def toAkkaUrl(sparkUrl: String): String = { sparkUrl match { case sparkUrlRegex(host, port) => - "akka://%s@%s:%s/user/%s".format(systemName, host, port, actorName) + "akka.tcp://%s@%s:%s/user/%s".format(systemName, host, port, actorName) case _ => throw new SparkException("Invalid master URL: " + sparkUrl) } @@ -552,9 +537,9 @@ private[spark] object Master { def startSystemAndActor(host: String, port: Int, webUiPort: Int): (ActorSystem, Int, Int) = { val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port) - val actor = actorSystem.actorOf(Props(new Master(host, boundPort, webUiPort)), name = actorName) - val timeoutDuration = Duration.create( - System.getProperty("spark.akka.askTimeout", "10").toLong, "seconds") + val actor = actorSystem.actorOf(Props(classOf[Master], host, boundPort, webUiPort), name = actorName) + val timeoutDuration: FiniteDuration = Duration.create( + System.getProperty("spark.akka.askTimeout", "10").toLong, TimeUnit.SECONDS) implicit val timeout = Timeout(timeoutDuration) val respFuture = actor ? RequestWebUIPort // ask pattern val resp = Await.result(respFuture, timeoutDuration).asInstanceOf[WebUIPortResponse] diff --git a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala index b91be821f0..256a5a7c28 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/RecoveryState.scala @@ -17,9 +17,7 @@ package org.apache.spark.deploy.master -private[spark] object RecoveryState - extends Enumeration("STANDBY", "ALIVE", "RECOVERING", "COMPLETING_RECOVERY") { - +private[spark] object RecoveryState extends Enumeration { type MasterState = Value val STANDBY, ALIVE, RECOVERING, COMPLETING_RECOVERY = Value diff --git a/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala b/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala index c8d34f25e2..0b36ef6005 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/WorkerState.scala @@ -17,9 +17,7 @@ package org.apache.spark.deploy.master -private[spark] object WorkerState - extends Enumeration("ALIVE", "DEAD", "DECOMMISSIONED", "UNKNOWN") { - +private[spark] object WorkerState extends Enumeration { type WorkerState = Value val ALIVE, DEAD, DECOMMISSIONED, UNKNOWN = Value diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala index a0233a7271..825344b3bb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -70,15 +70,15 @@ class ZooKeeperPersistenceEngine(serialization: Serialization) (apps, workers) } - private def serializeIntoFile(path: String, value: Serializable) { + private def serializeIntoFile(path: String, value: AnyRef) { val serializer = serialization.findSerializerFor(value) val serialized = serializer.toBinary(value) zk.create(path, serialized, CreateMode.PERSISTENT) } - def deserializeFromFile[T <: Serializable](filename: String)(implicit m: Manifest[T]): T = { + def deserializeFromFile[T](filename: String)(implicit m: Manifest[T]): T = { val fileData = zk.getData("/spark/master_status/" + filename) - val clazz = m.erasure.asInstanceOf[Class[T]] + val clazz = m.runtimeClass.asInstanceOf[Class[T]] val serializer = serialization.serializerFor(clazz) serializer.fromBinary(fileData).asInstanceOf[T] } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala index f4e574d15d..3b983c19eb 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/ApplicationPage.scala @@ -19,9 +19,10 @@ package org.apache.spark.deploy.master.ui import scala.xml.Node -import akka.dispatch.Await import akka.pattern.ask -import akka.util.duration._ + +import scala.concurrent.Await +import scala.concurrent.duration._ import javax.servlet.http.HttpServletRequest diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala index d7a57229b0..65e7a14e7a 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/IndexPage.scala @@ -21,9 +21,9 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node -import akka.dispatch.Await +import scala.concurrent.Await import akka.pattern.ask -import akka.util.duration._ +import scala.concurrent.duration._ import net.liftweb.json.JsonAST.JValue diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala index f4df729e87..a211ce2b42 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.master.ui -import akka.util.Duration +import scala.concurrent.duration._ import javax.servlet.http.HttpServletRequest diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 216d9d44ac..87531b6719 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -17,23 +17,31 @@ package org.apache.spark.deploy.worker +import java.io.File import java.text.SimpleDateFormat import java.util.Date -import java.io.File import scala.collection.mutable.HashMap +import scala.concurrent.duration._ import akka.actor._ -import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientShutdown, RemoteClientDisconnected} -import akka.util.duration._ +import akka.remote.{ DisassociatedEvent, RemotingLifecycleEvent} -import org.apache.spark.Logging +import org.apache.spark.{SparkException, Logging} import org.apache.spark.deploy.{ExecutorDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.Master import org.apache.spark.deploy.worker.ui.WorkerWebUI import org.apache.spark.metrics.MetricsSystem import org.apache.spark.util.{Utils, AkkaUtils} +import org.apache.spark.deploy.DeployMessages.WorkerStateResponse +import org.apache.spark.deploy.DeployMessages.RegisterWorkerFailed +import org.apache.spark.deploy.DeployMessages.KillExecutor +import org.apache.spark.deploy.DeployMessages.ExecutorStateChanged +import org.apache.spark.deploy.DeployMessages.Heartbeat +import org.apache.spark.deploy.DeployMessages.RegisteredWorker +import org.apache.spark.deploy.DeployMessages.LaunchExecutor +import org.apache.spark.deploy.DeployMessages.RegisterWorker /** * @param masterUrls Each url should look like spark://host:port. @@ -47,6 +55,7 @@ private[spark] class Worker( masterUrls: Array[String], workDirPath: String = null) extends Actor with Logging { + import context.dispatcher Utils.checkHost(host, "Expected hostname") assert (port > 0) @@ -63,7 +72,8 @@ private[spark] class Worker( var masterIndex = 0 val masterLock: Object = new Object() - var master: ActorRef = null + var master: ActorSelection = null + var masterAddress: Address = null var activeMasterUrl: String = "" var activeMasterWebUiUrl : String = "" @volatile var registered = false @@ -114,7 +124,7 @@ private[spark] class Worker( logInfo("Spark home: " + sparkHome) createWorkDir() webUi = new WorkerWebUI(this, workDir, Some(webUiPort)) - + context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) webUi.start() registerWithMaster() @@ -126,9 +136,13 @@ private[spark] class Worker( masterLock.synchronized { activeMasterUrl = url activeMasterWebUiUrl = uiUrl - master = context.actorFor(Master.toAkkaUrl(activeMasterUrl)) - context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) - context.watch(master) // Doesn't work with remote actors, but useful for testing + master = context.actorSelection(Master.toAkkaUrl(activeMasterUrl)) + masterAddress = activeMasterUrl match { + case Master.sparkUrlRegex(_host, _port) => + Address("akka.tcp", Master.systemName, _host, _port.toInt) + case x => + throw new SparkException("Invalid spark URL: " + x) + } connected = true } } @@ -136,7 +150,7 @@ private[spark] class Worker( def tryRegisterAllMasters() { for (masterUrl <- masterUrls) { logInfo("Connecting to master " + masterUrl + "...") - val actor = context.actorFor(Master.toAkkaUrl(masterUrl)) + val actor = context.actorSelection(Master.toAkkaUrl(masterUrl)) actor ! RegisterWorker(workerId, host, port, cores, memory, webUi.boundPort.get, publicAddress) } @@ -175,7 +189,6 @@ private[spark] class Worker( case MasterChanged(masterUrl, masterWebUiUrl) => logInfo("Master has changed, new master is at " + masterUrl) - context.unwatch(master) changeMaster(masterUrl, masterWebUiUrl) val execs = executors.values. @@ -234,13 +247,8 @@ private[spark] class Worker( } } - case Terminated(actor_) if actor_ == master => - masterDisconnected() - - case RemoteClientDisconnected(transport, address) if address == master.path.address => - masterDisconnected() - - case RemoteClientShutdown(transport, address) if address == master.path.address => + case x: DisassociatedEvent if x.remoteAddress == masterAddress => + logInfo(s"$x Disassociated !") masterDisconnected() case RequestWorkerState => { @@ -280,8 +288,8 @@ private[spark] object Worker { // The LocalSparkCluster runs multiple local sparkWorkerX actor systems val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("") val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port) - val actor = actorSystem.actorOf(Props(new Worker(host, boundPort, webUiPort, cores, memory, - masterUrls, workDir)), name = "Worker") + actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory, + masterUrls, workDir), name = "Worker") (actorSystem, boundPort) } diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala index d2d3617498..1a768d501f 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/IndexPage.scala @@ -21,9 +21,10 @@ import javax.servlet.http.HttpServletRequest import scala.xml.Node -import akka.dispatch.Await +import scala.concurrent.duration._ +import scala.concurrent.Await + import akka.pattern.ask -import akka.util.duration._ import net.liftweb.json.JsonAST.JValue diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala index 800f1cafcc..6c18a3c245 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala @@ -17,20 +17,19 @@ package org.apache.spark.deploy.worker.ui -import akka.util.{Duration, Timeout} +import java.io.File -import java.io.{FileInputStream, File} +import scala.concurrent.duration._ +import akka.util.Timeout import javax.servlet.http.HttpServletRequest -import org.eclipse.jetty.server.{Handler, Server} - +import org.apache.spark.Logging import org.apache.spark.deploy.worker.Worker -import org.apache.spark.{Logging} -import org.apache.spark.ui.JettyUtils +import org.apache.spark.ui.{JettyUtils, UIUtils} import org.apache.spark.ui.JettyUtils._ -import org.apache.spark.ui.UIUtils import org.apache.spark.util.Utils +import org.eclipse.jetty.server.{Handler, Server} /** * Web UI server for the standalone worker. diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index 8332631838..debbdd4c44 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -19,15 +19,14 @@ package org.apache.spark.executor import java.nio.ByteBuffer -import akka.actor.{ActorRef, Actor, Props, Terminated} -import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientShutdown, RemoteClientDisconnected} +import akka.actor._ +import akka.remote._ import org.apache.spark.Logging import org.apache.spark.TaskState.TaskState import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util.{Utils, AkkaUtils} - private[spark] class CoarseGrainedExecutorBackend( driverUrl: String, executorId: String, @@ -40,14 +39,13 @@ private[spark] class CoarseGrainedExecutorBackend( Utils.checkHostPort(hostPort, "Expected hostport") var executor: Executor = null - var driver: ActorRef = null + var driver: ActorSelection = null override def preStart() { logInfo("Connecting to driver: " + driverUrl) - driver = context.actorFor(driverUrl) + driver = context.actorSelection(driverUrl) driver ! RegisterExecutor(executorId, hostPort, cores) - context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) - context.watch(driver) // Doesn't work with remote actors, but useful for testing + context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) } override def receive = { @@ -77,8 +75,8 @@ private[spark] class CoarseGrainedExecutorBackend( executor.killTask(taskId) } - case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) => - logError("Driver terminated or disconnected! Shutting down.") + case x: DisassociatedEvent => + logError(s"Driver $x disassociated! Shutting down.") System.exit(1) case StopExecutor => @@ -99,12 +97,13 @@ private[spark] object CoarseGrainedExecutorBackend { // Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor // before getting started with all our system properties, etc - val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0) + val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0, + indestructible = true) // set it val sparkHostPort = hostname + ":" + boundPort System.setProperty("spark.hostPort", sparkHostPort) - val actor = actorSystem.actorOf( - Props(new CoarseGrainedExecutorBackend(driverUrl, executorId, sparkHostPort, cores)), + actorSystem.actorOf( + Props(classOf[CoarseGrainedExecutorBackend], driverUrl, executorId, sparkHostPort, cores), name = "Executor") actorSystem.awaitTermination() } diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 5c9bb9db1c..0b0a60ee60 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -121,7 +121,7 @@ private[spark] class Executor( // Akka's message frame size. If task result is bigger than this, we use the block manager // to send the result back. private val akkaFrameSize = { - env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size") + env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size") } // Start worker thread pool diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 0b4892f98f..c0ce46e379 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -61,50 +61,53 @@ object TaskMetrics { class ShuffleReadMetrics extends Serializable { /** - * Time when shuffle finishs + * Absolute time when this task finished reading shuffle data */ var shuffleFinishTime: Long = _ /** - * Total number of blocks fetched in a shuffle (remote or local) + * Number of blocks fetched in this shuffle by this task (remote or local) */ var totalBlocksFetched: Int = _ /** - * Number of remote blocks fetched in a shuffle + * Number of remote blocks fetched in this shuffle by this task */ var remoteBlocksFetched: Int = _ /** - * Local blocks fetched in a shuffle + * Number of local blocks fetched in this shuffle by this task */ var localBlocksFetched: Int = _ /** - * Total time that is spent blocked waiting for shuffle to fetch data + * Time the task spent waiting for remote shuffle blocks. This only includes the time + * blocking on shuffle input data. For instance if block B is being fetched while the task is + * still not finished processing block A, it is not considered to be blocking on block B. */ var fetchWaitTime: Long = _ /** - * The total amount of time for all the shuffle fetches. This adds up time from overlapping - * shuffles, so can be longer than task time + * Total time spent fetching remote shuffle blocks. This aggregates the time spent fetching all + * input blocks. Since block fetches are both pipelined and parallelized, this can + * exceed fetchWaitTime and executorRunTime. */ var remoteFetchTime: Long = _ /** - * Total number of remote bytes read from a shuffle + * Total number of remote bytes read from the shuffle by this task */ var remoteBytesRead: Long = _ } class ShuffleWriteMetrics extends Serializable { /** - * Number of bytes written for a shuffle + * Number of bytes written for the shuffle by this task */ var shuffleBytesWritten: Long = _ /** - * Time spent blocking on writes to disk or buffer cache, in nanoseconds. + * Time the task spent blocking on writes to disk or buffer cache, in nanoseconds */ var shuffleWriteTime: Long = _ } diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala new file mode 100644 index 0000000000..cdcfec8ca7 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.metrics.sink + +import java.util.Properties +import java.util.concurrent.TimeUnit +import java.net.InetSocketAddress + +import com.codahale.metrics.MetricRegistry +import com.codahale.metrics.graphite.{GraphiteReporter, Graphite} + +import org.apache.spark.metrics.MetricsSystem + +class GraphiteSink(val property: Properties, val registry: MetricRegistry) extends Sink { + val GRAPHITE_DEFAULT_PERIOD = 10 + val GRAPHITE_DEFAULT_UNIT = "SECONDS" + val GRAPHITE_DEFAULT_PREFIX = "" + + val GRAPHITE_KEY_HOST = "host" + val GRAPHITE_KEY_PORT = "port" + val GRAPHITE_KEY_PERIOD = "period" + val GRAPHITE_KEY_UNIT = "unit" + val GRAPHITE_KEY_PREFIX = "prefix" + + def propertyToOption(prop: String) = Option(property.getProperty(prop)) + + if (!propertyToOption(GRAPHITE_KEY_HOST).isDefined) { + throw new Exception("Graphite sink requires 'host' property.") + } + + if (!propertyToOption(GRAPHITE_KEY_PORT).isDefined) { + throw new Exception("Graphite sink requires 'port' property.") + } + + val host = propertyToOption(GRAPHITE_KEY_HOST).get + val port = propertyToOption(GRAPHITE_KEY_PORT).get.toInt + + val pollPeriod = propertyToOption(GRAPHITE_KEY_PERIOD) match { + case Some(s) => s.toInt + case None => GRAPHITE_DEFAULT_PERIOD + } + + val pollUnit = propertyToOption(GRAPHITE_KEY_UNIT) match { + case Some(s) => TimeUnit.valueOf(s.toUpperCase()) + case None => TimeUnit.valueOf(GRAPHITE_DEFAULT_UNIT) + } + + val prefix = propertyToOption(GRAPHITE_KEY_PREFIX).getOrElse(GRAPHITE_DEFAULT_PREFIX) + + MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod) + + val graphite: Graphite = new Graphite(new InetSocketAddress(host, port)) + + val reporter: GraphiteReporter = GraphiteReporter.forRegistry(registry) + .convertDurationsTo(TimeUnit.MILLISECONDS) + .convertRatesTo(TimeUnit.SECONDS) + .prefixedWith(prefix) + .build(graphite) + + override def start() { + reporter.start(pollPeriod, pollUnit) + } + + override def stop() { + reporter.stop() + } +} diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index 9c2fee4023..703bc6a9ca 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -31,11 +31,11 @@ import scala.collection.mutable.SynchronizedMap import scala.collection.mutable.SynchronizedQueue import scala.collection.mutable.ArrayBuffer -import akka.dispatch.{Await, Promise, ExecutionContext, Future} -import akka.util.Duration -import akka.util.duration._ -import org.apache.spark.util.Utils +import scala.concurrent.{Await, Promise, ExecutionContext, Future} +import scala.concurrent.duration.Duration +import scala.concurrent.duration._ +import org.apache.spark.util.Utils private[spark] class ConnectionManager(port: Int) extends Logging { diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala index 8d9ad9604d..4f5742d29b 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManagerTest.scala @@ -25,8 +25,8 @@ import scala.io.Source import java.nio.ByteBuffer import java.net.InetAddress -import akka.dispatch.Await -import akka.util.duration._ +import scala.concurrent.Await +import scala.concurrent.duration._ private[spark] object ConnectionManagerTest extends Logging{ def main(args: Array[String]) { diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala index 481ff8c3e0..b1e1576dad 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala @@ -76,7 +76,7 @@ private[spark] object ShuffleCopier extends Logging { extends FileClientHandler with Logging { override def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) { - logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)"); + logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)") resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen)) } diff --git a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala index faaf837be0..d1c74a5063 100644 --- a/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/AsyncRDDActions.scala @@ -21,6 +21,7 @@ import java.util.concurrent.atomic.AtomicLong import scala.collection.mutable.ArrayBuffer import scala.concurrent.ExecutionContext.Implicits.global +import scala.reflect.ClassTag import org.apache.spark.{ComplexFutureAction, FutureAction, Logging} @@ -28,7 +29,7 @@ import org.apache.spark.{ComplexFutureAction, FutureAction, Logging} * A set of asynchronous RDD actions available through an implicit conversion. * Import `org.apache.spark.SparkContext._` at the top of your program to use these functions. */ -class AsyncRDDActions[T: ClassManifest](self: RDD[T]) extends Serializable with Logging { +class AsyncRDDActions[T: ClassTag](self: RDD[T]) extends Serializable with Logging { /** * Returns a future for counting the number of elements in the RDD. diff --git a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala index 44ea573a7c..424354ae16 100644 --- a/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/BlockRDD.scala @@ -17,6 +17,8 @@ package org.apache.spark.rdd +import scala.reflect.ClassTag + import org.apache.spark.{SparkContext, SparkEnv, Partition, TaskContext} import org.apache.spark.storage.{BlockId, BlockManager} @@ -25,7 +27,7 @@ private[spark] class BlockRDDPartition(val blockId: BlockId, idx: Int) extends P } private[spark] -class BlockRDD[T: ClassManifest](sc: SparkContext, @transient blockIds: Array[BlockId]) +class BlockRDD[T: ClassTag](sc: SparkContext, @transient blockIds: Array[BlockId]) extends RDD[T](sc, Nil) { @transient lazy val locations_ = BlockManager.blockIdsToHosts(blockIds, SparkEnv.get) diff --git a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala index 9b0c882481..87b950ba43 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala @@ -18,6 +18,7 @@ package org.apache.spark.rdd import java.io.{ObjectOutputStream, IOException} +import scala.reflect.ClassTag import org.apache.spark._ @@ -43,7 +44,7 @@ class CartesianPartition( } private[spark] -class CartesianRDD[T: ClassManifest, U:ClassManifest]( +class CartesianRDD[T: ClassTag, U: ClassTag]( sc: SparkContext, var rdd1 : RDD[T], var rdd2 : RDD[U]) @@ -70,7 +71,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( override def compute(split: Partition, context: TaskContext) = { val currSplit = split.asInstanceOf[CartesianPartition] for (x <- rdd1.iterator(currSplit.s1, context); - y <- rdd2.iterator(currSplit.s2, context)) yield (x, y) + y <- rdd2.iterator(currSplit.s2, context)) yield (x, y) } override def getDependencies: Seq[Dependency[_]] = List( diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index d3033ea4a6..a712ef1c27 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -17,15 +17,13 @@ package org.apache.spark.rdd +import java.io.IOException + +import scala.reflect.ClassTag + +import org.apache.hadoop.fs.Path import org.apache.spark._ import org.apache.spark.deploy.SparkHadoopUtil -import org.apache.hadoop.mapred.{FileInputFormat, SequenceFileInputFormat, JobConf, Reporter} -import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.io.{NullWritable, BytesWritable} -import org.apache.hadoop.util.ReflectionUtils -import org.apache.hadoop.fs.Path -import java.io.{File, IOException, EOFException} -import java.text.NumberFormat private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {} @@ -33,7 +31,7 @@ private[spark] class CheckpointRDDPartition(val index: Int) extends Partition {} * This RDD represents a RDD checkpoint file (similar to HadoopRDD). */ private[spark] -class CheckpointRDD[T: ClassManifest](sc: SparkContext, val checkpointPath: String) +class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String) extends RDD[T](sc, Nil) { @transient val fs = new Path(checkpointPath).getFileSystem(sc.hadoopConfiguration) diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala index c5de6362a9..98da35763b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala @@ -22,6 +22,7 @@ import java.io.{ObjectOutputStream, IOException} import scala.collection.mutable import scala.Some import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag /** * Class that captures a coalesced RDD by essentially keeping track of parent partitions @@ -68,7 +69,7 @@ case class CoalescedRDDPartition( * @param maxPartitions number of desired partitions in the coalesced RDD * @param balanceSlack used to trade-off balance and locality. 1.0 is all locality, 0 is all balance */ -class CoalescedRDD[T: ClassManifest]( +class CoalescedRDD[T: ClassTag]( @transient var prev: RDD[T], maxPartitions: Int, balanceSlack: Double = 0.10) diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index a4bec41752..688c310ee9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -24,6 +24,8 @@ import org.apache.spark.partial.SumEvaluator import org.apache.spark.util.StatCounter import org.apache.spark.{TaskContext, Logging} +import scala.collection.immutable.NumericRange + /** * Extra functions available on RDDs of Doubles through an implicit conversion. * Import `org.apache.spark.SparkContext._` at the top of your program to use these functions. @@ -76,4 +78,129 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { val evaluator = new SumEvaluator(self.partitions.size, confidence) self.context.runApproximateJob(self, processPartition, evaluator, timeout) } + + /** + * Compute a histogram of the data using bucketCount number of buckets evenly + * spaced between the minimum and maximum of the RDD. For example if the min + * value is 0 and the max is 100 and there are two buckets the resulting + * buckets will be [0, 50) [50, 100]. bucketCount must be at least 1 + * If the RDD contains infinity, NaN throws an exception + * If the elements in RDD do not vary (max == min) always returns a single bucket. + */ + def histogram(bucketCount: Int): Pair[Array[Double], Array[Long]] = { + // Compute the minimum and the maxium + val (max: Double, min: Double) = self.mapPartitions { items => + Iterator(items.foldRight(Double.NegativeInfinity, + Double.PositiveInfinity)((e: Double, x: Pair[Double, Double]) => + (x._1.max(e), x._2.min(e)))) + }.reduce { (maxmin1, maxmin2) => + (maxmin1._1.max(maxmin2._1), maxmin1._2.min(maxmin2._2)) + } + if (min.isNaN || max.isNaN || max.isInfinity || min.isInfinity ) { + throw new UnsupportedOperationException( + "Histogram on either an empty RDD or RDD containing +/-infinity or NaN") + } + val increment = (max-min)/bucketCount.toDouble + val range = if (increment != 0) { + Range.Double.inclusive(min, max, increment) + } else { + List(min, min) + } + val buckets = range.toArray + (buckets, histogram(buckets, true)) + } + + /** + * Compute a histogram using the provided buckets. The buckets are all open + * to the left except for the last which is closed + * e.g. for the array + * [1, 10, 20, 50] the buckets are [1, 10) [10, 20) [20, 50] + * e.g 1<=x<10 , 10<=x<20, 20<=x<50 + * And on the input of 1 and 50 we would have a histogram of 1, 0, 0 + * + * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched + * from an O(log n) inseration to O(1) per element. (where n = # buckets) if you set evenBuckets + * to true. + * buckets must be sorted and not contain any duplicates. + * buckets array must be at least two elements + * All NaN entries are treated the same. If you have a NaN bucket it must be + * the maximum value of the last position and all NaN entries will be counted + * in that bucket. + */ + def histogram(buckets: Array[Double], evenBuckets: Boolean = false): Array[Long] = { + if (buckets.length < 2) { + throw new IllegalArgumentException("buckets array must have at least two elements") + } + // The histogramPartition function computes the partail histogram for a given + // partition. The provided bucketFunction determines which bucket in the array + // to increment or returns None if there is no bucket. This is done so we can + // specialize for uniformly distributed buckets and save the O(log n) binary + // search cost. + def histogramPartition(bucketFunction: (Double) => Option[Int])(iter: Iterator[Double]): + Iterator[Array[Long]] = { + val counters = new Array[Long](buckets.length - 1) + while (iter.hasNext) { + bucketFunction(iter.next()) match { + case Some(x: Int) => {counters(x) += 1} + case _ => {} + } + } + Iterator(counters) + } + // Merge the counters. + def mergeCounters(a1: Array[Long], a2: Array[Long]): Array[Long] = { + a1.indices.foreach(i => a1(i) += a2(i)) + a1 + } + // Basic bucket function. This works using Java's built in Array + // binary search. Takes log(size(buckets)) + def basicBucketFunction(e: Double): Option[Int] = { + val location = java.util.Arrays.binarySearch(buckets, e) + if (location < 0) { + // If the location is less than 0 then the insertion point in the array + // to keep it sorted is -location-1 + val insertionPoint = -location-1 + // If we have to insert before the first element or after the last one + // its out of bounds. + // We do this rather than buckets.lengthCompare(insertionPoint) + // because Array[Double] fails to override it (for now). + if (insertionPoint > 0 && insertionPoint < buckets.length) { + Some(insertionPoint-1) + } else { + None + } + } else if (location < buckets.length - 1) { + // Exact match, just insert here + Some(location) + } else { + // Exact match to the last element + Some(location - 1) + } + } + // Determine the bucket function in constant time. Requires that buckets are evenly spaced + def fastBucketFunction(min: Double, increment: Double, count: Int)(e: Double): Option[Int] = { + // If our input is not a number unless the increment is also NaN then we fail fast + if (e.isNaN()) { + return None + } + val bucketNumber = (e - min)/(increment) + // We do this rather than buckets.lengthCompare(bucketNumber) + // because Array[Double] fails to override it (for now). + if (bucketNumber > count || bucketNumber < 0) { + None + } else { + Some(bucketNumber.toInt.min(count - 1)) + } + } + // Decide which bucket function to pass to histogramPartition. We decide here + // rather than having a general function so that the decission need only be made + // once rather than once per shard + val bucketFunction = if (evenBuckets) { + fastBucketFunction(buckets(0), buckets(1)-buckets(0), buckets.length-1) _ + } else { + basicBucketFunction _ + } + self.mapPartitions(histogramPartition(bucketFunction)).reduce(mergeCounters) + } + } diff --git a/core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala b/core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala index c8900d1a93..a84e5f9fd8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/EmptyRDD.scala @@ -17,13 +17,14 @@ package org.apache.spark.rdd -import org.apache.spark.{SparkContext, SparkEnv, Partition, TaskContext} +import scala.reflect.ClassTag +import org.apache.spark.{Partition, SparkContext, TaskContext} /** * An RDD that is empty, i.e. has no element in it. */ -class EmptyRDD[T: ClassManifest](sc: SparkContext) extends RDD[T](sc, Nil) { +class EmptyRDD[T: ClassTag](sc: SparkContext) extends RDD[T](sc, Nil) { override def getPartitions: Array[Partition] = Array.empty diff --git a/core/src/main/scala/org/apache/spark/rdd/FilteredRDD.scala b/core/src/main/scala/org/apache/spark/rdd/FilteredRDD.scala index 5312dc0b59..e74c83b90b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/FilteredRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/FilteredRDD.scala @@ -18,8 +18,9 @@ package org.apache.spark.rdd import org.apache.spark.{OneToOneDependency, Partition, TaskContext} +import scala.reflect.ClassTag -private[spark] class FilteredRDD[T: ClassManifest]( +private[spark] class FilteredRDD[T: ClassTag]( prev: RDD[T], f: T => Boolean) extends RDD[T](prev) { diff --git a/core/src/main/scala/org/apache/spark/rdd/FlatMappedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/FlatMappedRDD.scala index cbdf6d84c0..4d1878fc14 100644 --- a/core/src/main/scala/org/apache/spark/rdd/FlatMappedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/FlatMappedRDD.scala @@ -18,10 +18,11 @@ package org.apache.spark.rdd import org.apache.spark.{Partition, TaskContext} +import scala.reflect.ClassTag private[spark] -class FlatMappedRDD[U: ClassManifest, T: ClassManifest]( +class FlatMappedRDD[U: ClassTag, T: ClassTag]( prev: RDD[T], f: T => TraversableOnce[U]) extends RDD[U](prev) { diff --git a/core/src/main/scala/org/apache/spark/rdd/GlommedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/GlommedRDD.scala index 829545d7b0..1a694475f6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/GlommedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/GlommedRDD.scala @@ -18,8 +18,9 @@ package org.apache.spark.rdd import org.apache.spark.{Partition, TaskContext} +import scala.reflect.ClassTag -private[spark] class GlommedRDD[T: ClassManifest](prev: RDD[T]) +private[spark] class GlommedRDD[T: ClassTag](prev: RDD[T]) extends RDD[Array[T]](prev) { override def getPartitions: Array[Partition] = firstParent[T].partitions diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala index aca0146884..8df8718f3b 100644 --- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala @@ -19,6 +19,8 @@ package org.apache.spark.rdd import java.sql.{Connection, ResultSet} +import scala.reflect.ClassTag + import org.apache.spark.{Logging, Partition, SparkContext, TaskContext} import org.apache.spark.util.NextIterator @@ -45,7 +47,7 @@ private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) e * This should only call getInt, getString, etc; the RDD takes care of calling next. * The default maps a ResultSet to an array of Object. */ -class JdbcRDD[T: ClassManifest]( +class JdbcRDD[T: ClassTag]( sc: SparkContext, getConnection: () => Connection, sql: String, diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala index 203179c4ea..db15baf503 100644 --- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala @@ -18,20 +18,18 @@ package org.apache.spark.rdd import org.apache.spark.{Partition, TaskContext} +import scala.reflect.ClassTag - -private[spark] -class MapPartitionsRDD[U: ClassManifest, T: ClassManifest]( +private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag]( prev: RDD[T], - f: Iterator[T] => Iterator[U], + f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator) preservesPartitioning: Boolean = false) extends RDD[U](prev) { - override val partitioner = - if (preservesPartitioning) firstParent[T].partitioner else None + override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None override def getPartitions: Array[Partition] = firstParent[T].partitions override def compute(split: Partition, context: TaskContext) = - f(firstParent[T].iterator(split, context)) + f(context, split.index, firstParent[T].iterator(split, context)) } diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala deleted file mode 100644 index aea08ff81b..0000000000 --- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.rdd - -import org.apache.spark.{Partition, TaskContext} - - -/** - * A variant of the MapPartitionsRDD that passes the TaskContext into the closure. From the - * TaskContext, the closure can either get access to the interruptible flag or get the index - * of the partition in the RDD. - */ -private[spark] -class MapPartitionsWithContextRDD[U: ClassManifest, T: ClassManifest]( - prev: RDD[T], - f: (TaskContext, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean - ) extends RDD[U](prev) { - - override def getPartitions: Array[Partition] = firstParent[T].partitions - - override val partitioner = if (preservesPartitioning) prev.partitioner else None - - override def compute(split: Partition, context: TaskContext) = - f(context, firstParent[T].iterator(split, context)) -} diff --git a/core/src/main/scala/org/apache/spark/rdd/MappedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MappedRDD.scala index e8be1c4816..8d7c288593 100644 --- a/core/src/main/scala/org/apache/spark/rdd/MappedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/MappedRDD.scala @@ -17,10 +17,12 @@ package org.apache.spark.rdd +import scala.reflect.ClassTag + import org.apache.spark.{Partition, TaskContext} private[spark] -class MappedRDD[U: ClassManifest, T: ClassManifest](prev: RDD[T], f: T => U) +class MappedRDD[U: ClassTag, T: ClassTag](prev: RDD[T], f: T => U) extends RDD[U](prev) { override def getPartitions: Array[Partition] = firstParent[T].partitions diff --git a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala index 697be8b997..d5691f2267 100644 --- a/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/OrderedRDDFunctions.scala @@ -17,7 +17,9 @@ package org.apache.spark.rdd -import org.apache.spark.{RangePartitioner, Logging} +import scala.reflect.ClassTag + +import org.apache.spark.{Logging, RangePartitioner} /** * Extra functions available on RDDs of (key, value) pairs where the key is sortable through @@ -25,9 +27,9 @@ import org.apache.spark.{RangePartitioner, Logging} * use these functions. They will work with any key type that has a `scala.math.Ordered` * implementation. */ -class OrderedRDDFunctions[K <% Ordered[K]: ClassManifest, - V: ClassManifest, - P <: Product2[K, V] : ClassManifest]( +class OrderedRDDFunctions[K <% Ordered[K]: ClassTag, + V: ClassTag, + P <: Product2[K, V] : ClassTag]( self: RDD[P]) extends Logging with Serializable { diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala index 93b78e1232..48168e152e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala @@ -25,6 +25,7 @@ import java.util.{HashMap => JHashMap} import scala.collection.{mutable, Map} import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConversions._ +import scala.reflect.{ClassTag, classTag} import org.apache.hadoop.mapred._ import org.apache.hadoop.io.compress.CompressionCodec @@ -50,7 +51,7 @@ import org.apache.spark.Partitioner.defaultPartitioner * Extra functions available on RDDs of (key, value) pairs through an implicit conversion. * Import `org.apache.spark.SparkContext._` at the top of your program to use these functions. */ -class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)]) +class PairRDDFunctions[K: ClassTag, V: ClassTag](self: RDD[(K, V)]) extends Logging with SparkHadoopMapReduceUtil with Serializable { @@ -415,7 +416,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)]) throw new SparkException("Default partitioner cannot partition array keys.") } val cg = new CoGroupedRDD[K](Seq(self, other), partitioner) - val prfs = new PairRDDFunctions[K, Seq[Seq[_]]](cg)(classManifest[K], Manifests.seqSeqManifest) + val prfs = new PairRDDFunctions[K, Seq[Seq[_]]](cg)(classTag[K], ClassTags.seqSeqClassTag) prfs.mapValues { case Seq(vs, ws) => (vs.asInstanceOf[Seq[V]], ws.asInstanceOf[Seq[W]]) } @@ -431,7 +432,7 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)]) throw new SparkException("Default partitioner cannot partition array keys.") } val cg = new CoGroupedRDD[K](Seq(self, other1, other2), partitioner) - val prfs = new PairRDDFunctions[K, Seq[Seq[_]]](cg)(classManifest[K], Manifests.seqSeqManifest) + val prfs = new PairRDDFunctions[K, Seq[Seq[_]]](cg)(classTag[K], ClassTags.seqSeqClassTag) prfs.mapValues { case Seq(vs, w1s, w2s) => (vs.asInstanceOf[Seq[V]], w1s.asInstanceOf[Seq[W1]], w2s.asInstanceOf[Seq[W2]]) } @@ -488,15 +489,15 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)]) * Uses `this` partitioner/partition size, because even if `other` is huge, the resulting * RDD will be <= us. */ - def subtractByKey[W: ClassManifest](other: RDD[(K, W)]): RDD[(K, V)] = + def subtractByKey[W: ClassTag](other: RDD[(K, W)]): RDD[(K, V)] = subtractByKey(other, self.partitioner.getOrElse(new HashPartitioner(self.partitions.size))) /** Return an RDD with the pairs from `this` whose keys are not in `other`. */ - def subtractByKey[W: ClassManifest](other: RDD[(K, W)], numPartitions: Int): RDD[(K, V)] = + def subtractByKey[W: ClassTag](other: RDD[(K, W)], numPartitions: Int): RDD[(K, V)] = subtractByKey(other, new HashPartitioner(numPartitions)) /** Return an RDD with the pairs from `this` whose keys are not in `other`. */ - def subtractByKey[W: ClassManifest](other: RDD[(K, W)], p: Partitioner): RDD[(K, V)] = + def subtractByKey[W: ClassTag](other: RDD[(K, W)], p: Partitioner): RDD[(K, V)] = new SubtractedRDD[K, V, W](self, other, p) /** @@ -525,8 +526,8 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)]) * Output the RDD to any Hadoop-supported file system, using a Hadoop `OutputFormat` class * supporting the key and value types K and V in this RDD. */ - def saveAsHadoopFile[F <: OutputFormat[K, V]](path: String)(implicit fm: ClassManifest[F]) { - saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]]) + def saveAsHadoopFile[F <: OutputFormat[K, V]](path: String)(implicit fm: ClassTag[F]) { + saveAsHadoopFile(path, getKeyClass, getValueClass, fm.runtimeClass.asInstanceOf[Class[F]]) } /** @@ -535,16 +536,16 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)]) * supplied codec. */ def saveAsHadoopFile[F <: OutputFormat[K, V]]( - path: String, codec: Class[_ <: CompressionCodec]) (implicit fm: ClassManifest[F]) { - saveAsHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]], codec) + path: String, codec: Class[_ <: CompressionCodec]) (implicit fm: ClassTag[F]) { + saveAsHadoopFile(path, getKeyClass, getValueClass, fm.runtimeClass.asInstanceOf[Class[F]], codec) } /** * Output the RDD to any Hadoop-supported file system, using a new Hadoop API `OutputFormat` * (mapreduce.OutputFormat) object supporting the key and value types K and V in this RDD. */ - def saveAsNewAPIHadoopFile[F <: NewOutputFormat[K, V]](path: String)(implicit fm: ClassManifest[F]) { - saveAsNewAPIHadoopFile(path, getKeyClass, getValueClass, fm.erasure.asInstanceOf[Class[F]]) + def saveAsNewAPIHadoopFile[F <: NewOutputFormat[K, V]](path: String)(implicit fm: ClassTag[F]) { + saveAsNewAPIHadoopFile(path, getKeyClass, getValueClass, fm.runtimeClass.asInstanceOf[Class[F]]) } /** @@ -698,11 +699,11 @@ class PairRDDFunctions[K: ClassManifest, V: ClassManifest](self: RDD[(K, V)]) */ def values: RDD[V] = self.map(_._2) - private[spark] def getKeyClass() = implicitly[ClassManifest[K]].erasure + private[spark] def getKeyClass() = implicitly[ClassTag[K]].runtimeClass - private[spark] def getValueClass() = implicitly[ClassManifest[V]].erasure + private[spark] def getValueClass() = implicitly[ClassTag[V]].runtimeClass } -private[spark] object Manifests { - val seqSeqManifest = classManifest[Seq[Seq[_]]] +private[spark] object ClassTags { + val seqSeqClassTag = classTag[Seq[Seq[_]]] } diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala index cd96250389..09d0a8189d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala @@ -20,13 +20,15 @@ package org.apache.spark.rdd import scala.collection.immutable.NumericRange import scala.collection.mutable.ArrayBuffer import scala.collection.Map +import scala.reflect.ClassTag + import org.apache.spark._ import java.io._ import scala.Serializable import org.apache.spark.serializer.JavaSerializer import org.apache.spark.util.Utils -private[spark] class ParallelCollectionPartition[T: ClassManifest]( +private[spark] class ParallelCollectionPartition[T: ClassTag]( var rddId: Long, var slice: Int, var values: Seq[T]) @@ -78,7 +80,7 @@ private[spark] class ParallelCollectionPartition[T: ClassManifest]( } } -private[spark] class ParallelCollectionRDD[T: ClassManifest]( +private[spark] class ParallelCollectionRDD[T: ClassTag]( @transient sc: SparkContext, @transient data: Seq[T], numSlices: Int, @@ -109,7 +111,7 @@ private object ParallelCollectionRDD { * collections specially, encoding the slices as other Ranges to minimize memory cost. This makes * it efficient to run Spark over RDDs representing large sets of numbers. */ - def slice[T: ClassManifest](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = { + def slice[T: ClassTag](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = { if (numSlices < 1) { throw new IllegalArgumentException("Positive number of slices required") } diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala index 165cd412fc..ea8885b36e 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala @@ -17,6 +17,8 @@ package org.apache.spark.rdd +import scala.reflect.ClassTag + import org.apache.spark.{NarrowDependency, SparkEnv, Partition, TaskContext} @@ -33,11 +35,13 @@ class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boo extends NarrowDependency[T](rdd) { @transient - val partitions: Array[Partition] = rdd.partitions.zipWithIndex - .filter(s => partitionFilterFunc(s._2)) + val partitions: Array[Partition] = rdd.partitions + .filter(s => partitionFilterFunc(s.index)).zipWithIndex .map { case(split, idx) => new PartitionPruningRDDPartition(idx, split) : Partition } - override def getParents(partitionId: Int) = List(partitions(partitionId).index) + override def getParents(partitionId: Int) = { + List(partitions(partitionId).asInstanceOf[PartitionPruningRDDPartition].parentSplit.index) + } } @@ -47,7 +51,7 @@ class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boo * and the execution DAG has a filter on the key, we can avoid launching tasks * on partitions that don't have the range covering the key. */ -class PartitionPruningRDD[T: ClassManifest]( +class PartitionPruningRDD[T: ClassTag]( @transient prev: RDD[T], @transient partitionFilterFunc: Int => Boolean) extends RDD[T](prev.context, List(new PruneDependency(prev, partitionFilterFunc))) { @@ -67,6 +71,6 @@ object PartitionPruningRDD { * when its type T is not known at compile time. */ def create[T](rdd: RDD[T], partitionFilterFunc: Int => Boolean) = { - new PartitionPruningRDD[T](rdd, partitionFilterFunc)(rdd.elementClassManifest) + new PartitionPruningRDD[T](rdd, partitionFilterFunc)(rdd.elementClassTag) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala index d5304ab0ae..1dbbe39898 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala @@ -24,6 +24,7 @@ import scala.collection.Map import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer import scala.io.Source +import scala.reflect.ClassTag import org.apache.spark.{SparkEnv, Partition, TaskContext} import org.apache.spark.broadcast.Broadcast @@ -33,7 +34,7 @@ import org.apache.spark.broadcast.Broadcast * An RDD that pipes the contents of each parent partition through an external command * (printing them one per line) and returns the output as a collection of strings. */ -class PipedRDD[T: ClassManifest]( +class PipedRDD[T: ClassTag]( prev: RDD[T], command: Seq[String], envVars: Map[String, String], diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 6e88be6f6a..ea45566ad1 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -23,6 +23,9 @@ import scala.collection.Map import scala.collection.JavaConversions.mapAsScalaMap import scala.collection.mutable.ArrayBuffer +import scala.collection.mutable.HashMap +import scala.reflect.{classTag, ClassTag} + import org.apache.hadoop.io.BytesWritable import org.apache.hadoop.io.compress.CompressionCodec import org.apache.hadoop.io.NullWritable @@ -69,7 +72,7 @@ import org.apache.spark._ * [[http://www.cs.berkeley.edu/~matei/papers/2012/nsdi_spark.pdf Spark paper]] for more details * on RDD internals. */ -abstract class RDD[T: ClassManifest]( +abstract class RDD[T: ClassTag]( @transient private var sc: SparkContext, @transient private var deps: Seq[Dependency[_]] ) extends Serializable with Logging { @@ -101,7 +104,7 @@ abstract class RDD[T: ClassManifest]( protected def getPreferredLocations(split: Partition): Seq[String] = Nil /** Optionally overridden by subclasses to specify how they are partitioned. */ - val partitioner: Option[Partitioner] = None + @transient val partitioner: Option[Partitioner] = None // ======================================================================= // Methods and fields available on all RDDs @@ -114,7 +117,7 @@ abstract class RDD[T: ClassManifest]( val id: Int = sc.newRddId() /** A friendly name for this RDD */ - var name: String = null + @transient var name: String = null /** Assign a name to this RDD */ def setName(_name: String) = { @@ -123,7 +126,7 @@ abstract class RDD[T: ClassManifest]( } /** User-defined generator of this RDD*/ - var generator = Utils.getCallSiteInfo.firstUserClass + @transient var generator = Utils.getCallSiteInfo.firstUserClass /** Reset generator*/ def setGenerator(_generator: String) = { @@ -243,13 +246,13 @@ abstract class RDD[T: ClassManifest]( /** * Return a new RDD by applying a function to all elements of this RDD. */ - def map[U: ClassManifest](f: T => U): RDD[U] = new MappedRDD(this, sc.clean(f)) + def map[U: ClassTag](f: T => U): RDD[U] = new MappedRDD(this, sc.clean(f)) /** * Return a new RDD by first applying a function to all elements of this * RDD, and then flattening the results. */ - def flatMap[U: ClassManifest](f: T => TraversableOnce[U]): RDD[U] = + def flatMap[U: ClassTag](f: T => TraversableOnce[U]): RDD[U] = new FlatMappedRDD(this, sc.clean(f)) /** @@ -374,25 +377,25 @@ abstract class RDD[T: ClassManifest]( * Return the Cartesian product of this RDD and another one, that is, the RDD of all pairs of * elements (a, b) where a is in `this` and b is in `other`. */ - def cartesian[U: ClassManifest](other: RDD[U]): RDD[(T, U)] = new CartesianRDD(sc, this, other) + def cartesian[U: ClassTag](other: RDD[U]): RDD[(T, U)] = new CartesianRDD(sc, this, other) /** * Return an RDD of grouped items. */ - def groupBy[K: ClassManifest](f: T => K): RDD[(K, Seq[T])] = + def groupBy[K: ClassTag](f: T => K): RDD[(K, Seq[T])] = groupBy[K](f, defaultPartitioner(this)) /** * Return an RDD of grouped elements. Each group consists of a key and a sequence of elements * mapping to that key. */ - def groupBy[K: ClassManifest](f: T => K, numPartitions: Int): RDD[(K, Seq[T])] = + def groupBy[K: ClassTag](f: T => K, numPartitions: Int): RDD[(K, Seq[T])] = groupBy(f, new HashPartitioner(numPartitions)) /** * Return an RDD of grouped items. */ - def groupBy[K: ClassManifest](f: T => K, p: Partitioner): RDD[(K, Seq[T])] = { + def groupBy[K: ClassTag](f: T => K, p: Partitioner): RDD[(K, Seq[T])] = { val cleanF = sc.clean(f) this.map(t => (cleanF(t), t)).groupByKey(p) } @@ -408,7 +411,6 @@ abstract class RDD[T: ClassManifest]( def pipe(command: String, env: Map[String, String]): RDD[String] = new PipedRDD(this, command, env) - /** * Return an RDD created by piping elements to a forked external process. * The print behavior can be customized by providing two functions. @@ -440,29 +442,31 @@ abstract class RDD[T: ClassManifest]( /** * Return a new RDD by applying a function to each partition of this RDD. */ - def mapPartitions[U: ClassManifest]( + def mapPartitions[U: ClassTag]( f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { - new MapPartitionsRDD(this, sc.clean(f), preservesPartitioning) + val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(iter) + new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning) } /** * Return a new RDD by applying a function to each partition of this RDD, while tracking the index * of the original partition. */ - def mapPartitionsWithIndex[U: ClassManifest]( + def mapPartitionsWithIndex[U: ClassTag]( f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { - val func = (context: TaskContext, iter: Iterator[T]) => f(context.partitionId, iter) - new MapPartitionsWithContextRDD(this, sc.clean(func), preservesPartitioning) + val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter) + new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning) } /** * Return a new RDD by applying a function to each partition of this RDD. This is a variant of * mapPartitions that also passes the TaskContext into the closure. */ - def mapPartitionsWithContext[U: ClassManifest]( + def mapPartitionsWithContext[U: ClassTag]( f: (TaskContext, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { - new MapPartitionsWithContextRDD(this, sc.clean(f), preservesPartitioning) + val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(context, iter) + new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning) } /** @@ -470,7 +474,7 @@ abstract class RDD[T: ClassManifest]( * of the original partition. */ @deprecated("use mapPartitionsWithIndex", "0.7.0") - def mapPartitionsWithSplit[U: ClassManifest]( + def mapPartitionsWithSplit[U: ClassTag]( f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { mapPartitionsWithIndex(f, preservesPartitioning) } @@ -480,14 +484,13 @@ abstract class RDD[T: ClassManifest]( * additional parameter is produced by constructA, which is called in each * partition with the index of that partition. */ - def mapWith[A: ClassManifest, U: ClassManifest] + def mapWith[A: ClassTag, U: ClassTag] (constructA: Int => A, preservesPartitioning: Boolean = false) (f: (T, A) => U): RDD[U] = { - def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = { - val a = constructA(context.partitionId) + mapPartitionsWithIndex((index, iter) => { + val a = constructA(index) iter.map(t => f(t, a)) - } - new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning) + }, preservesPartitioning) } /** @@ -495,14 +498,13 @@ abstract class RDD[T: ClassManifest]( * additional parameter is produced by constructA, which is called in each * partition with the index of that partition. */ - def flatMapWith[A: ClassManifest, U: ClassManifest] + def flatMapWith[A: ClassTag, U: ClassTag] (constructA: Int => A, preservesPartitioning: Boolean = false) (f: (T, A) => Seq[U]): RDD[U] = { - def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = { - val a = constructA(context.partitionId) + mapPartitionsWithIndex((index, iter) => { + val a = constructA(index) iter.flatMap(t => f(t, a)) - } - new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning) + }, preservesPartitioning) } /** @@ -510,12 +512,11 @@ abstract class RDD[T: ClassManifest]( * This additional parameter is produced by constructA, which is called in each * partition with the index of that partition. */ - def foreachWith[A: ClassManifest](constructA: Int => A)(f: (T, A) => Unit) { - def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = { - val a = constructA(context.partitionId) + def foreachWith[A: ClassTag](constructA: Int => A)(f: (T, A) => Unit) { + mapPartitionsWithIndex { (index, iter) => + val a = constructA(index) iter.map(t => {f(t, a); t}) - } - new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true).foreach(_ => {}) + }.foreach(_ => {}) } /** @@ -523,12 +524,11 @@ abstract class RDD[T: ClassManifest]( * additional parameter is produced by constructA, which is called in each * partition with the index of that partition. */ - def filterWith[A: ClassManifest](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = { - def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = { - val a = constructA(context.partitionId) + def filterWith[A: ClassTag](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = { + mapPartitionsWithIndex((index, iter) => { + val a = constructA(index) iter.filter(t => p(t, a)) - } - new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true) + }, preservesPartitioning = true) } /** @@ -537,7 +537,7 @@ abstract class RDD[T: ClassManifest]( * partitions* and the *same number of elements in each partition* (e.g. one was made through * a map on the other). */ - def zip[U: ClassManifest](other: RDD[U]): RDD[(T, U)] = new ZippedRDD(sc, this, other) + def zip[U: ClassTag](other: RDD[U]): RDD[(T, U)] = new ZippedRDD(sc, this, other) /** * Zip this RDD's partitions with one (or more) RDD(s) and return a new RDD by @@ -545,20 +545,30 @@ abstract class RDD[T: ClassManifest]( * *same number of partitions*, but does *not* require them to have the same number * of elements in each partition. */ - def zipPartitions[B: ClassManifest, V: ClassManifest] + def zipPartitions[B: ClassTag, V: ClassTag] (rdd2: RDD[B]) (f: (Iterator[T], Iterator[B]) => Iterator[V]): RDD[V] = - new ZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2) + new ZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2, false) + + def zipPartitions[B: ClassTag, C: ClassTag, V: ClassTag] + (rdd2: RDD[B], rdd3: RDD[C], preservesPartitioning: Boolean) + (f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V]): RDD[V] = + new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3, preservesPartitioning) - def zipPartitions[B: ClassManifest, C: ClassManifest, V: ClassManifest] + def zipPartitions[B: ClassTag, C: ClassTag, V: ClassTag] (rdd2: RDD[B], rdd3: RDD[C]) (f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V]): RDD[V] = - new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3) + new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3, false) + + def zipPartitions[B: ClassTag, C: ClassTag, D: ClassTag, V: ClassTag] + (rdd2: RDD[B], rdd3: RDD[C], rdd4: RDD[D], preservesPartitioning: Boolean) + (f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V]): RDD[V] = + new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4, preservesPartitioning) - def zipPartitions[B: ClassManifest, C: ClassManifest, D: ClassManifest, V: ClassManifest] + def zipPartitions[B: ClassTag, C: ClassTag, D: ClassTag, V: ClassTag] (rdd2: RDD[B], rdd3: RDD[C], rdd4: RDD[D]) (f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V]): RDD[V] = - new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4) + new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4, false) // Actions (launch a job to return a value to the user program) @@ -593,7 +603,7 @@ abstract class RDD[T: ClassManifest]( /** * Return an RDD that contains all matching values by applying `f`. */ - def collect[U: ClassManifest](f: PartialFunction[T, U]): RDD[U] = { + def collect[U: ClassTag](f: PartialFunction[T, U]): RDD[U] = { filter(f.isDefinedAt).map(f) } @@ -683,7 +693,7 @@ abstract class RDD[T: ClassManifest]( * allowed to modify and return their first argument instead of creating a new U to avoid memory * allocation. */ - def aggregate[U: ClassManifest](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = { + def aggregate[U: ClassTag](zeroValue: U)(seqOp: (U, T) => U, combOp: (U, U) => U): U = { // Clone the zero value since we will also be serializing it as part of tasks var jobResult = Utils.clone(zeroValue, sc.env.closureSerializer.newInstance()) val cleanSeqOp = sc.clean(seqOp) @@ -732,7 +742,7 @@ abstract class RDD[T: ClassManifest]( * combine step happens locally on the master, equivalent to running a single reduce task. */ def countByValue(): Map[T, Long] = { - if (elementClassManifest.erasure.isArray) { + if (elementClassTag.runtimeClass.isArray) { throw new SparkException("countByValue() does not support arrays") } // TODO: This should perhaps be distributed by default. @@ -763,7 +773,7 @@ abstract class RDD[T: ClassManifest]( timeout: Long, confidence: Double = 0.95 ): PartialResult[Map[T, BoundedDouble]] = { - if (elementClassManifest.erasure.isArray) { + if (elementClassTag.runtimeClass.isArray) { throw new SparkException("countByValueApprox() does not support arrays") } val countPartition: (TaskContext, Iterator[T]) => OLMap[T] = { (ctx, iter) => @@ -928,14 +938,14 @@ abstract class RDD[T: ClassManifest]( private var storageLevel: StorageLevel = StorageLevel.NONE /** Record user function generating this RDD. */ - private[spark] val origin = Utils.formatSparkCallSite + @transient private[spark] val origin = Utils.formatSparkCallSite - private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T] + private[spark] def elementClassTag: ClassTag[T] = classTag[T] private[spark] var checkpointData: Option[RDDCheckpointData[T]] = None /** Returns the first parent RDD */ - protected[spark] def firstParent[U: ClassManifest] = { + protected[spark] def firstParent[U: ClassTag] = { dependencies.head.rdd.asInstanceOf[RDD[U]] } @@ -943,7 +953,7 @@ abstract class RDD[T: ClassManifest]( def context = sc // Avoid handling doCheckpoint multiple times to prevent excessive recursion - private var doCheckpointCalled = false + @transient private var doCheckpointCalled = false /** * Performs the checkpointing of this RDD by saving this. It is called by the DAGScheduler @@ -997,7 +1007,7 @@ abstract class RDD[T: ClassManifest]( origin) def toJavaRDD() : JavaRDD[T] = { - new JavaRDD(this)(elementClassManifest) + new JavaRDD(this)(elementClassTag) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala index 6009a41570..3b56e45aa9 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDDCheckpointData.scala @@ -17,6 +17,8 @@ package org.apache.spark.rdd +import scala.reflect.ClassTag + import org.apache.hadoop.fs.Path import org.apache.hadoop.conf.Configuration @@ -38,7 +40,7 @@ private[spark] object CheckpointState extends Enumeration { * manages the post-checkpoint state by providing the updated partitions, iterator and preferred locations * of the checkpointed RDD. */ -private[spark] class RDDCheckpointData[T: ClassManifest](rdd: RDD[T]) +private[spark] class RDDCheckpointData[T: ClassTag](rdd: RDD[T]) extends Logging with Serializable { import CheckpointState._ diff --git a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala index 2c5253ae30..d433670cc2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SampledRDD.scala @@ -17,6 +17,7 @@ package org.apache.spark.rdd +import scala.reflect.ClassTag import java.util.Random import cern.jet.random.Poisson @@ -29,9 +30,9 @@ class SampledRDDPartition(val prev: Partition, val seed: Int) extends Partition override val index: Int = prev.index } -class SampledRDD[T: ClassManifest]( +class SampledRDD[T: ClassTag]( prev: RDD[T], - withReplacement: Boolean, + withReplacement: Boolean, frac: Double, seed: Int) extends RDD[T](prev) { diff --git a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala index 5fe4676029..2d1bd5b481 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SequenceFileRDDFunctions.scala @@ -14,9 +14,10 @@ * See the License for the specific language governing permissions and * limitations under the License. */ - package org.apache.spark.rdd +import scala.reflect.{ ClassTag, classTag} + import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapred.SequenceFileOutputFormat import org.apache.hadoop.io.compress.CompressionCodec @@ -32,15 +33,15 @@ import org.apache.spark.Logging * * Import `org.apache.spark.SparkContext._` at the top of their program to use these functions. */ -class SequenceFileRDDFunctions[K <% Writable: ClassManifest, V <% Writable : ClassManifest]( +class SequenceFileRDDFunctions[K <% Writable: ClassTag, V <% Writable : ClassTag]( self: RDD[(K, V)]) extends Logging with Serializable { - private def getWritableClass[T <% Writable: ClassManifest](): Class[_ <: Writable] = { + private def getWritableClass[T <% Writable: ClassTag](): Class[_ <: Writable] = { val c = { - if (classOf[Writable].isAssignableFrom(classManifest[T].erasure)) { - classManifest[T].erasure + if (classOf[Writable].isAssignableFrom(classTag[T].runtimeClass)) { + classTag[T].runtimeClass } else { // We get the type of the Writable class by looking at the apply method which converts // from T to Writable. Since we have two apply methods we filter out the one which diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index a5d751a7bd..3682c84598 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -17,8 +17,10 @@ package org.apache.spark.rdd -import org.apache.spark.{Dependency, Partitioner, SparkEnv, ShuffleDependency, Partition, TaskContext} +import scala.reflect.ClassTag +import org.apache.spark.{Dependency, Partition, Partitioner, ShuffleDependency, + SparkEnv, TaskContext} private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { override val index = idx @@ -32,7 +34,7 @@ private[spark] class ShuffledRDDPartition(val idx: Int) extends Partition { * @tparam K the key class. * @tparam V the value class. */ -class ShuffledRDD[K, V, P <: Product2[K, V] : ClassManifest]( +class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag]( @transient var prev: RDD[P], part: Partitioner) extends RDD[P](prev.context, Nil) { diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index 7af4d803e7..aab30b1bb4 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -18,8 +18,11 @@ package org.apache.spark.rdd import java.util.{HashMap => JHashMap} + import scala.collection.JavaConversions._ import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + import org.apache.spark.Partitioner import org.apache.spark.Dependency import org.apache.spark.TaskContext @@ -45,7 +48,7 @@ import org.apache.spark.OneToOneDependency * you can use `rdd1`'s partitioner/partition size and not worry about running * out of memory because of the size of `rdd2`. */ -private[spark] class SubtractedRDD[K: ClassManifest, V: ClassManifest, W: ClassManifest]( +private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( @transient var rdd1: RDD[_ <: Product2[K, V]], @transient var rdd2: RDD[_ <: Product2[K, W]], part: Partitioner) diff --git a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala index ae8a9f36a6..08a41ac558 100644 --- a/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/UnionRDD.scala @@ -18,10 +18,13 @@ package org.apache.spark.rdd import scala.collection.mutable.ArrayBuffer +import scala.reflect.ClassTag + import org.apache.spark.{Dependency, RangeDependency, SparkContext, Partition, TaskContext} + import java.io.{ObjectOutputStream, IOException} -private[spark] class UnionPartition[T: ClassManifest](idx: Int, rdd: RDD[T], splitIndex: Int) +private[spark] class UnionPartition[T: ClassTag](idx: Int, rdd: RDD[T], splitIndex: Int) extends Partition { var split: Partition = rdd.partitions(splitIndex) @@ -40,7 +43,7 @@ private[spark] class UnionPartition[T: ClassManifest](idx: Int, rdd: RDD[T], spl } } -class UnionRDD[T: ClassManifest]( +class UnionRDD[T: ClassTag]( sc: SparkContext, @transient var rdds: Seq[RDD[T]]) extends RDD[T](sc, Nil) { // Nil since we implement getDependencies diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala index 31e6fd519d..83be3c6eb4 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala @@ -19,10 +19,12 @@ package org.apache.spark.rdd import org.apache.spark.{OneToOneDependency, SparkContext, Partition, TaskContext} import java.io.{ObjectOutputStream, IOException} +import scala.reflect.ClassTag private[spark] class ZippedPartitionsPartition( idx: Int, - @transient rdds: Seq[RDD[_]]) + @transient rdds: Seq[RDD[_]], + @transient val preferredLocations: Seq[String]) extends Partition { override val index: Int = idx @@ -37,33 +39,31 @@ private[spark] class ZippedPartitionsPartition( } } -abstract class ZippedPartitionsBaseRDD[V: ClassManifest]( +abstract class ZippedPartitionsBaseRDD[V: ClassTag]( sc: SparkContext, - var rdds: Seq[RDD[_]]) + var rdds: Seq[RDD[_]], + preservesPartitioning: Boolean = false) extends RDD[V](sc, rdds.map(x => new OneToOneDependency(x))) { + override val partitioner = + if (preservesPartitioning) firstParent[Any].partitioner else None + override def getPartitions: Array[Partition] = { - val sizes = rdds.map(x => x.partitions.size) - if (!sizes.forall(x => x == sizes(0))) { + val numParts = rdds.head.partitions.size + if (!rdds.forall(rdd => rdd.partitions.size == numParts)) { throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions") } - val array = new Array[Partition](sizes(0)) - for (i <- 0 until sizes(0)) { - array(i) = new ZippedPartitionsPartition(i, rdds) + Array.tabulate[Partition](numParts) { i => + val prefs = rdds.map(rdd => rdd.preferredLocations(rdd.partitions(i))) + // Check whether there are any hosts that match all RDDs; otherwise return the union + val exactMatchLocations = prefs.reduce((x, y) => x.intersect(y)) + val locs = if (!exactMatchLocations.isEmpty) exactMatchLocations else prefs.flatten.distinct + new ZippedPartitionsPartition(i, rdds, locs) } - array } override def getPreferredLocations(s: Partition): Seq[String] = { - val parts = s.asInstanceOf[ZippedPartitionsPartition].partitions - val prefs = rdds.zip(parts).map { case (rdd, p) => rdd.preferredLocations(p) } - // Check whether there are any hosts that match all RDDs; otherwise return the union - val exactMatchLocations = prefs.reduce((x, y) => x.intersect(y)) - if (!exactMatchLocations.isEmpty) { - exactMatchLocations - } else { - prefs.flatten.distinct - } + s.asInstanceOf[ZippedPartitionsPartition].preferredLocations } override def clearDependencies() { @@ -72,12 +72,13 @@ abstract class ZippedPartitionsBaseRDD[V: ClassManifest]( } } -class ZippedPartitionsRDD2[A: ClassManifest, B: ClassManifest, V: ClassManifest]( +class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag]( sc: SparkContext, f: (Iterator[A], Iterator[B]) => Iterator[V], var rdd1: RDD[A], - var rdd2: RDD[B]) - extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2)) { + var rdd2: RDD[B], + preservesPartitioning: Boolean = false) + extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2), preservesPartitioning) { override def compute(s: Partition, context: TaskContext): Iterator[V] = { val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions @@ -92,13 +93,14 @@ class ZippedPartitionsRDD2[A: ClassManifest, B: ClassManifest, V: ClassManifest] } class ZippedPartitionsRDD3 - [A: ClassManifest, B: ClassManifest, C: ClassManifest, V: ClassManifest]( + [A: ClassTag, B: ClassTag, C: ClassTag, V: ClassTag]( sc: SparkContext, f: (Iterator[A], Iterator[B], Iterator[C]) => Iterator[V], var rdd1: RDD[A], var rdd2: RDD[B], - var rdd3: RDD[C]) - extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3)) { + var rdd3: RDD[C], + preservesPartitioning: Boolean = false) + extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3), preservesPartitioning) { override def compute(s: Partition, context: TaskContext): Iterator[V] = { val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions @@ -116,14 +118,15 @@ class ZippedPartitionsRDD3 } class ZippedPartitionsRDD4 - [A: ClassManifest, B: ClassManifest, C: ClassManifest, D:ClassManifest, V: ClassManifest]( + [A: ClassTag, B: ClassTag, C: ClassTag, D:ClassTag, V: ClassTag]( sc: SparkContext, f: (Iterator[A], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V], var rdd1: RDD[A], var rdd2: RDD[B], var rdd3: RDD[C], - var rdd4: RDD[D]) - extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4)) { + var rdd4: RDD[D], + preservesPartitioning: Boolean = false) + extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4), preservesPartitioning) { override def compute(s: Partition, context: TaskContext): Iterator[V] = { val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala index 567b67dfee..fb5b070c18 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedRDD.scala @@ -18,10 +18,12 @@ package org.apache.spark.rdd import org.apache.spark.{OneToOneDependency, SparkContext, Partition, TaskContext} + import java.io.{ObjectOutputStream, IOException} +import scala.reflect.ClassTag -private[spark] class ZippedPartition[T: ClassManifest, U: ClassManifest]( +private[spark] class ZippedPartition[T: ClassTag, U: ClassTag]( idx: Int, @transient rdd1: RDD[T], @transient rdd2: RDD[U] @@ -42,7 +44,7 @@ private[spark] class ZippedPartition[T: ClassManifest, U: ClassManifest]( } } -class ZippedRDD[T: ClassManifest, U: ClassManifest]( +class ZippedRDD[T: ClassTag, U: ClassTag]( sc: SparkContext, var rdd1: RDD[T], var rdd2: RDD[U]) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 42bb3884c8..963d15b76d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -21,9 +21,11 @@ import java.io.NotSerializableException import java.util.Properties import java.util.concurrent.atomic.AtomicInteger -import akka.actor._ -import akka.util.duration._ import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, Map} +import scala.concurrent.duration._ +import scala.reflect.ClassTag + +import akka.actor._ import org.apache.spark._ import org.apache.spark.rdd.RDD @@ -104,36 +106,16 @@ class DAGScheduler( // The time, in millis, to wait for fetch failure events to stop coming in after one is detected; // this is a simplistic way to avoid resubmitting tasks in the non-fetchable map stage one by one // as more failure events come in - val RESUBMIT_TIMEOUT = 50L + val RESUBMIT_TIMEOUT = 50.milliseconds // The time, in millis, to wake up between polls of the completion queue in order to potentially // resubmit failed stages val POLL_TIMEOUT = 10L - private val eventProcessActor: ActorRef = env.actorSystem.actorOf(Props(new Actor { - override def preStart() { - context.system.scheduler.schedule(RESUBMIT_TIMEOUT milliseconds, RESUBMIT_TIMEOUT milliseconds) { - if (failed.size > 0) { - resubmitFailedStages() - } - } - } - - /** - * The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure - * events and responds by launching tasks. This runs in a dedicated thread and receives events - * via the eventQueue. - */ - def receive = { - case event: DAGSchedulerEvent => - logDebug("Got event of type " + event.getClass.getName) + // Warns the user if a stage contains a task with size greater than this value (in KB) + val TASK_SIZE_TO_WARN = 100 - if (!processEvent(event)) - submitWaitingStages() - else - context.stop(self) - } - })) + private var eventProcessActor: ActorRef = _ private[scheduler] val nextJobId = new AtomicInteger(0) @@ -141,9 +123,13 @@ class DAGScheduler( private val nextStageId = new AtomicInteger(0) - private val stageIdToStage = new TimeStampedHashMap[Int, Stage] + private[scheduler] val jobIdToStageIds = new TimeStampedHashMap[Int, HashSet[Int]] - private val shuffleToMapStage = new TimeStampedHashMap[Int, Stage] + private[scheduler] val stageIdToJobIds = new TimeStampedHashMap[Int, HashSet[Int]] + + private[scheduler] val stageIdToStage = new TimeStampedHashMap[Int, Stage] + + private[scheduler] val shuffleToMapStage = new TimeStampedHashMap[Int, Stage] private[spark] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo] @@ -174,6 +160,57 @@ class DAGScheduler( val metadataCleaner = new MetadataCleaner(MetadataCleanerType.DAG_SCHEDULER, this.cleanup) + /** + * Starts the event processing actor. The actor has two responsibilities: + * + * 1. Waits for events like job submission, task finished, task failure etc., and calls + * [[org.apache.spark.scheduler.DAGScheduler.processEvent()]] to process them. + * 2. Schedules a periodical task to resubmit failed stages. + * + * NOTE: the actor cannot be started in the constructor, because the periodical task references + * some internal states of the enclosing [[org.apache.spark.scheduler.DAGScheduler]] object, thus + * cannot be scheduled until the [[org.apache.spark.scheduler.DAGScheduler]] is fully constructed. + */ + def start() { + eventProcessActor = env.actorSystem.actorOf(Props(new Actor { + /** + * A handle to the periodical task, used to cancel the task when the actor is stopped. + */ + var resubmissionTask: Cancellable = _ + + override def preStart() { + import context.dispatcher + /** + * A message is sent to the actor itself periodically to remind the actor to resubmit failed + * stages. In this way, stage resubmission can be done within the same thread context of + * other event processing logic to avoid unnecessary synchronization overhead. + */ + resubmissionTask = context.system.scheduler.schedule( + RESUBMIT_TIMEOUT, RESUBMIT_TIMEOUT, self, ResubmitFailedStages) + } + + /** + * The main event loop of the DAG scheduler. + */ + def receive = { + case event: DAGSchedulerEvent => + logDebug("Got event of type " + event.getClass.getName) + + /** + * All events are forwarded to `processEvent()`, so that the event processing logic can + * easily tested without starting a dedicated actor. Please refer to `DAGSchedulerSuite` + * for details. + */ + if (!processEvent(event)) { + submitWaitingStages() + } else { + resubmissionTask.cancel() + context.stop(self) + } + } + })) + } + def addSparkListener(listener: SparkListener) { listenerBus.addListener(listener) } @@ -202,16 +239,16 @@ class DAGScheduler( shuffleToMapStage.get(shuffleDep.shuffleId) match { case Some(stage) => stage case None => - val stage = newStage(shuffleDep.rdd, shuffleDep.rdd.partitions.size, Some(shuffleDep), jobId) + val stage = newOrUsedStage(shuffleDep.rdd, shuffleDep.rdd.partitions.size, shuffleDep, jobId) shuffleToMapStage(shuffleDep.shuffleId) = stage stage } } /** - * Create a Stage for the given RDD, either as a shuffle map stage (for a ShuffleDependency) or - * as a result stage for the final RDD used directly in an action. The stage will also be - * associated with the provided jobId. + * Create a Stage -- either directly for use as a result stage, or as part of the (re)-creation + * of a shuffle map stage in newOrUsedStage. The stage will be associated with the provided + * jobId. Production of shuffle map stages should always use newOrUsedStage, not newStage directly. */ private def newStage( rdd: RDD[_], @@ -221,21 +258,45 @@ class DAGScheduler( callSite: Option[String] = None) : Stage = { - if (shuffleDep != None) { - // Kind of ugly: need to register RDDs with the cache and map output tracker here - // since we can't do it in the RDD constructor because # of partitions is unknown - logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")") - mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.partitions.size) - } val id = nextStageId.getAndIncrement() val stage = new Stage(id, rdd, numTasks, shuffleDep, getParentStages(rdd, jobId), jobId, callSite) stageIdToStage(id) = stage + updateJobIdStageIdMaps(jobId, stage) stageToInfos(stage) = new StageInfo(stage) stage } /** + * Create a shuffle map Stage for the given RDD. The stage will also be associated with the + * provided jobId. If a stage for the shuffleId existed previously so that the shuffleId is + * present in the MapOutputTracker, then the number and location of available outputs are + * recovered from the MapOutputTracker + */ + private def newOrUsedStage( + rdd: RDD[_], + numTasks: Int, + shuffleDep: ShuffleDependency[_,_], + jobId: Int, + callSite: Option[String] = None) + : Stage = + { + val stage = newStage(rdd, numTasks, Some(shuffleDep), jobId, callSite) + if (mapOutputTracker.has(shuffleDep.shuffleId)) { + val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId) + val locs = MapOutputTracker.deserializeMapStatuses(serLocs) + for (i <- 0 until locs.size) stage.outputLocs(i) = List(locs(i)) + stage.numAvailableOutputs = locs.size + } else { + // Kind of ugly: need to register RDDs with the cache and map output tracker here + // since we can't do it in the RDD constructor because # of partitions is unknown + logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")") + mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.size) + } + stage + } + + /** * Get or create the list of parent stages for a given RDD. The stages will be assigned the * provided jobId if they haven't already been created with a lower jobId. */ @@ -287,6 +348,89 @@ class DAGScheduler( } /** + * Registers the given jobId among the jobs that need the given stage and + * all of that stage's ancestors. + */ + private def updateJobIdStageIdMaps(jobId: Int, stage: Stage) { + def updateJobIdStageIdMapsList(stages: List[Stage]) { + if (!stages.isEmpty) { + val s = stages.head + stageIdToJobIds.getOrElseUpdate(s.id, new HashSet[Int]()) += jobId + jobIdToStageIds.getOrElseUpdate(jobId, new HashSet[Int]()) += s.id + val parents = getParentStages(s.rdd, jobId) + val parentsWithoutThisJobId = parents.filter(p => !stageIdToJobIds.get(p.id).exists(_.contains(jobId))) + updateJobIdStageIdMapsList(parentsWithoutThisJobId ++ stages.tail) + } + } + updateJobIdStageIdMapsList(List(stage)) + } + + /** + * Removes job and any stages that are not needed by any other job. Returns the set of ids for stages that + * were removed. The associated tasks for those stages need to be cancelled if we got here via job cancellation. + */ + private def removeJobAndIndependentStages(jobId: Int): Set[Int] = { + val registeredStages = jobIdToStageIds(jobId) + val independentStages = new HashSet[Int]() + if (registeredStages.isEmpty) { + logError("No stages registered for job " + jobId) + } else { + stageIdToJobIds.filterKeys(stageId => registeredStages.contains(stageId)).foreach { + case (stageId, jobSet) => + if (!jobSet.contains(jobId)) { + logError("Job %d not registered for stage %d even though that stage was registered for the job" + .format(jobId, stageId)) + } else { + def removeStage(stageId: Int) { + // data structures based on Stage + stageIdToStage.get(stageId).foreach { s => + if (running.contains(s)) { + logDebug("Removing running stage %d".format(stageId)) + running -= s + } + stageToInfos -= s + shuffleToMapStage.keys.filter(shuffleToMapStage(_) == s).foreach(shuffleToMapStage.remove) + if (pendingTasks.contains(s) && !pendingTasks(s).isEmpty) { + logDebug("Removing pending status for stage %d".format(stageId)) + } + pendingTasks -= s + if (waiting.contains(s)) { + logDebug("Removing stage %d from waiting set.".format(stageId)) + waiting -= s + } + if (failed.contains(s)) { + logDebug("Removing stage %d from failed set.".format(stageId)) + failed -= s + } + } + // data structures based on StageId + stageIdToStage -= stageId + stageIdToJobIds -= stageId + + logDebug("After removal of stage %d, remaining stages = %d".format(stageId, stageIdToStage.size)) + } + + jobSet -= jobId + if (jobSet.isEmpty) { // no other job needs this stage + independentStages += stageId + removeStage(stageId) + } + } + } + } + independentStages.toSet + } + + private def jobIdToStageIdsRemove(jobId: Int) { + if (!jobIdToStageIds.contains(jobId)) { + logDebug("Trying to remove unregistered job " + jobId) + } else { + removeJobAndIndependentStages(jobId) + jobIdToStageIds -= jobId + } + } + + /** * Submit a job to the job scheduler and get a JobWaiter object back. The JobWaiter object * can be used to block until the the job finishes executing or can be used to cancel the job. */ @@ -319,7 +463,7 @@ class DAGScheduler( waiter } - def runJob[T, U: ClassManifest]( + def runJob[T, U: ClassTag]( rdd: RDD[T], func: (TaskContext, Iterator[T]) => U, partitions: Seq[Int], @@ -375,13 +519,25 @@ class DAGScheduler( } /** - * Process one event retrieved from the event queue. - * Returns true if we should stop the event loop. + * Process one event retrieved from the event processing actor. + * + * @param event The event to be processed. + * @return `true` if we should stop the event loop. */ private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = { event match { case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) => - val finalStage = newStage(rdd, partitions.size, None, jobId, Some(callSite)) + var finalStage: Stage = null + try { + // New stage creation at times and if its not protected, the scheduler thread is killed. + // e.g. it can fail when jobs are run on HadoopRDD whose underlying hdfs files have been deleted + finalStage = newStage(rdd, partitions.size, None, jobId, Some(callSite)) + } catch { + case e: Exception => + logWarning("Creating new stage failed due to exception - job: " + jobId, e) + listener.jobFailed(e) + return false + } val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties) clearCacheLocs() logInfo("Got job " + job.jobId + " (" + callSite + ") with " + partitions.length + @@ -391,37 +547,31 @@ class DAGScheduler( logInfo("Missing parents: " + getMissingParentStages(finalStage)) if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) { // Compute very short actions like first() or take() with no parent stages locally. + listenerBus.post(SparkListenerJobStart(job, Array(), properties)) runLocally(job) } else { - listenerBus.post(SparkListenerJobStart(job, properties)) idToActiveJob(jobId) = job activeJobs += job resultStageToJob(finalStage) = job + listenerBus.post(SparkListenerJobStart(job, jobIdToStageIds(jobId).toArray, properties)) submitStage(finalStage) } case JobCancelled(jobId) => - // Cancel a job: find all the running stages that are linked to this job, and cancel them. - running.filter(_.jobId == jobId).foreach { stage => - taskSched.cancelTasks(stage.id) - } + handleJobCancellation(jobId) case JobGroupCancelled(groupId) => // Cancel all jobs belonging to this job group. // First finds all active jobs with this group id, and then kill stages for them. - val jobIds = activeJobs.filter(groupId == _.properties.get(SparkContext.SPARK_JOB_GROUP_ID)) - .map(_.jobId) - if (!jobIds.isEmpty) { - running.filter(stage => jobIds.contains(stage.jobId)).foreach { stage => - taskSched.cancelTasks(stage.id) - } - } + val activeInGroup = activeJobs.filter(groupId == _.properties.get(SparkContext.SPARK_JOB_GROUP_ID)) + val jobIds = activeInGroup.map(_.jobId) + jobIds.foreach { handleJobCancellation } case AllJobsCancelled => // Cancel all running jobs. - running.foreach { stage => - taskSched.cancelTasks(stage.id) - } + running.map(_.jobId).foreach { handleJobCancellation } + activeJobs.clear() // These should already be empty by this point, + idToActiveJob.clear() // but just in case we lost track of some jobs... case ExecutorGained(execId, host) => handleExecutorGained(execId, host) @@ -430,6 +580,18 @@ class DAGScheduler( handleExecutorLost(execId) case BeginEvent(task, taskInfo) => + for ( + job <- idToActiveJob.get(task.stageId); + stage <- stageIdToStage.get(task.stageId); + stageInfo <- stageToInfos.get(stage) + ) { + if (taskInfo.serializedSize > TASK_SIZE_TO_WARN * 1024 && !stageInfo.emittedTaskSizeWarning) { + stageInfo.emittedTaskSizeWarning = true + logWarning(("Stage %d (%s) contains a task of very large " + + "size (%d KB). The maximum recommended task size is %d KB.").format( + task.stageId, stageInfo.name, taskInfo.serializedSize / 1024, TASK_SIZE_TO_WARN)) + } + } listenerBus.post(SparkListenerTaskStart(task, taskInfo)) case GettingResultEvent(task, taskInfo) => @@ -440,7 +602,12 @@ class DAGScheduler( handleTaskCompletion(completion) case TaskSetFailed(taskSet, reason) => - abortStage(stageIdToStage(taskSet.stageId), reason) + stageIdToStage.get(taskSet.stageId).foreach { abortStage(_, reason) } + + case ResubmitFailedStages => + if (failed.size > 0) { + resubmitFailedStages() + } case StopDAGScheduler => // Cancel any active jobs @@ -502,6 +669,7 @@ class DAGScheduler( // Broken out for easier testing in DAGSchedulerSuite. protected def runLocallyWithinThread(job: ActiveJob) { + var jobResult: JobResult = JobSucceeded try { SparkEnv.set(env) val rdd = job.finalStage.rdd @@ -516,31 +684,59 @@ class DAGScheduler( } } catch { case e: Exception => + jobResult = JobFailed(e, Some(job.finalStage)) job.listener.jobFailed(e) + } finally { + val s = job.finalStage + stageIdToJobIds -= s.id // clean up data structures that were populated for a local job, + stageIdToStage -= s.id // but that won't get cleaned up via the normal paths through + stageToInfos -= s // completion events or stage abort + jobIdToStageIds -= job.jobId + listenerBus.post(SparkListenerJobEnd(job, jobResult)) + } + } + + /** Finds the earliest-created active job that needs the stage */ + // TODO: Probably should actually find among the active jobs that need this + // stage the one with the highest priority (highest-priority pool, earliest created). + // That should take care of at least part of the priority inversion problem with + // cross-job dependencies. + private def activeJobForStage(stage: Stage): Option[Int] = { + if (stageIdToJobIds.contains(stage.id)) { + val jobsThatUseStage: Array[Int] = stageIdToJobIds(stage.id).toArray.sorted + jobsThatUseStage.find(idToActiveJob.contains(_)) + } else { + None } } /** Submits stage, but first recursively submits any missing parents. */ private def submitStage(stage: Stage) { - logDebug("submitStage(" + stage + ")") - if (!waiting(stage) && !running(stage) && !failed(stage)) { - val missing = getMissingParentStages(stage).sortBy(_.id) - logDebug("missing: " + missing) - if (missing == Nil) { - logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents") - submitMissingTasks(stage) - running += stage - } else { - for (parent <- missing) { - submitStage(parent) + val jobId = activeJobForStage(stage) + if (jobId.isDefined) { + logDebug("submitStage(" + stage + ")") + if (!waiting(stage) && !running(stage) && !failed(stage)) { + val missing = getMissingParentStages(stage).sortBy(_.id) + logDebug("missing: " + missing) + if (missing == Nil) { + logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents") + submitMissingTasks(stage, jobId.get) + running += stage + } else { + for (parent <- missing) { + submitStage(parent) + } + waiting += stage } - waiting += stage } + } else { + abortStage(stage, "No active job for stage " + stage.id) } } + /** Called when stage's parents are available and we can now do its task. */ - private def submitMissingTasks(stage: Stage) { + private def submitMissingTasks(stage: Stage, jobId: Int) { logDebug("submitMissingTasks(" + stage + ")") // Get our pending tasks and remember them in our pendingTasks entry val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet) @@ -561,7 +757,7 @@ class DAGScheduler( } } - val properties = if (idToActiveJob.contains(stage.jobId)) { + val properties = if (idToActiveJob.contains(jobId)) { idToActiveJob(stage.jobId).properties } else { //this stage will be assigned to "default" pool @@ -643,6 +839,7 @@ class DAGScheduler( activeJobs -= job resultStageToJob -= stage markStageAsFinished(stage) + jobIdToStageIdsRemove(job.jobId) listenerBus.post(SparkListenerJobEnd(job, JobSucceeded)) } job.listener.taskSucceeded(rt.outputId, event.result) @@ -679,7 +876,7 @@ class DAGScheduler( changeEpoch = true) } clearCacheLocs() - if (stage.outputLocs.count(_ == Nil) != 0) { + if (stage.outputLocs.exists(_ == Nil)) { // Some tasks had failed; let's resubmit this stage // TODO: Lower-level scheduler should also deal with this logInfo("Resubmitting " + stage + " (" + stage.name + @@ -696,9 +893,12 @@ class DAGScheduler( } waiting --= newlyRunnable running ++= newlyRunnable - for (stage <- newlyRunnable.sortBy(_.id)) { + for { + stage <- newlyRunnable.sortBy(_.id) + jobId <- activeJobForStage(stage) + } { logInfo("Submitting " + stage + " (" + stage.rdd + "), which is now runnable") - submitMissingTasks(stage) + submitMissingTasks(stage, jobId) } } } @@ -782,21 +982,42 @@ class DAGScheduler( } } + private def handleJobCancellation(jobId: Int) { + if (!jobIdToStageIds.contains(jobId)) { + logDebug("Trying to cancel unregistered job " + jobId) + } else { + val independentStages = removeJobAndIndependentStages(jobId) + independentStages.foreach { taskSched.cancelTasks } + val error = new SparkException("Job %d cancelled".format(jobId)) + val job = idToActiveJob(jobId) + job.listener.jobFailed(error) + jobIdToStageIds -= jobId + activeJobs -= job + idToActiveJob -= jobId + listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, Some(job.finalStage)))) + } + } + /** * Aborts all jobs depending on a particular Stage. This is called in response to a task set * being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside. */ private def abortStage(failedStage: Stage, reason: String) { + if (!stageIdToStage.contains(failedStage.id)) { + // Skip all the actions if the stage has been removed. + return + } val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq stageToInfos(failedStage).completionTime = Some(System.currentTimeMillis()) for (resultStage <- dependentStages) { val job = resultStageToJob(resultStage) val error = new SparkException("Job aborted: " + reason) job.listener.jobFailed(error) - listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, Some(failedStage)))) + jobIdToStageIdsRemove(job.jobId) idToActiveJob -= resultStage.jobId activeJobs -= job resultStageToJob -= resultStage + listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, Some(failedStage)))) } if (dependentStages.isEmpty) { logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done") @@ -867,25 +1088,24 @@ class DAGScheduler( } private def cleanup(cleanupTime: Long) { - var sizeBefore = stageIdToStage.size - stageIdToStage.clearOldValues(cleanupTime) - logInfo("stageIdToStage " + sizeBefore + " --> " + stageIdToStage.size) - - sizeBefore = shuffleToMapStage.size - shuffleToMapStage.clearOldValues(cleanupTime) - logInfo("shuffleToMapStage " + sizeBefore + " --> " + shuffleToMapStage.size) - - sizeBefore = pendingTasks.size - pendingTasks.clearOldValues(cleanupTime) - logInfo("pendingTasks " + sizeBefore + " --> " + pendingTasks.size) - - sizeBefore = stageToInfos.size - stageToInfos.clearOldValues(cleanupTime) - logInfo("stageToInfos " + sizeBefore + " --> " + stageToInfos.size) + Map( + "stageIdToStage" -> stageIdToStage, + "shuffleToMapStage" -> shuffleToMapStage, + "pendingTasks" -> pendingTasks, + "stageToInfos" -> stageToInfos, + "jobIdToStageIds" -> jobIdToStageIds, + "stageIdToJobIds" -> stageIdToJobIds). + foreach { case(s, t) => { + val sizeBefore = t.size + t.clearOldValues(cleanupTime) + logInfo("%s %d --> %d".format(s, sizeBefore, t.size)) + }} } def stop() { - eventProcessActor ! StopDAGScheduler + if (eventProcessActor != null) { + eventProcessActor ! StopDAGScheduler + } metadataCleaner.cancel() taskSched.stop() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index 708d221d60..add1187613 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -65,12 +65,13 @@ private[scheduler] case class CompletionEvent( taskMetrics: TaskMetrics) extends DAGSchedulerEvent -private[scheduler] -case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent +private[scheduler] case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent private[scheduler] case class ExecutorLost(execId: String) extends DAGSchedulerEvent private[scheduler] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent +private[scheduler] case object ResubmitFailedStages extends DAGSchedulerEvent + private[scheduler] case object StopDAGScheduler extends DAGSchedulerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala index 58f238d8cf..b026f860a8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala @@ -31,6 +31,7 @@ private[spark] class JobWaiter[T]( private var finishedTasks = 0 // Is the job as a whole finished (succeeded or failed)? + @volatile private var _jobFinished = totalTasks == 0 def jobFinished = _jobFinished diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulingMode.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulingMode.scala index 0a786deb16..3832ee7ff6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulingMode.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulingMode.scala @@ -22,7 +22,7 @@ package org.apache.spark.scheduler * to order tasks amongst a Schedulable's sub-queues * "NONE" is used when the a Schedulable has no sub-queues. */ -object SchedulingMode extends Enumeration("FAIR", "FIFO", "NONE") { +object SchedulingMode extends Enumeration { type SchedulingMode = Value val FAIR,FIFO,NONE = Value diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 1dc71a0428..0f2deb4bcb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -167,6 +167,7 @@ private[spark] class ShuffleMapTask( var totalTime = 0L val compressedSizes: Array[Byte] = shuffle.writers.map { writer: BlockObjectWriter => writer.commit() + writer.close() val size = writer.fileSegment().length totalBytes += size totalTime += writer.timeWriting() @@ -184,14 +185,16 @@ private[spark] class ShuffleMapTask( } catch { case e: Exception => // If there is an exception from running the task, revert the partial writes // and throw the exception upstream to Spark. - if (shuffle != null) { - shuffle.writers.foreach(_.revertPartialWrites()) + if (shuffle != null && shuffle.writers != null) { + for (writer <- shuffle.writers) { + writer.revertPartialWrites() + writer.close() + } } throw e } finally { // Release the writers back to the shuffle block manager. if (shuffle != null && shuffle.writers != null) { - shuffle.writers.foreach(_.close()) shuffle.releaseWriters(success) } // Execute the callbacks on task completion. diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index a35081f7b1..ee63b3c4a1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -37,7 +37,7 @@ case class SparkListenerTaskGettingResult( case class SparkListenerTaskEnd(task: Task[_], reason: TaskEndReason, taskInfo: TaskInfo, taskMetrics: TaskMetrics) extends SparkListenerEvents -case class SparkListenerJobStart(job: ActiveJob, properties: Properties = null) +case class SparkListenerJobStart(job: ActiveJob, stageIds: Array[Int], properties: Properties = null) extends SparkListenerEvents case class SparkListenerJobEnd(job: ActiveJob, jobResult: JobResult) @@ -63,7 +63,7 @@ trait SparkListener { * Called when a task begins remotely fetching its result (will not be called for tasks that do * not need to fetch the result remotely). */ - def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { } + def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) { } /** * Called when a task ends @@ -131,8 +131,8 @@ object StatsReportListener extends Logging { def showDistribution(heading: String, d: Distribution, formatNumber: Double => String) { val stats = d.statCounter - logInfo(heading + stats) val quantiles = d.getQuantiles(probabilities).map{formatNumber} + logInfo(heading + stats) logInfo(percentilesHeader) logInfo("\t" + quantiles.mkString("\t")) } @@ -173,8 +173,6 @@ object StatsReportListener extends Logging { showMillisDistribution(heading, extractLongDistribution(stage, getMetric)) } - - val seconds = 1000L val minutes = seconds * 60 val hours = minutes * 60 @@ -198,7 +196,6 @@ object StatsReportListener extends Logging { } - case class RuntimePercentage(executorPct: Double, fetchPct: Option[Double], other: Double) object RuntimePercentage { def apply(totalTime: Long, metrics: TaskMetrics): RuntimePercentage = { diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala index d5824e7954..85687ea330 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListenerBus.scala @@ -91,4 +91,3 @@ private[spark] class SparkListenerBus() extends Logging { return true
}
}
-
diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index 93599dfdc8..e9f2198a00 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -33,4 +33,5 @@ class StageInfo( val name = stage.name val numPartitions = stage.numPartitions val numTasks = stage.numTasks + var emittedTaskSizeWarning = false } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index 4bae26f3a6..3c22edd524 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -46,6 +46,8 @@ class TaskInfo( var failed = false + var serializedSize: Int = 0 + def markGettingResult(time: Long = System.currentTimeMillis) { gettingResultTime = time } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala index 47b0f387aa..35de13c385 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskLocality.scala @@ -18,9 +18,7 @@ package org.apache.spark.scheduler -private[spark] object TaskLocality - extends Enumeration("PROCESS_LOCAL", "NODE_LOCAL", "RACK_LOCAL", "ANY") -{ +private[spark] object TaskLocality extends Enumeration { // process local is expected to be used ONLY within tasksetmanager for now. val PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL, ANY = Value diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala index c1e65a3c48..66ab8ea4cd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala @@ -24,8 +24,7 @@ import java.util.{TimerTask, Timer} import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap import scala.collection.mutable.HashSet - -import akka.util.duration._ +import scala.concurrent.duration._ import org.apache.spark._ import org.apache.spark.TaskState.TaskState @@ -100,8 +99,8 @@ private[spark] class ClusterScheduler(val sc: SparkContext) this.dagScheduler = dagScheduler } - def initialize(context: SchedulerBackend) { - backend = context + def initialize(backend: SchedulerBackend) { + this.backend = backend // temporarily set rootPool name to empty rootPool = new Pool("", schedulingMode, 0, 0) schedulableBuilder = { @@ -122,7 +121,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) if (System.getProperty("spark.speculation", "false").toBoolean) { logInfo("Starting speculative execution thread") - + import sc.env.actorSystem.dispatcher sc.env.actorSystem.scheduler.schedule(SPECULATION_INTERVAL milliseconds, SPECULATION_INTERVAL milliseconds) { checkSpeculatableTasks() @@ -173,7 +172,9 @@ private[spark] class ClusterScheduler(val sc: SparkContext) backend.killTask(tid, execId) } } - tsm.error("Stage %d was cancelled".format(stageId)) + logInfo("Stage %d was cancelled".format(stageId)) + tsm.removeAllRunningTasks() + taskSetFinished(tsm) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala index 4c5eca8537..bf494aa64d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala @@ -377,6 +377,7 @@ private[spark] class ClusterTaskSetManager( logInfo("Serialized task %s:%d as %d bytes in %d ms".format( taskSet.id, index, serializedTask.limit, timeTaken)) val taskName = "task %s:%d".format(taskSet.id, index) + info.serializedSize = serializedTask.limit if (taskAttempts(index).size == 1) taskStarted(task,info) return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask)) @@ -528,10 +529,10 @@ private[spark] class ClusterTaskSetManager( addPendingTask(index) if (state != TaskState.KILLED) { numFailures(index) += 1 - if (numFailures(index) > MAX_TASK_FAILURES) { - logError("Task %s:%d failed more than %d times; aborting job".format( + if (numFailures(index) >= MAX_TASK_FAILURES) { + logError("Task %s:%d failed %d times; aborting job".format( taskSet.id, index, MAX_TASK_FAILURES)) - abort("Task %s:%d failed more than %d times".format(taskSet.id, index, MAX_TASK_FAILURES)) + abort("Task %s:%d failed %d times".format(taskSet.id, index, MAX_TASK_FAILURES)) } } } else { @@ -573,7 +574,7 @@ private[spark] class ClusterTaskSetManager( runningTasks = runningTasksSet.size } - private def removeAllRunningTasks() { + private[cluster] def removeAllRunningTasks() { val numRunningTasks = runningTasksSet.size runningTasksSet.clear() if (parent != null) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index a45bee536c..f5e8766f6d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -20,13 +20,12 @@ package org.apache.spark.scheduler.cluster import java.util.concurrent.atomic.AtomicInteger import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet} +import scala.concurrent.Await +import scala.concurrent.duration._ import akka.actor._ -import akka.dispatch.Await import akka.pattern.ask -import akka.remote.{RemoteClientShutdown, RemoteClientDisconnected, RemoteClientLifeCycleEvent} -import akka.util.Duration -import akka.util.duration._ +import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} import org.apache.spark.{SparkException, Logging, TaskState} import org.apache.spark.scheduler.TaskDescription @@ -53,15 +52,15 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac private val executorAddress = new HashMap[String, Address] private val executorHost = new HashMap[String, String] private val freeCores = new HashMap[String, Int] - private val actorToExecutorId = new HashMap[ActorRef, String] private val addressToExecutorId = new HashMap[Address, String] override def preStart() { // Listen for remote client disconnection events, since they don't go through Akka's watch() - context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) + context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) // Periodically revive offers to allow delay scheduling to work val reviveInterval = System.getProperty("spark.scheduler.revive.interval", "1000").toLong + import context.dispatcher context.system.scheduler.schedule(0.millis, reviveInterval.millis, self, ReviveOffers) } @@ -73,12 +72,10 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac } else { logInfo("Registered executor: " + sender + " with ID " + executorId) sender ! RegisteredExecutor(sparkProperties) - context.watch(sender) executorActor(executorId) = sender executorHost(executorId) = Utils.parseHostPort(hostPort)._1 freeCores(executorId) = cores executorAddress(executorId) = sender.path.address - actorToExecutorId(sender) = executorId addressToExecutorId(sender.path.address) = executorId totalCoreCount.addAndGet(cores) makeOffers() @@ -118,14 +115,9 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac removeExecutor(executorId, reason) sender ! true - case Terminated(actor) => - actorToExecutorId.get(actor).foreach(removeExecutor(_, "Akka actor terminated")) + case DisassociatedEvent(_, address, _) => + addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client disassociated")) - case RemoteClientDisconnected(transport, address) => - addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client disconnected")) - - case RemoteClientShutdown(transport, address) => - addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client shutdown")) } // Make fake resource offers on all executors @@ -153,7 +145,6 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac if (executorActor.contains(executorId)) { logInfo("Executor " + executorId + " disconnected, so removing it") val numCores = freeCores(executorId) - actorToExecutorId -= executorActor(executorId) addressToExecutorId -= executorAddress(executorId) executorActor -= executorId executorHost -= executorId @@ -199,6 +190,7 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac } override def stop() { + stopExecutors() try { if (driverActor != null) { val future = driverActor.ask(StopDriver)(timeout) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala index 0ea35e2b7a..e8fecec4a6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala @@ -36,7 +36,7 @@ private[spark] class SimrSchedulerBackend( override def start() { super.start() - val driverUrl = "akka://spark@%s:%s/user/%s".format( + val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format( System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"), CoarseGrainedSchedulerBackend.ACTOR_NAME) @@ -62,7 +62,6 @@ private[spark] class SimrSchedulerBackend( val conf = new Configuration() val fs = FileSystem.get(conf) fs.delete(new Path(driverFilePath), false) - super.stopExecutors() super.stop() } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index cefa970bb9..7127a72d6d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -42,7 +42,7 @@ private[spark] class SparkDeploySchedulerBackend( super.start() // The endpoint for executors to talk to us - val driverUrl = "akka://spark@%s:%s/user/%s".format( + val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format( System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"), CoarseGrainedSchedulerBackend.ACTOR_NAME) val args = Seq(driverUrl, "{{EXECUTOR_ID}}", "{{HOSTNAME}}", "{{CORES}}") diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala index 2064d97b49..e68c527713 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala @@ -71,7 +71,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterSche case cnf: ClassNotFoundException => val loader = Thread.currentThread.getContextClassLoader taskSetManager.abort("ClassNotFound with classloader: " + loader) - case ex => + case ex: Throwable => taskSetManager.abort("Exception while deserializing and fetching task: %s".format(ex)) } } @@ -95,7 +95,7 @@ private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterSche val loader = Thread.currentThread.getContextClassLoader logError( "Could not deserialize TaskEndReason: ClassNotFound with classloader " + loader) - case ex => {} + case ex: Throwable => {} } scheduler.handleFailedTask(taskSetManager, tid, taskState, reason) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index cd521e0f2b..84fe3094cc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -120,7 +120,7 @@ private[spark] class CoarseMesosSchedulerBackend( } val command = CommandInfo.newBuilder() .setEnvironment(environment) - val driverUrl = "akka://spark@%s:%s/user/%s".format( + val driverUrl = "akka.tcp://spark@%s:%s/user/%s".format( System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"), CoarseGrainedSchedulerBackend.ACTOR_NAME) diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala index 2699f0b33e..01e95162c0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala @@ -74,7 +74,7 @@ class LocalActor(localScheduler: LocalScheduler, private var freeCores: Int) } } -private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: SparkContext) +private[spark] class LocalScheduler(val threads: Int, val maxFailures: Int, val sc: SparkContext) extends TaskScheduler with ExecutorBackend with Logging { @@ -144,7 +144,8 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: localActor ! KillTask(tid) } } - tsm.error("Stage %d was cancelled".format(stageId)) + logInfo("Stage %d was cancelled".format(stageId)) + taskSetFinished(tsm) } } @@ -192,17 +193,19 @@ private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: synchronized { taskIdToTaskSetId.get(taskId) match { case Some(taskSetId) => - val taskSetManager = activeTaskSets(taskSetId) - taskSetTaskIds(taskSetId) -= taskId - - state match { - case TaskState.FINISHED => - taskSetManager.taskEnded(taskId, state, serializedData) - case TaskState.FAILED => - taskSetManager.taskFailed(taskId, state, serializedData) - case TaskState.KILLED => - taskSetManager.error("Task %d was killed".format(taskId)) - case _ => {} + val taskSetManager = activeTaskSets.get(taskSetId) + taskSetManager.foreach { tsm => + taskSetTaskIds(taskSetId) -= taskId + + state match { + case TaskState.FINISHED => + tsm.taskEnded(taskId, state, serializedData) + case TaskState.FAILED => + tsm.taskFailed(taskId, state, serializedData) + case TaskState.KILLED => + tsm.error("Task %d was killed".format(taskId)) + case _ => {} + } } case None => logInfo("Ignoring update from TID " + taskId + " because its task set is gone") diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index 702aca8323..19a025a329 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -24,9 +24,9 @@ import scala.collection.mutable.{HashMap, ArrayBuffer} import scala.util.Random import akka.actor.{ActorSystem, Cancellable, Props} -import akka.dispatch.{Await, Future} -import akka.util.Duration -import akka.util.duration._ +import scala.concurrent.{Await, Future} +import scala.concurrent.duration.Duration +import scala.concurrent.duration._ import it.unimi.dsi.fastutil.io.{FastBufferedOutputStream, FastByteArrayOutputStream} @@ -924,4 +924,3 @@ private[spark] object BlockManager extends Logging { blockIdsToBlockManagers(blockIds, env, blockManagerMaster).mapValues(s => s.map(_.host)) } } - diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index 94038649b3..e05b842476 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -17,16 +17,17 @@ package org.apache.spark.storage -import akka.actor.ActorRef -import akka.dispatch.{Await, Future} +import scala.concurrent.{Await, Future} +import scala.concurrent.duration._ +import scala.concurrent.ExecutionContext.Implicits.global + +import akka.actor._ import akka.pattern.ask -import akka.util.Duration import org.apache.spark.{Logging, SparkException} import org.apache.spark.storage.BlockManagerMessages._ - -private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Logging { +private[spark] class BlockManagerMaster(var driverActor : Either[ActorRef, ActorSelection]) extends Logging { val AKKA_RETRY_ATTEMPTS: Int = System.getProperty("spark.akka.num.retries", "3").toInt val AKKA_RETRY_INTERVAL_MS: Int = System.getProperty("spark.akka.retry.wait", "3000").toInt @@ -156,7 +157,10 @@ private[spark] class BlockManagerMaster(var driverActor: ActorRef) extends Loggi while (attempts < AKKA_RETRY_ATTEMPTS) { attempts += 1 try { - val future = driverActor.ask(message)(timeout) + val future = driverActor match { + case Left(a: ActorRef) => a.ask(message)(timeout) + case Right(b: ActorSelection) => b.ask(message)(timeout) + } val result = Await.result(future, timeout) if (result == null) { throw new SparkException("BlockManagerMaster returned null") diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index f8cf14b503..154a3980e9 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -23,10 +23,10 @@ import scala.collection.mutable import scala.collection.JavaConversions._ import akka.actor.{Actor, ActorRef, Cancellable} -import akka.dispatch.Future import akka.pattern.ask -import akka.util.Duration -import akka.util.duration._ + +import scala.concurrent.duration._ +import scala.concurrent.Future import org.apache.spark.{Logging, SparkException} import org.apache.spark.storage.BlockManagerMessages._ @@ -65,6 +65,7 @@ class BlockManagerMasterActor(val isLocal: Boolean) extends Actor with Logging { override def preStart() { if (!BlockManager.getDisableHeartBeatsForTesting) { + import context.dispatcher timeoutCheckingTask = context.system.scheduler.schedule( 0.seconds, checkTimeoutInterval.milliseconds, self, ExpireDeadHosts) } diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index 469e68fed7..b4451fc7b8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -93,6 +93,8 @@ class DiskBlockObjectWriter( def write(i: Int): Unit = callWithTiming(out.write(i)) override def write(b: Array[Byte]) = callWithTiming(out.write(b)) override def write(b: Array[Byte], off: Int, len: Int) = callWithTiming(out.write(b, off, len)) + override def close() = out.close() + override def flush() = out.flush() } private val syncWrites = System.getProperty("spark.shuffle.sync", "false").toBoolean diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index 2f1b049ce4..e828e1d1c5 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -62,7 +62,7 @@ class ShuffleBlockManager(blockManager: BlockManager) { // Turning off shuffle file consolidation causes all shuffle Blocks to get their own file. // TODO: Remove this once the shuffle file consolidation feature is stable. val consolidateShuffleFiles = - System.getProperty("spark.shuffle.consolidateFiles", "true").toBoolean + System.getProperty("spark.shuffle.consolidateFiles", "false").toBoolean private val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024 diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala index 632ff047d1..b5596dffd3 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala @@ -101,7 +101,7 @@ class StorageLevel private( var result = "" result += (if (useDisk) "Disk " else "") result += (if (useMemory) "Memory " else "") - result += (if (deserialized) "Deserialized " else "Serialized") + result += (if (deserialized) "Deserialized " else "Serialized ") result += "%sx Replicated".format(replication) result } diff --git a/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala b/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala index 1e4db4f66b..d52b3d8284 100644 --- a/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala +++ b/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.storage import java.util.concurrent.atomic.AtomicLong diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala index 860e680576..a8db37ded1 100644 --- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala +++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala @@ -93,7 +93,7 @@ private[spark] object ThreadingTest { val actorSystem = ActorSystem("test") val serializer = new KryoSerializer val blockManagerMaster = new BlockManagerMaster( - actorSystem.actorOf(Props(new BlockManagerMasterActor(true)))) + Left(actorSystem.actorOf(Props(new BlockManagerMasterActor(true))))) val blockManager = new BlockManager( "<driver>", actorSystem, blockManagerMaster, serializer, 1024 * 1024) val producers = (1 to numProducers).map(i => new ProducerThread(blockManager, i)) diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala index 42e9be6e19..e596690bc3 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala @@ -76,7 +76,7 @@ private[spark] class ExecutorsUI(val sc: SparkContext) { </tr> } - val execInfo = for (b <- 0 until storageStatusList.size) yield getExecInfo(b) + val execInfo = for (statusId <- 0 until storageStatusList.size) yield getExecInfo(statusId) val execTable = UIUtils.listingTable(execHead, execRow, execInfo) val content = @@ -99,16 +99,17 @@ private[spark] class ExecutorsUI(val sc: SparkContext) { UIUtils.headerSparkPage(content, sc, "Executors (" + execInfo.size + ")", Executors) } - def getExecInfo(a: Int): Seq[String] = { - val execId = sc.getExecutorStorageStatus(a).blockManagerId.executorId - val hostPort = sc.getExecutorStorageStatus(a).blockManagerId.hostPort - val rddBlocks = sc.getExecutorStorageStatus(a).blocks.size.toString - val memUsed = sc.getExecutorStorageStatus(a).memUsed().toString - val maxMem = sc.getExecutorStorageStatus(a).maxMem.toString - val diskUsed = sc.getExecutorStorageStatus(a).diskUsed().toString - val activeTasks = listener.executorToTasksActive.get(a.toString).map(l => l.size).getOrElse(0) - val failedTasks = listener.executorToTasksFailed.getOrElse(a.toString, 0) - val completedTasks = listener.executorToTasksComplete.getOrElse(a.toString, 0) + def getExecInfo(statusId: Int): Seq[String] = { + val status = sc.getExecutorStorageStatus(statusId) + val execId = status.blockManagerId.executorId + val hostPort = status.blockManagerId.hostPort + val rddBlocks = status.blocks.size.toString + val memUsed = status.memUsed().toString + val maxMem = status.maxMem.toString + val diskUsed = status.diskUsed().toString + val activeTasks = listener.executorToTasksActive.getOrElse(execId, HashSet.empty[Long]).size + val failedTasks = listener.executorToTasksFailed.getOrElse(execId, 0) + val completedTasks = listener.executorToTasksComplete.getOrElse(execId, 0) val totalTasks = activeTasks + failedTasks + completedTasks Seq( diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala index e7eab374ad..c1ee2f3d00 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressUI.scala @@ -17,7 +17,7 @@ package org.apache.spark.ui.jobs -import akka.util.Duration +import scala.concurrent.duration._ import java.text.SimpleDateFormat diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index c1c7aa70e6..69f9446bab 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -60,11 +60,13 @@ private[spark] class StagePage(parent: JobProgressUI) { var activeTime = 0L listener.stageIdToTasksActive(stageId).foreach(activeTime += _.timeRunning(now)) + val finishedTasks = listener.stageIdToTaskInfos(stageId).filter(_._1.finished) + val summary = <div> <ul class="unstyled"> <li> - <strong>CPU time: </strong> + <strong>Total duration across all tasks: </strong> {parent.formatDuration(listener.stageIdToTime.getOrElse(stageId, 0L) + activeTime)} </li> {if (hasShuffleRead) @@ -104,6 +106,33 @@ private[spark] class StagePage(parent: JobProgressUI) { val serviceQuantiles = "Duration" +: Distribution(serviceTimes).get.getQuantiles().map( ms => parent.formatDuration(ms.toLong)) + val gettingResultTimes = validTasks.map{case (info, metrics, exception) => + if (info.gettingResultTime > 0) { + (info.finishTime - info.gettingResultTime).toDouble + } else { + 0.0 + } + } + val gettingResultQuantiles = ("Time spent fetching task results" +: + Distribution(gettingResultTimes).get.getQuantiles().map( + millis => parent.formatDuration(millis.toLong))) + // The scheduler delay includes the network delay to send the task to the worker + // machine and to send back the result (but not the time to fetch the task result, + // if it needed to be fetched from the block manager on the worker). + val schedulerDelays = validTasks.map{case (info, metrics, exception) => + val totalExecutionTime = { + if (info.gettingResultTime > 0) { + (info.gettingResultTime - info.launchTime).toDouble + } else { + (info.finishTime - info.launchTime).toDouble + } + } + totalExecutionTime - metrics.get.executorRunTime + } + val schedulerDelayQuantiles = ("Scheduler delay" +: + Distribution(schedulerDelays).get.getQuantiles().map( + millis => parent.formatDuration(millis.toLong))) + def getQuantileCols(data: Seq[Double]) = Distribution(data).get.getQuantiles().map(d => Utils.bytesToString(d.toLong)) @@ -119,7 +148,10 @@ private[spark] class StagePage(parent: JobProgressUI) { } val shuffleWriteQuantiles = "Shuffle Write" +: getQuantileCols(shuffleWriteSizes) - val listings: Seq[Seq[String]] = Seq(serviceQuantiles, + val listings: Seq[Seq[String]] = Seq( + serviceQuantiles, + gettingResultQuantiles, + schedulerDelayQuantiles, if (hasShuffleRead) shuffleReadQuantiles else Nil, if (hasShuffleWrite) shuffleWriteQuantiles else Nil) @@ -133,7 +165,7 @@ private[spark] class StagePage(parent: JobProgressUI) { summary ++ <h4>Summary Metrics for {numCompleted} Completed Tasks</h4> ++ <div>{summaryTable.getOrElse("No tasks have reported metrics yet.")}</div> ++ - <h4>Tasks</h4> ++ taskTable; + <h4>Tasks</h4> ++ taskTable headerSparkPage(content, parent.sc, "Details for Stage %d".format(stageId), Stages) } @@ -152,21 +184,18 @@ private[spark] class StagePage(parent: JobProgressUI) { else metrics.map(m => parent.formatDuration(m.executorRunTime)).getOrElse("") val gcTime = metrics.map(m => m.jvmGCTime).getOrElse(0L) - var shuffleReadSortable: String = "" - var shuffleReadReadable: String = "" - if (shuffleRead) { - shuffleReadSortable = metrics.flatMap{m => m.shuffleReadMetrics}.map{s => s.remoteBytesRead}.toString() - shuffleReadReadable = metrics.flatMap{m => m.shuffleReadMetrics}.map{s => - Utils.bytesToString(s.remoteBytesRead)}.getOrElse("") - } + val maybeShuffleRead = metrics.flatMap{m => m.shuffleReadMetrics}.map{s => s.remoteBytesRead} + val shuffleReadSortable = maybeShuffleRead.map(_.toString).getOrElse("") + val shuffleReadReadable = maybeShuffleRead.map{Utils.bytesToString(_)}.getOrElse("") - var shuffleWriteSortable: String = "" - var shuffleWriteReadable: String = "" - if (shuffleWrite) { - shuffleWriteSortable = metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => s.shuffleBytesWritten}.toString() - shuffleWriteReadable = metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => - Utils.bytesToString(s.shuffleBytesWritten)}.getOrElse("") - } + val maybeShuffleWrite = metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => s.shuffleBytesWritten} + val shuffleWriteSortable = maybeShuffleWrite.map(_.toString).getOrElse("") + val shuffleWriteReadable = maybeShuffleWrite.map{Utils.bytesToString(_)}.getOrElse("") + + val maybeWriteTime = metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => s.shuffleWriteTime} + val writeTimeSortable = maybeWriteTime.map(_.toString).getOrElse("") + val writeTimeReadable = maybeWriteTime.map{ t => t / (1000 * 1000)}.map{ ms => + if (ms == 0) "" else parent.formatDuration(ms)}.getOrElse("") <tr> <td>{info.index}</td> @@ -187,8 +216,8 @@ private[spark] class StagePage(parent: JobProgressUI) { </td> }} {if (shuffleWrite) { - <td>{metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => - parent.formatDuration(s.shuffleWriteTime / (1000 * 1000))}.getOrElse("")} + <td sorttable_customkey={writeTimeSortable}> + {writeTimeReadable} </td> <td sorttable_customkey={shuffleWriteSortable}> {shuffleWriteReadable} diff --git a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala index 1d633d374a..a5446b3fc3 100644 --- a/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/storage/BlockManagerUI.scala @@ -17,7 +17,7 @@ package org.apache.spark.ui.storage -import akka.util.Duration +import scala.concurrent.duration._ import javax.servlet.http.HttpServletRequest diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index d4c5065c3f..74133cef6c 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -17,11 +17,8 @@ package org.apache.spark.util -import akka.actor.{ActorSystem, ExtendedActorSystem} +import akka.actor.{ActorSystem, ExtendedActorSystem, IndestructibleActorSystem} import com.typesafe.config.ConfigFactory -import akka.util.duration._ -import akka.remote.RemoteActorRefProvider - /** * Various utility classes for working with Akka. @@ -34,39 +31,57 @@ private[spark] object AkkaUtils { * * Note: the `name` parameter is important, as even if a client sends a message to right * host + port, if the system name is incorrect, Akka will drop the message. + * + * If indestructible is set to true, the Actor System will continue running in the event + * of a fatal exception. This is used by [[org.apache.spark.executor.Executor]]. */ - def createActorSystem(name: String, host: String, port: Int): (ActorSystem, Int) = { - val akkaThreads = System.getProperty("spark.akka.threads", "4").toInt + def createActorSystem(name: String, host: String, port: Int, indestructible: Boolean = false) + : (ActorSystem, Int) = { + + val akkaThreads = System.getProperty("spark.akka.threads", "4").toInt val akkaBatchSize = System.getProperty("spark.akka.batchSize", "15").toInt - val akkaTimeout = System.getProperty("spark.akka.timeout", "60").toInt + + val akkaTimeout = System.getProperty("spark.akka.timeout", "100").toInt + val akkaFrameSize = System.getProperty("spark.akka.frameSize", "10").toInt - val lifecycleEvents = if (System.getProperty("spark.akka.logLifecycleEvents", "false").toBoolean) "on" else "off" - // 10 seconds is the default akka timeout, but in a cluster, we need higher by default. - val akkaWriteTimeout = System.getProperty("spark.akka.writeTimeout", "30").toInt - - val akkaConf = ConfigFactory.parseString(""" - akka.daemonic = on - akka.event-handlers = ["akka.event.slf4j.Slf4jEventHandler"] - akka.stdout-loglevel = "ERROR" - akka.actor.provider = "akka.remote.RemoteActorRefProvider" - akka.remote.transport = "akka.remote.netty.NettyRemoteTransport" - akka.remote.netty.hostname = "%s" - akka.remote.netty.port = %d - akka.remote.netty.connection-timeout = %ds - akka.remote.netty.message-frame-size = %d MiB - akka.remote.netty.execution-pool-size = %d - akka.actor.default-dispatcher.throughput = %d - akka.remote.log-remote-lifecycle-events = %s - akka.remote.netty.write-timeout = %ds - """.format(host, port, akkaTimeout, akkaFrameSize, akkaThreads, akkaBatchSize, - lifecycleEvents, akkaWriteTimeout)) + val lifecycleEvents = + if (System.getProperty("spark.akka.logLifecycleEvents", "false").toBoolean) "on" else "off" + + val akkaHeartBeatPauses = System.getProperty("spark.akka.heartbeat.pauses", "600").toInt + val akkaFailureDetector = + System.getProperty("spark.akka.failure-detector.threshold", "300.0").toDouble + val akkaHeartBeatInterval = System.getProperty("spark.akka.heartbeat.interval", "1000").toInt - val actorSystem = ActorSystem(name, akkaConf) + val akkaConf = ConfigFactory.parseString( + s""" + |akka.daemonic = on + |akka.loggers = [""akka.event.slf4j.Slf4jLogger""] + |akka.stdout-loglevel = "ERROR" + |akka.jvm-exit-on-fatal-error = off + |akka.remote.transport-failure-detector.heartbeat-interval = $akkaHeartBeatInterval s + |akka.remote.transport-failure-detector.acceptable-heartbeat-pause = $akkaHeartBeatPauses s + |akka.remote.transport-failure-detector.threshold = $akkaFailureDetector + |akka.actor.provider = "akka.remote.RemoteActorRefProvider" + |akka.remote.netty.tcp.transport-class = "akka.remote.transport.netty.NettyTransport" + |akka.remote.netty.tcp.hostname = "$host" + |akka.remote.netty.tcp.port = $port + |akka.remote.netty.tcp.tcp-nodelay = on + |akka.remote.netty.tcp.connection-timeout = $akkaTimeout s + |akka.remote.netty.tcp.maximum-frame-size = ${akkaFrameSize}MiB + |akka.remote.netty.tcp.execution-pool-size = $akkaThreads + |akka.actor.default-dispatcher.throughput = $akkaBatchSize + |akka.remote.log-remote-lifecycle-events = $lifecycleEvents + """.stripMargin) + + val actorSystem = if (indestructible) { + IndestructibleActorSystem(name, akkaConf) + } else { + ActorSystem(name, akkaConf) + } - // Figure out the port number we bound to, in case port was passed as 0. This is a bit of a - // hack because Akka doesn't let you figure out the port through the public API yet. val provider = actorSystem.asInstanceOf[ExtendedActorSystem].provider - val boundPort = provider.asInstanceOf[RemoteActorRefProvider].transport.address.port.get - return (actorSystem, boundPort) + val boundPort = provider.getDefaultAddress.port.get + (actorSystem, boundPort) } + } diff --git a/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala index f60deafc6f..8bb4ee3bfa 100644 --- a/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala @@ -35,6 +35,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi private var capacity = nextPowerOf2(initialCapacity) private var mask = capacity - 1 private var curSize = 0 + private var growThreshold = LOAD_FACTOR * capacity // Holds keys and values in the same array for memory locality; specifically, the order of // elements is key0, value0, key1, value1, key2, value2, etc. @@ -56,7 +57,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi var i = 1 while (true) { val curKey = data(2 * pos) - if (k.eq(curKey) || k == curKey) { + if (k.eq(curKey) || k.equals(curKey)) { return data(2 * pos + 1).asInstanceOf[V] } else if (curKey.eq(null)) { return null.asInstanceOf[V] @@ -80,9 +81,23 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi haveNullValue = true return } - val isNewEntry = putInto(data, k, value.asInstanceOf[AnyRef]) - if (isNewEntry) { - incrementSize() + var pos = rehash(key.hashCode) & mask + var i = 1 + while (true) { + val curKey = data(2 * pos) + if (curKey.eq(null)) { + data(2 * pos) = k + data(2 * pos + 1) = value.asInstanceOf[AnyRef] + incrementSize() // Since we added a new key + return + } else if (k.eq(curKey) || k.equals(curKey)) { + data(2 * pos + 1) = value.asInstanceOf[AnyRef] + return + } else { + val delta = i + pos = (pos + delta) & mask + i += 1 + } } } @@ -104,7 +119,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi var i = 1 while (true) { val curKey = data(2 * pos) - if (k.eq(curKey) || k == curKey) { + if (k.eq(curKey) || k.equals(curKey)) { val newValue = updateFunc(true, data(2 * pos + 1).asInstanceOf[V]) data(2 * pos + 1) = newValue.asInstanceOf[AnyRef] return newValue @@ -161,45 +176,17 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi /** Increase table size by 1, rehashing if necessary */ private def incrementSize() { curSize += 1 - if (curSize > LOAD_FACTOR * capacity) { + if (curSize > growThreshold) { growTable() } } /** - * Re-hash a value to deal better with hash functions that don't differ - * in the lower bits, similar to java.util.HashMap + * Re-hash a value to deal better with hash functions that don't differ in the lower bits. + * We use the Murmur Hash 3 finalization step that's also used in fastutil. */ private def rehash(h: Int): Int = { - val r = h ^ (h >>> 20) ^ (h >>> 12) - r ^ (r >>> 7) ^ (r >>> 4) - } - - /** - * Put an entry into a table represented by data, returning true if - * this increases the size of the table or false otherwise. Assumes - * that "data" has at least one empty slot. - */ - private def putInto(data: Array[AnyRef], key: AnyRef, value: AnyRef): Boolean = { - val mask = (data.length / 2) - 1 - var pos = rehash(key.hashCode) & mask - var i = 1 - while (true) { - val curKey = data(2 * pos) - if (curKey.eq(null)) { - data(2 * pos) = key - data(2 * pos + 1) = value.asInstanceOf[AnyRef] - return true - } else if (curKey.eq(key) || curKey == key) { - data(2 * pos + 1) = value.asInstanceOf[AnyRef] - return false - } else { - val delta = i - pos = (pos + delta) & mask - i += 1 - } - } - return false // Never reached but needed to keep compiler happy + it.unimi.dsi.fastutil.HashCommon.murmurHash3(h) } /** Double the table's size and re-hash everything */ @@ -211,16 +198,36 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi throw new Exception("Can't make capacity bigger than 2^29 elements") } val newData = new Array[AnyRef](2 * newCapacity) - var pos = 0 - while (pos < capacity) { - if (!data(2 * pos).eq(null)) { - putInto(newData, data(2 * pos), data(2 * pos + 1)) + val newMask = newCapacity - 1 + // Insert all our old values into the new array. Note that because our old keys are + // unique, there's no need to check for equality here when we insert. + var oldPos = 0 + while (oldPos < capacity) { + if (!data(2 * oldPos).eq(null)) { + val key = data(2 * oldPos) + val value = data(2 * oldPos + 1) + var newPos = rehash(key.hashCode) & newMask + var i = 1 + var keepGoing = true + while (keepGoing) { + val curKey = newData(2 * newPos) + if (curKey.eq(null)) { + newData(2 * newPos) = key + newData(2 * newPos + 1) = value + keepGoing = false + } else { + val delta = i + newPos = (newPos + delta) & newMask + i += 1 + } + } } - pos += 1 + oldPos += 1 } data = newData capacity = newCapacity - mask = newCapacity - 1 + mask = newMask + growThreshold = LOAD_FACTOR * newCapacity } private def nextPowerOf2(n: Int): Int = { diff --git a/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala b/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala index 0b51c23f7b..a38329df03 100644 --- a/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala +++ b/core/src/main/scala/org/apache/spark/util/BoundedPriorityQueue.scala @@ -34,6 +34,8 @@ class BoundedPriorityQueue[A](maxSize: Int)(implicit ord: Ordering[A]) override def iterator: Iterator[A] = underlying.iterator.asScala + override def size: Int = underlying.size + override def ++=(xs: TraversableOnce[A]): this.type = { xs.foreach { this += _ } this diff --git a/core/src/main/scala/org/apache/spark/util/IndestructibleActorSystem.scala b/core/src/main/scala/org/apache/spark/util/IndestructibleActorSystem.scala new file mode 100644 index 0000000000..bf71882ef7 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/IndestructibleActorSystem.scala @@ -0,0 +1,68 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +// Must be in akka.actor package as ActorSystemImpl is protected[akka]. +package akka.actor + +import scala.util.control.{ControlThrowable, NonFatal} + +import com.typesafe.config.Config + +/** + * An [[akka.actor.ActorSystem]] which refuses to shut down in the event of a fatal exception. + * This is necessary as Spark Executors are allowed to recover from fatal exceptions + * (see [[org.apache.spark.executor.Executor]]). + */ +object IndestructibleActorSystem { + def apply(name: String, config: Config): ActorSystem = + apply(name, config, ActorSystem.findClassLoader()) + + def apply(name: String, config: Config, classLoader: ClassLoader): ActorSystem = + new IndestructibleActorSystemImpl(name, config, classLoader).start() +} + +private[akka] class IndestructibleActorSystemImpl( + override val name: String, + applicationConfig: Config, + classLoader: ClassLoader) + extends ActorSystemImpl(name, applicationConfig, classLoader) { + + protected override def uncaughtExceptionHandler: Thread.UncaughtExceptionHandler = { + val fallbackHandler = super.uncaughtExceptionHandler + + new Thread.UncaughtExceptionHandler() { + def uncaughtException(thread: Thread, cause: Throwable): Unit = { + if (isFatalError(cause) && !settings.JvmExitOnFatalError) { + log.error(cause, "Uncaught fatal error from thread [{}] not shutting down " + + "ActorSystem [{}] tolerating and continuing.... ", thread.getName, name) + //shutdown() //TODO make it configurable + } else { + fallbackHandler.uncaughtException(thread, cause) + } + } + } + } + + def isFatalError(e: Throwable): Boolean = { + e match { + case NonFatal(_) | _: InterruptedException | _: NotImplementedError | _: ControlThrowable => + false + case _ => + true + } + } +} diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala index 67a7f87a5c..7b41ef89f1 100644 --- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala @@ -55,8 +55,7 @@ class MetadataCleaner(cleanerType: MetadataCleanerType.MetadataCleanerType, clea } } -object MetadataCleanerType extends Enumeration("MapOutputTracker", "SparkContext", "HttpBroadcast", "DagScheduler", "ResultTask", - "ShuffleMapTask", "BlockManager", "DiskBlockManager", "BroadcastVars") { +object MetadataCleanerType extends Enumeration { val MAP_OUTPUT_TRACKER, SPARK_CONTEXT, HTTP_BROADCAST, DAG_SCHEDULER, RESULT_TASK, SHUFFLE_MAP_TASK, BLOCK_MANAGER, SHUFFLE_BLOCK_MANAGER, BROADCAST_VARS = Value diff --git a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala index 277de2f8a6..dbff571de9 100644 --- a/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/TimeStampedHashMap.scala @@ -85,7 +85,7 @@ class TimeStampedHashMap[A, B] extends Map[A, B]() with Logging { } override def filter(p: ((A, B)) => Boolean): Map[A, B] = { - JavaConversions.asScalaConcurrentMap(internalMap).map(kv => (kv._1, kv._2._1)).filter(p) + JavaConversions.mapAsScalaConcurrentMap(internalMap).map(kv => (kv._1, kv._2._1)).filter(p) } override def empty: Map[A, B] = new TimeStampedHashMap[A, B]() diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index fe932d8ede..3f7858d2de 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -22,10 +22,11 @@ import java.net.{InetAddress, URL, URI, NetworkInterface, Inet4Address} import java.util.{Locale, Random, UUID} import java.util.concurrent.{ConcurrentHashMap, Executors, ThreadPoolExecutor} +import scala.collection.JavaConversions._ import scala.collection.Map import scala.collection.mutable.ArrayBuffer -import scala.collection.JavaConversions._ import scala.io.Source +import scala.reflect.ClassTag import com.google.common.io.Files import com.google.common.util.concurrent.ThreadFactoryBuilder @@ -319,7 +320,7 @@ private[spark] object Utils extends Logging { * result in a new collection. Unlike scala.util.Random.shuffle, this method * uses a local random number generator, avoiding inter-thread contention. */ - def randomize[T: ClassManifest](seq: TraversableOnce[T]): Seq[T] = { + def randomize[T: ClassTag](seq: TraversableOnce[T]): Seq[T] = { randomizeInPlace(seq.toArray) } @@ -823,4 +824,28 @@ private[spark] object Utils extends Logging { return System.getProperties().clone() .asInstanceOf[java.util.Properties].toMap[String, String] } + + /** + * Method executed for repeating a task for side effects. + * Unlike a for comprehension, it permits JVM JIT optimization + */ + def times(numIters: Int)(f: => Unit): Unit = { + var i = 0 + while (i < numIters) { + f + i += 1 + } + } + + /** + * Timing method based on iterations that permit JVM JIT optimization. + * @param numIters number of iterations + * @param f function to be executed + */ + def timeIt(numIters: Int)(f: => Unit): Long = { + val start = System.currentTimeMillis + times(numIters)(f) + System.currentTimeMillis - start + } + } diff --git a/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala new file mode 100644 index 0000000000..e9907e6c85 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.util.{Random => JavaRandom} +import org.apache.spark.util.Utils.timeIt + +/** + * This class implements a XORShift random number generator algorithm + * Source: + * Marsaglia, G. (2003). Xorshift RNGs. Journal of Statistical Software, Vol. 8, Issue 14. + * @see <a href="http://www.jstatsoft.org/v08/i14/paper">Paper</a> + * This implementation is approximately 3.5 times faster than + * {@link java.util.Random java.util.Random}, partly because of the algorithm, but also due + * to renouncing thread safety. JDK's implementation uses an AtomicLong seed, this class + * uses a regular Long. We can forgo thread safety since we use a new instance of the RNG + * for each thread. + */ +private[spark] class XORShiftRandom(init: Long) extends JavaRandom(init) { + + def this() = this(System.nanoTime) + + private var seed = init + + // we need to just override next - this will be called by nextInt, nextDouble, + // nextGaussian, nextLong, etc. + override protected def next(bits: Int): Int = { + var nextSeed = seed ^ (seed << 21) + nextSeed ^= (nextSeed >>> 35) + nextSeed ^= (nextSeed << 4) + seed = nextSeed + (nextSeed & ((1L << bits) -1)).asInstanceOf[Int] + } +} + +/** Contains benchmark method and main method to run benchmark of the RNG */ +private[spark] object XORShiftRandom { + + /** + * Main method for running benchmark + * @param args takes one argument - the number of random numbers to generate + */ + def main(args: Array[String]): Unit = { + if (args.length != 1) { + println("Benchmark of XORShiftRandom vis-a-vis java.util.Random") + println("Usage: XORShiftRandom number_of_random_numbers_to_generate") + System.exit(1) + } + println(benchmark(args(0).toInt)) + } + + /** + * @param numIters Number of random numbers to generate while running the benchmark + * @return Map of execution times for {@link java.util.Random java.util.Random} + * and XORShift + */ + def benchmark(numIters: Int) = { + + val seed = 1L + val million = 1e6.toInt + val javaRand = new JavaRandom(seed) + val xorRand = new XORShiftRandom(seed) + + // this is just to warm up the JIT - we're not timing anything + timeIt(1e6.toInt) { + javaRand.nextInt() + xorRand.nextInt() + } + + val iters = timeIt(numIters)(_) + + /* Return results as a map instead of just printing to screen + in case the user wants to do something with them */ + Map("javaTime" -> iters {javaRand.nextInt()}, + "xorTime" -> iters {xorRand.nextInt()}) + + } + +}
\ No newline at end of file diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala index 80545c9688..c26f23d500 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashMap.scala @@ -17,6 +17,7 @@ package org.apache.spark.util.collection +import scala.reflect.ClassTag /** * A fast hash map implementation for nullable keys. This hash map supports insertions and updates, @@ -26,7 +27,7 @@ package org.apache.spark.util.collection * Under the hood, it uses our OpenHashSet implementation. */ private[spark] -class OpenHashMap[K >: Null : ClassManifest, @specialized(Long, Int, Double) V: ClassManifest]( +class OpenHashMap[K >: Null : ClassTag, @specialized(Long, Int, Double) V: ClassTag]( initialCapacity: Int) extends Iterable[(K, V)] with Serializable { diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index 4592e4f939..87e009a4de 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -17,6 +17,7 @@ package org.apache.spark.util.collection +import scala.reflect._ /** * A simple, fast hash set optimized for non-null insertion-only use case, where keys are never @@ -36,7 +37,7 @@ package org.apache.spark.util.collection * to explore all spaces for each key (see http://en.wikipedia.org/wiki/Quadratic_probing). */ private[spark] -class OpenHashSet[@specialized(Long, Int) T: ClassManifest]( +class OpenHashSet[@specialized(Long, Int) T: ClassTag]( initialCapacity: Int, loadFactor: Double) extends Serializable { @@ -62,14 +63,14 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest]( // throws: // scala.tools.nsc.symtab.Types$TypeError: type mismatch; // found : scala.reflect.AnyValManifest[Long] - // required: scala.reflect.ClassManifest[Int] + // required: scala.reflect.ClassTag[Int] // at scala.tools.nsc.typechecker.Contexts$Context.error(Contexts.scala:298) // at scala.tools.nsc.typechecker.Infer$Inferencer.error(Infer.scala:207) // ... - val mt = classManifest[T] - if (mt == ClassManifest.Long) { + val mt = classTag[T] + if (mt == ClassTag.Long) { (new LongHasher).asInstanceOf[Hasher[T]] - } else if (mt == ClassManifest.Int) { + } else if (mt == ClassTag.Int) { (new IntHasher).asInstanceOf[Hasher[T]] } else { new Hasher[T] @@ -79,6 +80,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest]( protected var _capacity = nextPowerOf2(initialCapacity) protected var _mask = _capacity - 1 protected var _size = 0 + protected var _growThreshold = (loadFactor * _capacity).toInt protected var _bitset = new BitSet(_capacity) @@ -115,7 +117,29 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest]( * @return The position where the key is placed, plus the highest order bit is set if the key * exists previously. */ - def addWithoutResize(k: T): Int = putInto(_bitset, _data, k) + def addWithoutResize(k: T): Int = { + var pos = hashcode(hasher.hash(k)) & _mask + var i = 1 + while (true) { + if (!_bitset.get(pos)) { + // This is a new key. + _data(pos) = k + _bitset.set(pos) + _size += 1 + return pos | NONEXISTENCE_MASK + } else if (_data(pos) == k) { + // Found an existing key. + return pos + } else { + val delta = i + pos = (pos + delta) & _mask + i += 1 + } + } + // Never reached here + assert(INVALID_POS != INVALID_POS) + INVALID_POS + } /** * Rehash the set if it is overloaded. @@ -126,7 +150,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest]( * to a new position (in the new data array). */ def rehashIfNeeded(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) { - if (_size > loadFactor * _capacity) { + if (_size > _growThreshold) { rehash(k, allocateFunc, moveFunc) } } @@ -161,37 +185,6 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest]( def nextPos(fromPos: Int): Int = _bitset.nextSetBit(fromPos) /** - * Put an entry into the set. Return the position where the key is placed. In addition, the - * highest bit in the returned position is set if the key exists prior to this put. - * - * This function assumes the data array has at least one empty slot. - */ - private def putInto(bitset: BitSet, data: Array[T], k: T): Int = { - val mask = data.length - 1 - var pos = hashcode(hasher.hash(k)) & mask - var i = 1 - while (true) { - if (!bitset.get(pos)) { - // This is a new key. - data(pos) = k - bitset.set(pos) - _size += 1 - return pos | NONEXISTENCE_MASK - } else if (data(pos) == k) { - // Found an existing key. - return pos - } else { - val delta = i - pos = (pos + delta) & mask - i += 1 - } - } - // Never reached here - assert(INVALID_POS != INVALID_POS) - INVALID_POS - } - - /** * Double the table's size and re-hash everything. We are not really using k, but it is declared * so Scala compiler can specialize this method (which leads to calling the specialized version * of putInto). @@ -204,34 +197,49 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest]( */ private def rehash(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) { val newCapacity = _capacity * 2 - require(newCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements") - allocateFunc(newCapacity) - val newData = new Array[T](newCapacity) val newBitset = new BitSet(newCapacity) - var pos = 0 - _size = 0 - while (pos < _capacity) { - if (_bitset.get(pos)) { - val newPos = putInto(newBitset, newData, _data(pos)) - moveFunc(pos, newPos & POSITION_MASK) + val newData = new Array[T](newCapacity) + val newMask = newCapacity - 1 + + var oldPos = 0 + while (oldPos < capacity) { + if (_bitset.get(oldPos)) { + val key = _data(oldPos) + var newPos = hashcode(hasher.hash(key)) & newMask + var i = 1 + var keepGoing = true + // No need to check for equality here when we insert so this has one less if branch than + // the similar code path in addWithoutResize. + while (keepGoing) { + if (!newBitset.get(newPos)) { + // Inserting the key at newPos + newData(newPos) = key + newBitset.set(newPos) + moveFunc(oldPos, newPos) + keepGoing = false + } else { + val delta = i + newPos = (newPos + delta) & newMask + i += 1 + } + } } - pos += 1 + oldPos += 1 } + _bitset = newBitset _data = newData _capacity = newCapacity - _mask = newCapacity - 1 + _mask = newMask + _growThreshold = (loadFactor * newCapacity).toInt } /** - * Re-hash a value to deal better with hash functions that don't differ - * in the lower bits, similar to java.util.HashMap + * Re-hash a value to deal better with hash functions that don't differ in the lower bits. + * We use the Murmur Hash 3 finalization step that's also used in fastutil. */ - private def hashcode(h: Int): Int = { - val r = h ^ (h >>> 20) ^ (h >>> 12) - r ^ (r >>> 7) ^ (r >>> 4) - } + private def hashcode(h: Int): Int = it.unimi.dsi.fastutil.HashCommon.murmurHash3(h) private def nextPowerOf2(n: Int): Int = { val highBit = Integer.highestOneBit(n) diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala index d76143e45a..2e1ef06cbc 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMap.scala @@ -17,6 +17,7 @@ package org.apache.spark.util.collection +import scala.reflect._ /** * A fast hash map implementation for primitive, non-null keys. This hash map supports @@ -26,15 +27,15 @@ package org.apache.spark.util.collection * Under the hood, it uses our OpenHashSet implementation. */ private[spark] -class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassManifest, - @specialized(Long, Int, Double) V: ClassManifest]( +class PrimitiveKeyOpenHashMap[@specialized(Long, Int) K: ClassTag, + @specialized(Long, Int, Double) V: ClassTag]( initialCapacity: Int) extends Iterable[(K, V)] with Serializable { def this() = this(64) - require(classManifest[K] == classManifest[Long] || classManifest[K] == classManifest[Int]) + require(classTag[K] == classTag[Long] || classTag[K] == classTag[Int]) // Init in constructor (instead of in declaration) to work around a Scala compiler specialization // bug that would generate two arrays (one for Object and one for specialized T). diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala index 20554f0aab..b84eb65c62 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala @@ -17,11 +17,13 @@ package org.apache.spark.util.collection +import scala.reflect.ClassTag + /** * An append-only, non-threadsafe, array-backed vector that is optimized for primitive types. */ private[spark] -class PrimitiveVector[@specialized(Long, Int, Double) V: ClassManifest](initialSize: Int = 64) { +class PrimitiveVector[@specialized(Long, Int, Double) V: ClassTag](initialSize: Int = 64) { private var _numElements = 0 private var _array: Array[V] = _ diff --git a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala index 4434f3b87c..c443c5266e 100644 --- a/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/AccumulatorSuite.scala @@ -27,6 +27,21 @@ import org.apache.spark.SparkContext._ class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkContext { + + implicit def setAccum[A] = new AccumulableParam[mutable.Set[A], A] { + def addInPlace(t1: mutable.Set[A], t2: mutable.Set[A]) : mutable.Set[A] = { + t1 ++= t2 + t1 + } + def addAccumulator(t1: mutable.Set[A], t2: A) : mutable.Set[A] = { + t1 += t2 + t1 + } + def zero(t: mutable.Set[A]) : mutable.Set[A] = { + new mutable.HashSet[A]() + } + } + test ("basic accumulation"){ sc = new SparkContext("local", "test") val acc : Accumulator[Int] = sc.accumulator(0) @@ -51,7 +66,6 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkConte } test ("add value to collection accumulators") { - import SetAccum._ val maxI = 1000 for (nThreads <- List(1, 10)) { //test single & multi-threaded sc = new SparkContext("local[" + nThreads + "]", "test") @@ -68,22 +82,7 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkConte } } - implicit object SetAccum extends AccumulableParam[mutable.Set[Any], Any] { - def addInPlace(t1: mutable.Set[Any], t2: mutable.Set[Any]) : mutable.Set[Any] = { - t1 ++= t2 - t1 - } - def addAccumulator(t1: mutable.Set[Any], t2: Any) : mutable.Set[Any] = { - t1 += t2 - t1 - } - def zero(t: mutable.Set[Any]) : mutable.Set[Any] = { - new mutable.HashSet[Any]() - } - } - test ("value not readable in tasks") { - import SetAccum._ val maxI = 1000 for (nThreads <- List(1, 10)) { //test single & multi-threaded sc = new SparkContext("local[" + nThreads + "]", "test") @@ -125,7 +124,6 @@ class AccumulatorSuite extends FunSuite with ShouldMatchers with LocalSparkConte } test ("localValue readable in tasks") { - import SetAccum._ val maxI = 1000 for (nThreads <- List(1, 10)) { //test single & multi-threaded sc = new SparkContext("local[" + nThreads + "]", "test") diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index f26c44d3e7..f25d921d3f 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -17,6 +17,7 @@ package org.apache.spark +import scala.reflect.ClassTag import org.scalatest.FunSuite import java.io.File import org.apache.spark.rdd._ @@ -62,8 +63,6 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { testCheckpointing(_.sample(false, 0.5, 0)) testCheckpointing(_.glom()) testCheckpointing(_.mapPartitions(_.map(_.toString))) - testCheckpointing(r => new MapPartitionsWithContextRDD(r, - (context: TaskContext, iter: Iterator[Int]) => iter.map(_.toString), false )) testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString)) testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x)) testCheckpointing(_.pipe(Seq("cat"))) @@ -207,7 +206,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { * not, but this is not done by default as usually the partitions do not refer to any RDD and * therefore never store the lineage. */ - def testCheckpointing[U: ClassManifest]( + def testCheckpointing[U: ClassTag]( op: (RDD[Int]) => RDD[U], testRDDSize: Boolean = true, testRDDPartitionSize: Boolean = false @@ -276,7 +275,7 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { * RDDs partitions. So even if the parent RDD is checkpointed and its partitions changed, * this RDD will remember the partitions and therefore potentially the whole lineage. */ - def testParentCheckpointing[U: ClassManifest]( + def testParentCheckpointing[U: ClassTag]( op: (RDD[Int]) => RDD[U], testRDDSize: Boolean, testRDDPartitionSize: Boolean diff --git a/core/src/test/scala/org/apache/spark/DistributedSuite.scala b/core/src/test/scala/org/apache/spark/DistributedSuite.scala index 480bac84f3..d9cb7fead5 100644 --- a/core/src/test/scala/org/apache/spark/DistributedSuite.scala +++ b/core/src/test/scala/org/apache/spark/DistributedSuite.scala @@ -122,7 +122,7 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter sc.parallelize(1 to 10, 10).foreach(x => println(x / 0)) } assert(thrown.getClass === classOf[SparkException]) - assert(thrown.getMessage.contains("more than 4 times")) + assert(thrown.getMessage.contains("failed 4 times")) } test("caching") { @@ -303,12 +303,13 @@ class DistributedSuite extends FunSuite with ShouldMatchers with BeforeAndAfter Thread.sleep(200) } } catch { - case _ => { Thread.sleep(10) } + case _: Throwable => { Thread.sleep(10) } // Do nothing. We might see exceptions because block manager // is racing this thread to remove entries from the driver. } } } + } object DistributedSuite { diff --git a/core/src/test/scala/org/apache/spark/DriverSuite.scala b/core/src/test/scala/org/apache/spark/DriverSuite.scala index 01a72d8401..6d1695eae7 100644 --- a/core/src/test/scala/org/apache/spark/DriverSuite.scala +++ b/core/src/test/scala/org/apache/spark/DriverSuite.scala @@ -34,7 +34,7 @@ class DriverSuite extends FunSuite with Timeouts { // Regression test for SPARK-530: "Spark driver process doesn't exit after finishing" val masters = Table(("master"), ("local"), ("local-cluster[2,1,512]")) forAll(masters) { (master: String) => - failAfter(30 seconds) { + failAfter(60 seconds) { Utils.execute(Seq("./spark-class", "org.apache.spark.DriverWithoutCleanup", master), new File(System.getenv("SPARK_HOME"))) } diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java index 352036f182..4234f6eac7 100644 --- a/core/src/test/scala/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java @@ -365,6 +365,20 @@ public class JavaAPISuite implements Serializable { } @Test + public void javaDoubleRDDHistoGram() { + JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); + // Test using generated buckets + Tuple2<double[], long[]> results = rdd.histogram(2); + double[] expected_buckets = {1.0, 2.5, 4.0}; + long[] expected_counts = {2, 2}; + Assert.assertArrayEquals(expected_buckets, results._1, 0.1); + Assert.assertArrayEquals(expected_counts, results._2); + // Test with provided buckets + long[] histogram = rdd.histogram(expected_buckets); + Assert.assertArrayEquals(expected_counts, histogram); + } + + @Test public void map() { JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); JavaDoubleRDD doubles = rdd.map(new DoubleFunction<Integer>() { diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index d8a0e983b2..1121e06e2e 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -114,7 +114,7 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf // Once A is cancelled, job B should finish fairly quickly. assert(jobB.get() === 100) } - +/* test("two jobs sharing the same stage") { // sem1: make sure cancel is issued after some tasks are launched // sem2: make sure the first stage is not finished until cancel is issued @@ -148,7 +148,7 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf intercept[SparkException] { f1.get() } intercept[SparkException] { f2.get() } } - + */ def testCount() { // Cancel before launching any tasks { diff --git a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala index 459e257d79..8dd5786da6 100644 --- a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala @@ -30,7 +30,7 @@ trait LocalSparkContext extends BeforeAndAfterEach with BeforeAndAfterAll { self @transient var sc: SparkContext = _ override def beforeAll() { - InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory()); + InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory()) super.beforeAll() } diff --git a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala index b7eb268bd5..271dc905bc 100644 --- a/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala +++ b/core/src/test/scala/org/apache/spark/MapOutputTrackerSuite.scala @@ -49,14 +49,14 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { test("master start and stop") { val actorSystem = ActorSystem("test") val tracker = new MapOutputTrackerMaster() - tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker))) + tracker.trackerActor = Left(actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker)))) tracker.stop() } test("master register and fetch") { val actorSystem = ActorSystem("test") val tracker = new MapOutputTrackerMaster() - tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker))) + tracker.trackerActor = Left(actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker)))) tracker.registerShuffle(10, 2) val compressedSize1000 = MapOutputTracker.compressSize(1000L) val compressedSize10000 = MapOutputTracker.compressSize(10000L) @@ -75,7 +75,7 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { test("master register and unregister and fetch") { val actorSystem = ActorSystem("test") val tracker = new MapOutputTrackerMaster() - tracker.trackerActor = actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker))) + tracker.trackerActor = Left(actorSystem.actorOf(Props(new MapOutputTrackerMasterActor(tracker)))) tracker.registerShuffle(10, 2) val compressedSize1000 = MapOutputTracker.compressSize(1000L) val compressedSize10000 = MapOutputTracker.compressSize(10000L) @@ -101,13 +101,13 @@ class MapOutputTrackerSuite extends FunSuite with LocalSparkContext { System.setProperty("spark.hostPort", hostname + ":" + boundPort) val masterTracker = new MapOutputTrackerMaster() - masterTracker.trackerActor = actorSystem.actorOf( - Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker") + masterTracker.trackerActor = Left(actorSystem.actorOf( + Props(new MapOutputTrackerMasterActor(masterTracker)), "MapOutputTracker")) val (slaveSystem, _) = AkkaUtils.createActorSystem("spark-slave", hostname, 0) val slaveTracker = new MapOutputTracker() - slaveTracker.trackerActor = slaveSystem.actorFor( - "akka://spark@localhost:" + boundPort + "/user/MapOutputTracker") + slaveTracker.trackerActor = Right(slaveSystem.actorSelection( + "akka.tcp://spark@localhost:" + boundPort + "/user/MapOutputTracker")) masterTracker.registerShuffle(10, 1) masterTracker.incrementEpoch() diff --git a/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala b/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala deleted file mode 100644 index 21f16ef2c6..0000000000 --- a/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import org.scalatest.FunSuite -import org.apache.spark.SparkContext._ -import org.apache.spark.rdd.{RDD, PartitionPruningRDD} - - -class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext { - - test("Pruned Partitions inherit locality prefs correctly") { - class TestPartition(i: Int) extends Partition { - def index = i - } - val rdd = new RDD[Int](sc, Nil) { - override protected def getPartitions = { - Array[Partition]( - new TestPartition(1), - new TestPartition(2), - new TestPartition(3)) - } - def compute(split: Partition, context: TaskContext) = {Iterator()} - } - val prunedRDD = PartitionPruningRDD.create(rdd, {x => if (x==2) true else false}) - val p = prunedRDD.partitions(0) - assert(p.index == 2) - assert(prunedRDD.partitions.length == 1) - } -} diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala index 7d938917f2..1374d01774 100644 --- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -142,11 +142,11 @@ class PartitioningSuite extends FunSuite with SharedSparkContext { .filter(_ >= 0.0) // Run the partitions, including the consecutive empty ones, through StatCounter - val stats: StatCounter = rdd.stats(); - assert(abs(6.0 - stats.sum) < 0.01); - assert(abs(6.0/2 - rdd.mean) < 0.01); - assert(abs(1.0 - rdd.variance) < 0.01); - assert(abs(1.0 - rdd.stdev) < 0.01); + val stats: StatCounter = rdd.stats() + assert(abs(6.0 - stats.sum) < 0.01) + assert(abs(6.0/2 - rdd.mean) < 0.01) + assert(abs(1.0 - rdd.variance) < 0.01) + assert(abs(1.0 - rdd.stdev) < 0.01) // Add other tests here for classes that should be able to handle empty partitions correctly } diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala new file mode 100644 index 0000000000..151af0d213 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -0,0 +1,140 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.{FunSuite, PrivateMethodTester} + +import org.apache.spark.scheduler.TaskScheduler +import org.apache.spark.scheduler.cluster.{ClusterScheduler, SimrSchedulerBackend, SparkDeploySchedulerBackend} +import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} +import org.apache.spark.scheduler.local.LocalScheduler + +class SparkContextSchedulerCreationSuite + extends FunSuite with PrivateMethodTester with LocalSparkContext with Logging { + + def createTaskScheduler(master: String): TaskScheduler = { + // Create local SparkContext to setup a SparkEnv. We don't actually want to start() the + // real schedulers, so we don't want to create a full SparkContext with the desired scheduler. + sc = new SparkContext("local", "test") + val createTaskSchedulerMethod = PrivateMethod[TaskScheduler]('createTaskScheduler) + SparkContext invokePrivate createTaskSchedulerMethod(sc, master, "test") + } + + test("bad-master") { + val e = intercept[SparkException] { + createTaskScheduler("localhost:1234") + } + assert(e.getMessage.contains("Could not parse Master URL")) + } + + test("local") { + createTaskScheduler("local") match { + case s: LocalScheduler => + assert(s.threads === 1) + assert(s.maxFailures === 0) + case _ => fail() + } + } + + test("local-n") { + createTaskScheduler("local[5]") match { + case s: LocalScheduler => + assert(s.threads === 5) + assert(s.maxFailures === 0) + case _ => fail() + } + } + + test("local-n-failures") { + createTaskScheduler("local[4, 2]") match { + case s: LocalScheduler => + assert(s.threads === 4) + assert(s.maxFailures === 2) + case _ => fail() + } + } + + test("simr") { + createTaskScheduler("simr://uri") match { + case s: ClusterScheduler => + assert(s.backend.isInstanceOf[SimrSchedulerBackend]) + case _ => fail() + } + } + + test("local-cluster") { + createTaskScheduler("local-cluster[3, 14, 512]") match { + case s: ClusterScheduler => + assert(s.backend.isInstanceOf[SparkDeploySchedulerBackend]) + case _ => fail() + } + } + + def testYarn(master: String, expectedClassName: String) { + try { + createTaskScheduler(master) match { + case s: ClusterScheduler => + assert(s.getClass === Class.forName(expectedClassName)) + case _ => fail() + } + } catch { + case e: SparkException => + assert(e.getMessage.contains("YARN mode not available")) + logWarning("YARN not available, could not test actual YARN scheduler creation") + case e: Throwable => fail(e) + } + } + + test("yarn-standalone") { + testYarn("yarn-standalone", "org.apache.spark.scheduler.cluster.YarnClusterScheduler") + } + + test("yarn-client") { + testYarn("yarn-client", "org.apache.spark.scheduler.cluster.YarnClientClusterScheduler") + } + + def testMesos(master: String, expectedClass: Class[_]) { + try { + createTaskScheduler(master) match { + case s: ClusterScheduler => + assert(s.backend.getClass === expectedClass) + case _ => fail() + } + } catch { + case e: UnsatisfiedLinkError => + assert(e.getMessage.contains("no mesos in")) + logWarning("Mesos not available, could not test actual Mesos scheduler creation") + case e: Throwable => fail(e) + } + } + + test("mesos fine-grained") { + System.setProperty("spark.mesos.coarse", "false") + testMesos("mesos://localhost:1234", classOf[MesosSchedulerBackend]) + } + + test("mesos coarse-grained") { + System.setProperty("spark.mesos.coarse", "true") + testMesos("mesos://localhost:1234", classOf[CoarseMesosSchedulerBackend]) + } + + test("mesos with zookeeper") { + System.setProperty("spark.mesos.coarse", "false") + testMesos("zk://localhost:1234,localhost:2345", classOf[MesosSchedulerBackend]) + } +} diff --git a/core/src/test/scala/org/apache/spark/UnpersistSuite.scala b/core/src/test/scala/org/apache/spark/UnpersistSuite.scala index 46a2da1724..768ca3850e 100644 --- a/core/src/test/scala/org/apache/spark/UnpersistSuite.scala +++ b/core/src/test/scala/org/apache/spark/UnpersistSuite.scala @@ -37,7 +37,7 @@ class UnpersistSuite extends FunSuite with LocalSparkContext { Thread.sleep(200) } } catch { - case _ => { Thread.sleep(10) } + case _: Throwable => { Thread.sleep(10) } // Do nothing. We might see exceptions because block manager // is racing this thread to remove entries from the driver. } diff --git a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala index 8f0954122b..4cb4ddc9cd 100644 --- a/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala +++ b/core/src/test/scala/org/apache/spark/deploy/worker/ExecutorRunnerTest.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.deploy.worker import java.io.File diff --git a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala index da032b17d9..0d4c10db8e 100644 --- a/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/AsyncRDDActionsSuite.scala @@ -19,6 +19,8 @@ package org.apache.spark.rdd import java.util.concurrent.Semaphore +import scala.concurrent.{Await, TimeoutException} +import scala.concurrent.duration.Duration import scala.concurrent.ExecutionContext.Implicits.global import org.scalatest.{BeforeAndAfterAll, FunSuite} @@ -173,4 +175,28 @@ class AsyncRDDActionsSuite extends FunSuite with BeforeAndAfterAll with Timeouts sem.acquire(2) } } + + /** + * Awaiting FutureAction results + */ + test("FutureAction result, infinite wait") { + val f = sc.parallelize(1 to 100, 4) + .countAsync() + assert(Await.result(f, Duration.Inf) === 100) + } + + test("FutureAction result, finite wait") { + val f = sc.parallelize(1 to 100, 4) + .countAsync() + assert(Await.result(f, Duration(30, "seconds")) === 100) + } + + test("FutureAction result, timeout") { + val f = sc.parallelize(1 to 100, 4) + .mapPartitions(itr => { Thread.sleep(20); itr }) + .countAsync() + intercept[TimeoutException] { + Await.result(f, Duration(20, "milliseconds")) + } + } } diff --git a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala new file mode 100644 index 0000000000..7f50a5a47c --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala @@ -0,0 +1,271 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.math.abs +import scala.collection.mutable.ArrayBuffer + +import org.scalatest.FunSuite + +import org.apache.spark.SparkContext._ +import org.apache.spark.rdd._ +import org.apache.spark._ + +class DoubleRDDSuite extends FunSuite with SharedSparkContext { + // Verify tests on the histogram functionality. We test with both evenly + // and non-evenly spaced buckets as the bucket lookup function changes. + test("WorksOnEmpty") { + // Make sure that it works on an empty input + val rdd: RDD[Double] = sc.parallelize(Seq()) + val buckets = Array(0.0, 10.0) + val histogramResults = rdd.histogram(buckets) + val histogramResults2 = rdd.histogram(buckets, true) + val expectedHistogramResults = Array(0) + assert(histogramResults === expectedHistogramResults) + assert(histogramResults2 === expectedHistogramResults) + } + + test("WorksWithOutOfRangeWithOneBucket") { + // Verify that if all of the elements are out of range the counts are zero + val rdd = sc.parallelize(Seq(10.01, -0.01)) + val buckets = Array(0.0, 10.0) + val histogramResults = rdd.histogram(buckets) + val histogramResults2 = rdd.histogram(buckets, true) + val expectedHistogramResults = Array(0) + assert(histogramResults === expectedHistogramResults) + assert(histogramResults2 === expectedHistogramResults) + } + + test("WorksInRangeWithOneBucket") { + // Verify the basic case of one bucket and all elements in that bucket works + val rdd = sc.parallelize(Seq(1, 2, 3, 4)) + val buckets = Array(0.0, 10.0) + val histogramResults = rdd.histogram(buckets) + val histogramResults2 = rdd.histogram(buckets, true) + val expectedHistogramResults = Array(4) + assert(histogramResults === expectedHistogramResults) + assert(histogramResults2 === expectedHistogramResults) + } + + test("WorksInRangeWithOneBucketExactMatch") { + // Verify the basic case of one bucket and all elements in that bucket works + val rdd = sc.parallelize(Seq(1, 2, 3, 4)) + val buckets = Array(1.0, 4.0) + val histogramResults = rdd.histogram(buckets) + val histogramResults2 = rdd.histogram(buckets, true) + val expectedHistogramResults = Array(4) + assert(histogramResults === expectedHistogramResults) + assert(histogramResults2 === expectedHistogramResults) + } + + test("WorksWithOutOfRangeWithTwoBuckets") { + // Verify that out of range works with two buckets + val rdd = sc.parallelize(Seq(10.01, -0.01)) + val buckets = Array(0.0, 5.0, 10.0) + val histogramResults = rdd.histogram(buckets) + val histogramResults2 = rdd.histogram(buckets, true) + val expectedHistogramResults = Array(0, 0) + assert(histogramResults === expectedHistogramResults) + assert(histogramResults2 === expectedHistogramResults) + } + + test("WorksWithOutOfRangeWithTwoUnEvenBuckets") { + // Verify that out of range works with two un even buckets + val rdd = sc.parallelize(Seq(10.01, -0.01)) + val buckets = Array(0.0, 4.0, 10.0) + val histogramResults = rdd.histogram(buckets) + val expectedHistogramResults = Array(0, 0) + assert(histogramResults === expectedHistogramResults) + } + + test("WorksInRangeWithTwoBuckets") { + // Make sure that it works with two equally spaced buckets and elements in each + val rdd = sc.parallelize(Seq(1, 2, 3, 5, 6)) + val buckets = Array(0.0, 5.0, 10.0) + val histogramResults = rdd.histogram(buckets) + val histogramResults2 = rdd.histogram(buckets, true) + val expectedHistogramResults = Array(3, 2) + assert(histogramResults === expectedHistogramResults) + assert(histogramResults2 === expectedHistogramResults) + } + + test("WorksInRangeWithTwoBucketsAndNaN") { + // Make sure that it works with two equally spaced buckets and elements in each + val rdd = sc.parallelize(Seq(1, 2, 3, 5, 6, Double.NaN)) + val buckets = Array(0.0, 5.0, 10.0) + val histogramResults = rdd.histogram(buckets) + val histogramResults2 = rdd.histogram(buckets, true) + val expectedHistogramResults = Array(3, 2) + assert(histogramResults === expectedHistogramResults) + assert(histogramResults2 === expectedHistogramResults) + } + + test("WorksInRangeWithTwoUnevenBuckets") { + // Make sure that it works with two unequally spaced buckets and elements in each + val rdd = sc.parallelize(Seq(1, 2, 3, 5, 6)) + val buckets = Array(0.0, 5.0, 11.0) + val histogramResults = rdd.histogram(buckets) + val expectedHistogramResults = Array(3, 2) + assert(histogramResults === expectedHistogramResults) + } + + test("WorksMixedRangeWithTwoUnevenBuckets") { + // Make sure that it works with two unequally spaced buckets and elements in each + val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.0, 11.01)) + val buckets = Array(0.0, 5.0, 11.0) + val histogramResults = rdd.histogram(buckets) + val expectedHistogramResults = Array(4, 3) + assert(histogramResults === expectedHistogramResults) + } + + test("WorksMixedRangeWithFourUnevenBuckets") { + // Make sure that it works with two unequally spaced buckets and elements in each + val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, + 200.0, 200.1)) + val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0) + val histogramResults = rdd.histogram(buckets) + val expectedHistogramResults = Array(4, 2, 1, 3) + assert(histogramResults === expectedHistogramResults) + } + + test("WorksMixedRangeWithUnevenBucketsAndNaN") { + // Make sure that it works with two unequally spaced buckets and elements in each + val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, + 200.0, 200.1, Double.NaN)) + val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0) + val histogramResults = rdd.histogram(buckets) + val expectedHistogramResults = Array(4, 2, 1, 3) + assert(histogramResults === expectedHistogramResults) + } + // Make sure this works with a NaN end bucket + test("WorksMixedRangeWithUnevenBucketsAndNaNAndNaNRange") { + // Make sure that it works with two unequally spaced buckets and elements in each + val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, + 200.0, 200.1, Double.NaN)) + val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0, Double.NaN) + val histogramResults = rdd.histogram(buckets) + val expectedHistogramResults = Array(4, 2, 1, 2, 3) + assert(histogramResults === expectedHistogramResults) + } + // Make sure this works with a NaN end bucket and an inifity + test("WorksMixedRangeWithUnevenBucketsAndNaNAndNaNRangeAndInfity") { + // Make sure that it works with two unequally spaced buckets and elements in each + val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, + 200.0, 200.1, 1.0/0.0, -1.0/0.0, Double.NaN)) + val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0, Double.NaN) + val histogramResults = rdd.histogram(buckets) + val expectedHistogramResults = Array(4, 2, 1, 2, 4) + assert(histogramResults === expectedHistogramResults) + } + + test("WorksWithOutOfRangeWithInfiniteBuckets") { + // Verify that out of range works with two buckets + val rdd = sc.parallelize(Seq(10.01, -0.01, Double.NaN)) + val buckets = Array(-1.0/0.0 , 0.0, 1.0/0.0) + val histogramResults = rdd.histogram(buckets) + val expectedHistogramResults = Array(1, 1) + assert(histogramResults === expectedHistogramResults) + } + // Test the failure mode with an invalid bucket array + test("ThrowsExceptionOnInvalidBucketArray") { + val rdd = sc.parallelize(Seq(1.0)) + // Empty array + intercept[IllegalArgumentException] { + val buckets = Array.empty[Double] + val result = rdd.histogram(buckets) + } + // Single element array + intercept[IllegalArgumentException] { + val buckets = Array(1.0) + val result = rdd.histogram(buckets) + } + } + + // Test automatic histogram function + test("WorksWithoutBucketsBasic") { + // Verify the basic case of one bucket and all elements in that bucket works + val rdd = sc.parallelize(Seq(1, 2, 3, 4)) + val (histogramBuckets, histogramResults) = rdd.histogram(1) + val expectedHistogramResults = Array(4) + val expectedHistogramBuckets = Array(1.0, 4.0) + assert(histogramResults === expectedHistogramResults) + assert(histogramBuckets === expectedHistogramBuckets) + } + // Test automatic histogram function with a single element + test("WorksWithoutBucketsBasicSingleElement") { + // Verify the basic case of one bucket and all elements in that bucket works + val rdd = sc.parallelize(Seq(1)) + val (histogramBuckets, histogramResults) = rdd.histogram(1) + val expectedHistogramResults = Array(1) + val expectedHistogramBuckets = Array(1.0, 1.0) + assert(histogramResults === expectedHistogramResults) + assert(histogramBuckets === expectedHistogramBuckets) + } + // Test automatic histogram function with a single element + test("WorksWithoutBucketsBasicNoRange") { + // Verify the basic case of one bucket and all elements in that bucket works + val rdd = sc.parallelize(Seq(1, 1, 1, 1)) + val (histogramBuckets, histogramResults) = rdd.histogram(1) + val expectedHistogramResults = Array(4) + val expectedHistogramBuckets = Array(1.0, 1.0) + assert(histogramResults === expectedHistogramResults) + assert(histogramBuckets === expectedHistogramBuckets) + } + + test("WorksWithoutBucketsBasicTwo") { + // Verify the basic case of one bucket and all elements in that bucket works + val rdd = sc.parallelize(Seq(1, 2, 3, 4)) + val (histogramBuckets, histogramResults) = rdd.histogram(2) + val expectedHistogramResults = Array(2, 2) + val expectedHistogramBuckets = Array(1.0, 2.5, 4.0) + assert(histogramResults === expectedHistogramResults) + assert(histogramBuckets === expectedHistogramBuckets) + } + + test("WorksWithoutBucketsWithMoreRequestedThanElements") { + // Verify the basic case of one bucket and all elements in that bucket works + val rdd = sc.parallelize(Seq(1, 2)) + val (histogramBuckets, histogramResults) = rdd.histogram(10) + val expectedHistogramResults = + Array(1, 0, 0, 0, 0, 0, 0, 0, 0, 1) + val expectedHistogramBuckets = + Array(1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0) + assert(histogramResults === expectedHistogramResults) + assert(histogramBuckets === expectedHistogramBuckets) + } + + // Test the failure mode with an invalid RDD + test("ThrowsExceptionOnInvalidRDDs") { + // infinity + intercept[UnsupportedOperationException] { + val rdd = sc.parallelize(Seq(1, 1.0/0.0)) + val result = rdd.histogram(1) + } + // NaN + intercept[UnsupportedOperationException] { + val rdd = sc.parallelize(Seq(1, Double.NaN)) + val result = rdd.histogram(1) + } + // Empty + intercept[UnsupportedOperationException] { + val rdd: RDD[Double] = sc.parallelize(Seq()) + val result = rdd.histogram(1) + } + } + +} diff --git a/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala new file mode 100644 index 0000000000..53a7b7c44d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.scalatest.FunSuite +import org.apache.spark.{TaskContext, Partition, SharedSparkContext} + + +class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext { + + + test("Pruned Partitions inherit locality prefs correctly") { + + val rdd = new RDD[Int](sc, Nil) { + override protected def getPartitions = { + Array[Partition]( + new TestPartition(0, 1), + new TestPartition(1, 1), + new TestPartition(2, 1)) + } + + def compute(split: Partition, context: TaskContext) = { + Iterator() + } + } + val prunedRDD = PartitionPruningRDD.create(rdd, { + x => if (x == 2) true else false + }) + assert(prunedRDD.partitions.length == 1) + val p = prunedRDD.partitions(0) + assert(p.index == 0) + assert(p.asInstanceOf[PartitionPruningRDDPartition].parentSplit.index == 2) + } + + + test("Pruned Partitions can be unioned ") { + + val rdd = new RDD[Int](sc, Nil) { + override protected def getPartitions = { + Array[Partition]( + new TestPartition(0, 4), + new TestPartition(1, 5), + new TestPartition(2, 6)) + } + + def compute(split: Partition, context: TaskContext) = { + List(split.asInstanceOf[TestPartition].testValue).iterator + } + } + val prunedRDD1 = PartitionPruningRDD.create(rdd, { + x => if (x == 0) true else false + }) + + val prunedRDD2 = PartitionPruningRDD.create(rdd, { + x => if (x == 2) true else false + }) + + val merged = prunedRDD1 ++ prunedRDD2 + assert(merged.count() == 2) + val take = merged.take(2) + assert(take.apply(0) == 4) + assert(take.apply(1) == 6) + } +} + +class TestPartition(i: Int, value: Int) extends Partition with Serializable { + def index = i + + def testValue = this.value + +} diff --git a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala index 88b36a6855..964ca9e91d 100644 --- a/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/RDDSuite.scala @@ -271,8 +271,8 @@ class RDDSuite extends FunSuite with SharedSparkContext { // test that you get over 90% locality in each group val minLocality = coalesced2.partitions .map(part => part.asInstanceOf[CoalescedRDDPartition].localFraction) - .foldLeft(1.)((perc, loc) => math.min(perc,loc)) - assert(minLocality >= 0.90, "Expected 90% locality but got " + (minLocality*100.).toInt + "%") + .foldLeft(1.0)((perc, loc) => math.min(perc,loc)) + assert(minLocality >= 0.90, "Expected 90% locality but got " + (minLocality*100.0).toInt + "%") // test that the groups are load balanced with 100 +/- 20 elements in each val maxImbalance = coalesced2.partitions @@ -284,9 +284,9 @@ class RDDSuite extends FunSuite with SharedSparkContext { val coalesced3 = data3.coalesce(numMachines*2) val minLocality2 = coalesced3.partitions .map(part => part.asInstanceOf[CoalescedRDDPartition].localFraction) - .foldLeft(1.)((perc, loc) => math.min(perc,loc)) + .foldLeft(1.0)((perc, loc) => math.min(perc,loc)) assert(minLocality2 >= 0.90, "Expected 90% locality for derived RDD but got " + - (minLocality2*100.).toInt + "%") + (minLocality2*100.0).toInt + "%") } test("zipped RDDs") { diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index a4d41ebbff..706d84a58b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -206,6 +206,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont submit(rdd, Array(0)) complete(taskSets(0), List((Success, 42))) assert(results === Map(0 -> 42)) + assertDataStructuresEmpty } test("local job") { @@ -219,6 +220,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont val jobId = scheduler.nextJobId.getAndIncrement() runEvent(JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, null, listener)) assert(results === Map(0 -> 42)) + assertDataStructuresEmpty } test("run trivial job w/ dependency") { @@ -227,6 +229,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont submit(finalRdd, Array(0)) complete(taskSets(0), Seq((Success, 42))) assert(results === Map(0 -> 42)) + assertDataStructuresEmpty } test("cache location preferences w/ dependency") { @@ -239,12 +242,14 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont assertLocations(taskSet, Seq(Seq("hostA", "hostB"))) complete(taskSet, Seq((Success, 42))) assert(results === Map(0 -> 42)) + assertDataStructuresEmpty } test("trivial job failure") { submit(makeRdd(1, Nil), Array(0)) failed(taskSets(0), "some failure") assert(failure.getMessage === "Job aborted: some failure") + assertDataStructuresEmpty } test("run trivial shuffle") { @@ -260,6 +265,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) complete(taskSets(1), Seq((Success, 42))) assert(results === Map(0 -> 42)) + assertDataStructuresEmpty } test("run trivial shuffle with fetch failure") { @@ -285,6 +291,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === Array("hostA", "hostB")) complete(taskSets(3), Seq((Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) + assertDataStructuresEmpty } test("ignore late map task completions") { @@ -313,6 +320,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) complete(taskSets(1), Seq((Success, 42), (Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) + assertDataStructuresEmpty } test("run trivial shuffle with out-of-band failure and retry") { @@ -329,15 +337,16 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostB", 1)))) - // have hostC complete the resubmitted task - complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1)))) - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === - Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) - complete(taskSets(2), Seq((Success, 42))) - assert(results === Map(0 -> 42)) - } - - test("recursive shuffle failures") { + // have hostC complete the resubmitted task + complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1)))) + assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === + Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) + complete(taskSets(2), Seq((Success, 42))) + assert(results === Map(0 -> 42)) + assertDataStructuresEmpty + } + + test("recursive shuffle failures") { val shuffleOneRdd = makeRdd(2, Nil) val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null) val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne)) @@ -363,6 +372,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont complete(taskSets(4), Seq((Success, makeMapStatus("hostA", 1)))) complete(taskSets(5), Seq((Success, 42))) assert(results === Map(0 -> 42)) + assertDataStructuresEmpty } test("cached post-shuffle") { @@ -394,6 +404,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont complete(taskSets(3), Seq((Success, makeMapStatus("hostD", 1)))) complete(taskSets(4), Seq((Success, 42))) assert(results === Map(0 -> 42)) + assertDataStructuresEmpty } /** @@ -413,4 +424,18 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont private def makeBlockManagerId(host: String): BlockManagerId = BlockManagerId("exec-" + host, host, 12345, 0) + private def assertDataStructuresEmpty = { + assert(scheduler.pendingTasks.isEmpty) + assert(scheduler.activeJobs.isEmpty) + assert(scheduler.failed.isEmpty) + assert(scheduler.idToActiveJob.isEmpty) + assert(scheduler.jobIdToStageIds.isEmpty) + assert(scheduler.stageIdToJobIds.isEmpty) + assert(scheduler.stageIdToStage.isEmpty) + assert(scheduler.stageToInfos.isEmpty) + assert(scheduler.resultStageToJob.isEmpty) + assert(scheduler.running.isEmpty) + assert(scheduler.shuffleToMapStage.isEmpty) + assert(scheduler.waiting.isEmpty) + } } diff --git a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala index 984881861c..002368ff55 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.rdd.RDD class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers { + val WAIT_TIMEOUT_MILLIS = 10000 test("inner method") { sc = new SparkContext("local", "joblogger") @@ -92,6 +93,8 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) } rdd.reduceByKey(_+_).collect() + assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + val user = System.getProperty("user.name", SparkContext.SPARK_UNKNOWN_USER) joblogger.getLogDir should be ("/tmp/spark-%s".format(user)) @@ -120,7 +123,9 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers sc.addSparkListener(joblogger) val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) } rdd.reduceByKey(_+_).collect() - + + assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + joblogger.onJobStartCount should be (1) joblogger.onJobEndCount should be (1) joblogger.onTaskEndCount should be (8) diff --git a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala index 1fd76420ea..2e41438a52 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/SparkListenerSuite.scala @@ -145,7 +145,7 @@ class SparkListenerSuite extends FunSuite with LocalSparkContext with ShouldMatc // Make a task whose result is larger than the akka frame size System.setProperty("spark.akka.frameSize", "1") val akkaFrameSize = - sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt + sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x,y) => x) assert(result === 1.to(akkaFrameSize).toArray) diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala index b97f2b19b5..29c4cc5d9c 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala @@ -283,7 +283,7 @@ class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Lo // Fail the task MAX_TASK_FAILURES times, and check that the task set is aborted // after the last failure. - (0 until manager.MAX_TASK_FAILURES).foreach { index => + (1 to manager.MAX_TASK_FAILURES).foreach { index => val offerResult = manager.resourceOffer("exec1", "host1", 1, ANY) assert(offerResult != None, "Expect resource offer on iteration %s to return a task".format(index)) diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala index ee150a3107..27c2d53361 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/TaskResultGetterSuite.scala @@ -82,7 +82,7 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndA test("handling results larger than Akka frame size") { val akkaFrameSize = - sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt + sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x) assert(result === 1.to(akkaFrameSize).toArray) @@ -103,7 +103,7 @@ class TaskResultGetterSuite extends FunSuite with BeforeAndAfter with BeforeAndA } scheduler.taskResultGetter = new ResultDeletingTaskResultGetter(sc.env, scheduler) val akkaFrameSize = - sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.message-frame-size").toInt + sc.env.actorSystem.settings.config.getBytes("akka.remote.netty.tcp.maximum-frame-size").toInt val result = sc.parallelize(Seq(1), 1).map(x => 1.to(akkaFrameSize).toArray).reduce((x, y) => x) assert(result === 1.to(akkaFrameSize).toArray) diff --git a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala index cb76275e39..b647e8a672 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockIdSuite.scala @@ -39,7 +39,7 @@ class BlockIdSuite extends FunSuite { fail() } catch { case e: IllegalStateException => // OK - case _ => fail() + case _: Throwable => fail() } } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 484a654108..5b4d63b954 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -56,7 +56,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT System.setProperty("spark.hostPort", "localhost:" + boundPort) master = new BlockManagerMaster( - actorSystem.actorOf(Props(new BlockManagerMasterActor(true)))) + Left(actorSystem.actorOf(Props(new BlockManagerMasterActor(true))))) // Set the arch to 64-bit and compressedOops to true to get a deterministic test-case oldArch = System.setProperty("os.arch", "amd64") diff --git a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala index 0b9056344c..070982e798 100644 --- a/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/DiskBlockManagerSuite.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.storage import java.io.{FileWriter, File} @@ -5,9 +22,9 @@ import java.io.{FileWriter, File} import scala.collection.mutable import com.google.common.io.Files -import org.scalatest.{BeforeAndAfterEach, FunSuite} +import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite} -class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach { +class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll { val rootDir0 = Files.createTempDir() rootDir0.deleteOnExit() @@ -16,6 +33,12 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach { val rootDirs = rootDir0.getName + "," + rootDir1.getName println("Created root dirs: " + rootDirs) + // This suite focuses primarily on consolidation features, + // so we coerce consolidation if not already enabled. + val consolidateProp = "spark.shuffle.consolidateFiles" + val oldConsolidate = Option(System.getProperty(consolidateProp)) + System.setProperty(consolidateProp, "true") + val shuffleBlockManager = new ShuffleBlockManager(null) { var idToSegmentMap = mutable.Map[ShuffleBlockId, FileSegment]() override def getBlockLocation(id: ShuffleBlockId) = idToSegmentMap(id) @@ -23,6 +46,10 @@ class DiskBlockManagerSuite extends FunSuite with BeforeAndAfterEach { var diskBlockManager: DiskBlockManager = _ + override def afterAll() { + oldConsolidate.map(c => System.setProperty(consolidateProp, c)) + } + override def beforeEach() { diskBlockManager = new DiskBlockManager(shuffleBlockManager, rootDirs) shuffleBlockManager.idToSegmentMap.clear() diff --git a/core/src/test/scala/org/apache/spark/ui/UISuite.scala b/core/src/test/scala/org/apache/spark/ui/UISuite.scala index 8f0ec6683b..3764f4d1a0 100644 --- a/core/src/test/scala/org/apache/spark/ui/UISuite.scala +++ b/core/src/test/scala/org/apache/spark/ui/UISuite.scala @@ -34,7 +34,6 @@ class UISuite extends FunSuite { } val (jettyServer1, boundPort1) = JettyUtils.startJettyServer("localhost", startPort, Seq()) val (jettyServer2, boundPort2) = JettyUtils.startJettyServer("localhost", startPort, Seq()) - // Allow some wiggle room in case ports on the machine are under contention assert(boundPort1 > startPort && boundPort1 < startPort + 10) assert(boundPort2 > boundPort1 && boundPort2 < boundPort1 + 10) diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala index 4e40dcbdee..5aff26f9fc 100644 --- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala @@ -63,54 +63,53 @@ class SizeEstimatorSuite } test("simple classes") { - assert(SizeEstimator.estimate(new DummyClass1) === 16) - assert(SizeEstimator.estimate(new DummyClass2) === 16) - assert(SizeEstimator.estimate(new DummyClass3) === 24) - assert(SizeEstimator.estimate(new DummyClass4(null)) === 24) - assert(SizeEstimator.estimate(new DummyClass4(new DummyClass3)) === 48) + expectResult(16)(SizeEstimator.estimate(new DummyClass1)) + expectResult(16)(SizeEstimator.estimate(new DummyClass2)) + expectResult(24)(SizeEstimator.estimate(new DummyClass3)) + expectResult(24)(SizeEstimator.estimate(new DummyClass4(null))) + expectResult(48)(SizeEstimator.estimate(new DummyClass4(new DummyClass3))) } // NOTE: The String class definition varies across JDK versions (1.6 vs. 1.7) and vendors // (Sun vs IBM). Use a DummyString class to make tests deterministic. test("strings") { - assert(SizeEstimator.estimate(DummyString("")) === 40) - assert(SizeEstimator.estimate(DummyString("a")) === 48) - assert(SizeEstimator.estimate(DummyString("ab")) === 48) - assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 56) + expectResult(40)(SizeEstimator.estimate(DummyString(""))) + expectResult(48)(SizeEstimator.estimate(DummyString("a"))) + expectResult(48)(SizeEstimator.estimate(DummyString("ab"))) + expectResult(56)(SizeEstimator.estimate(DummyString("abcdefgh"))) } test("primitive arrays") { - assert(SizeEstimator.estimate(new Array[Byte](10)) === 32) - assert(SizeEstimator.estimate(new Array[Char](10)) === 40) - assert(SizeEstimator.estimate(new Array[Short](10)) === 40) - assert(SizeEstimator.estimate(new Array[Int](10)) === 56) - assert(SizeEstimator.estimate(new Array[Long](10)) === 96) - assert(SizeEstimator.estimate(new Array[Float](10)) === 56) - assert(SizeEstimator.estimate(new Array[Double](10)) === 96) - assert(SizeEstimator.estimate(new Array[Int](1000)) === 4016) - assert(SizeEstimator.estimate(new Array[Long](1000)) === 8016) + expectResult(32)(SizeEstimator.estimate(new Array[Byte](10))) + expectResult(40)(SizeEstimator.estimate(new Array[Char](10))) + expectResult(40)(SizeEstimator.estimate(new Array[Short](10))) + expectResult(56)(SizeEstimator.estimate(new Array[Int](10))) + expectResult(96)(SizeEstimator.estimate(new Array[Long](10))) + expectResult(56)(SizeEstimator.estimate(new Array[Float](10))) + expectResult(96)(SizeEstimator.estimate(new Array[Double](10))) + expectResult(4016)(SizeEstimator.estimate(new Array[Int](1000))) + expectResult(8016)(SizeEstimator.estimate(new Array[Long](1000))) } test("object arrays") { // Arrays containing nulls should just have one pointer per element - assert(SizeEstimator.estimate(new Array[String](10)) === 56) - assert(SizeEstimator.estimate(new Array[AnyRef](10)) === 56) - + expectResult(56)(SizeEstimator.estimate(new Array[String](10))) + expectResult(56)(SizeEstimator.estimate(new Array[AnyRef](10))) // For object arrays with non-null elements, each object should take one pointer plus // however many bytes that class takes. (Note that Array.fill calls the code in its // second parameter separately for each object, so we get distinct objects.) - assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass1)) === 216) - assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass2)) === 216) - assert(SizeEstimator.estimate(Array.fill(10)(new DummyClass3)) === 296) - assert(SizeEstimator.estimate(Array(new DummyClass1, new DummyClass2)) === 56) + expectResult(216)(SizeEstimator.estimate(Array.fill(10)(new DummyClass1))) + expectResult(216)(SizeEstimator.estimate(Array.fill(10)(new DummyClass2))) + expectResult(296)(SizeEstimator.estimate(Array.fill(10)(new DummyClass3))) + expectResult(56)(SizeEstimator.estimate(Array(new DummyClass1, new DummyClass2))) // Past size 100, our samples 100 elements, but we should still get the right size. - assert(SizeEstimator.estimate(Array.fill(1000)(new DummyClass3)) === 28016) + expectResult(28016)(SizeEstimator.estimate(Array.fill(1000)(new DummyClass3))) // If an array contains the *same* element many times, we should only count it once. val d1 = new DummyClass1 - assert(SizeEstimator.estimate(Array.fill(10)(d1)) === 72) // 10 pointers plus 8-byte object - assert(SizeEstimator.estimate(Array.fill(100)(d1)) === 432) // 100 pointers plus 8-byte object + expectResult(72)(SizeEstimator.estimate(Array.fill(10)(d1))) // 10 pointers plus 8-byte object + expectResult(432)(SizeEstimator.estimate(Array.fill(100)(d1))) // 100 pointers plus 8-byte object // Same thing with huge array containing the same element many times. Note that this won't // return exactly 4032 because it can't tell that *all* the elements will equal the first @@ -128,11 +127,10 @@ class SizeEstimatorSuite val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() - assert(SizeEstimator.estimate(DummyString("")) === 40) - assert(SizeEstimator.estimate(DummyString("a")) === 48) - assert(SizeEstimator.estimate(DummyString("ab")) === 48) - assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 56) - + expectResult(40)(SizeEstimator.estimate(DummyString(""))) + expectResult(48)(SizeEstimator.estimate(DummyString("a"))) + expectResult(48)(SizeEstimator.estimate(DummyString("ab"))) + expectResult(56)(SizeEstimator.estimate(DummyString("abcdefgh"))) resetOrClear("os.arch", arch) } @@ -145,10 +143,10 @@ class SizeEstimatorSuite val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() - assert(SizeEstimator.estimate(DummyString("")) === 56) - assert(SizeEstimator.estimate(DummyString("a")) === 64) - assert(SizeEstimator.estimate(DummyString("ab")) === 64) - assert(SizeEstimator.estimate(DummyString("abcdefgh")) === 72) + expectResult(56)(SizeEstimator.estimate(DummyString(""))) + expectResult(64)(SizeEstimator.estimate(DummyString("a"))) + expectResult(64)(SizeEstimator.estimate(DummyString("ab"))) + expectResult(72)(SizeEstimator.estimate(DummyString("abcdefgh"))) resetOrClear("os.arch", arch) resetOrClear("spark.test.useCompressedOops", oops) diff --git a/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala new file mode 100644 index 0000000000..b78367b6ca --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.util.Random +import org.scalatest.FlatSpec +import org.scalatest.FunSuite +import org.scalatest.matchers.ShouldMatchers +import org.apache.spark.util.Utils.times + +class XORShiftRandomSuite extends FunSuite with ShouldMatchers { + + def fixture = new { + val seed = 1L + val xorRand = new XORShiftRandom(seed) + val hundMil = 1e8.toInt + } + + /* + * This test is based on a chi-squared test for randomness. The values are hard-coded + * so as not to create Spark's dependency on apache.commons.math3 just to call one + * method for calculating the exact p-value for a given number of random numbers + * and bins. In case one would want to move to a full-fledged test based on + * apache.commons.math3, the relevant class is here: + * org.apache.commons.math3.stat.inference.ChiSquareTest + */ + test ("XORShift generates valid random numbers") { + + val f = fixture + + val numBins = 10 + // create 10 bins + val bins = Array.fill(numBins)(0) + + // populate bins based on modulus of the random number + times(f.hundMil) {bins(math.abs(f.xorRand.nextInt) % 10) += 1} + + /* since the seed is deterministic, until the algorithm is changed, we know the result will be + * exactly this: Array(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272, + * 10000790, 10002286, 9998699), so the test will never fail at the prespecified (5%) + * significance level. However, should the RNG implementation change, the test should still + * pass at the same significance level. The chi-squared test done in R gave the following + * results: + * > chisq.test(c(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272, + * 10000790, 10002286, 9998699)) + * Chi-squared test for given probabilities + * data: c(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272, 10000790, + * 10002286, 9998699) + * X-squared = 11.975, df = 9, p-value = 0.2147 + * Note that the p-value was ~0.22. The test will fail if alpha < 0.05, which for 100 million + * random numbers + * and 10 bins will happen at X-squared of ~16.9196. So, the test will fail if X-squared + * is greater than or equal to that number. + */ + val binSize = f.hundMil/numBins + val xSquared = bins.map(x => math.pow((binSize - x), 2)/binSize).sum + xSquared should be < (16.9196) + + } + +}
\ No newline at end of file diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala index ca3f684668..e9b62ea70d 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala @@ -1,9 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.util.collection import scala.collection.mutable.HashSet import org.scalatest.FunSuite - -class OpenHashMapSuite extends FunSuite { +import org.scalatest.matchers.ShouldMatchers +import org.apache.spark.util.SizeEstimator + +class OpenHashMapSuite extends FunSuite with ShouldMatchers { + + test("size for specialized, primitive value (int)") { + val capacity = 1024 + val map = new OpenHashMap[String, Int](capacity) + val actualSize = SizeEstimator.estimate(map) + // 64 bit for pointers, 32 bit for ints, and 1 bit for the bitset. + val expectedSize = capacity * (64 + 32 + 1) / 8 + // Make sure we are not allocating a significant amount of memory beyond our expected. + actualSize should be <= (expectedSize * 1.1).toLong + } test("initialization") { val goodMap1 = new OpenHashMap[String, Int](1) diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala index 4e11e8a628..1b24f8f287 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala @@ -1,9 +1,44 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.util.collection import org.scalatest.FunSuite +import org.scalatest.matchers.ShouldMatchers + +import org.apache.spark.util.SizeEstimator + +class OpenHashSetSuite extends FunSuite with ShouldMatchers { -class OpenHashSetSuite extends FunSuite { + test("size for specialized, primitive int") { + val loadFactor = 0.7 + val set = new OpenHashSet[Int](64, loadFactor) + for (i <- 0 until 1024) { + set.add(i) + } + assert(set.size === 1024) + assert(set.capacity > 1024) + val actualSize = SizeEstimator.estimate(set) + // 32 bits for the ints + 1 bit for the bitset + val expectedSize = set.capacity * (32 + 1) / 8 + // Make sure we are not allocating a significant amount of memory beyond our expected. + actualSize should be <= (expectedSize * 1.1).toLong + } test("primitive int") { val set = new OpenHashSet[Int] diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala index dfd6aed2c4..3b60decee9 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashSetSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala @@ -1,9 +1,38 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.util.collection import scala.collection.mutable.HashSet import org.scalatest.FunSuite +import org.scalatest.matchers.ShouldMatchers +import org.apache.spark.util.SizeEstimator + +class PrimitiveKeyOpenHashMapSuite extends FunSuite with ShouldMatchers { -class PrimitiveKeyOpenHashSetSuite extends FunSuite { + test("size for specialized, primitive key, value (int, int)") { + val capacity = 1024 + val map = new PrimitiveKeyOpenHashMap[Int, Int](capacity) + val actualSize = SizeEstimator.estimate(map) + // 32 bit for keys, 32 bit for values, and 1 bit for the bitset. + val expectedSize = capacity * (32 + 32 + 1) / 8 + // Make sure we are not allocating a significant amount of memory beyond our expected. + actualSize should be <= (expectedSize * 1.1).toLong + } test("initialization") { val goodMap1 = new PrimitiveKeyOpenHashMap[Int, Int](1) |