aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHarvey Feng <harvey@databricks.com>2013-11-26 15:03:03 -0800
committerHarvey Feng <harvey@databricks.com>2013-11-26 15:03:03 -0800
commitafe4fe7f5ed9e5b82b641e72bbfc14f2d952c3be (patch)
treeacf60f9e27cd400fd7dce37419cfbad59177c3cc
parenta1a1c62a3eca61ccbd50a9fb391ab9a380dd1007 (diff)
parentcb976dfb50d62b8a686b6249ff3b24c0837bd613 (diff)
downloadspark-afe4fe7f5ed9e5b82b641e72bbfc14f2d952c3be.tar.gz
spark-afe4fe7f5ed9e5b82b641e72bbfc14f2d952c3be.tar.bz2
spark-afe4fe7f5ed9e5b82b641e72bbfc14f2d952c3be.zip
Merge remote-tracking branch 'origin/master' into yarn-2.2
Conflicts: yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
-rw-r--r--conf/metrics.properties.template8
-rw-r--r--core/pom.xml4
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala25
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala40
-rw-r--r--core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala23
-rw-r--r--core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala82
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala126
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala10
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala41
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala60
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala21
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala9
-rw-r--r--core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala93
-rw-r--r--core/src/main/scala/org/apache/spark/util/Utils.scala24
-rw-r--r--core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala94
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala107
-rw-r--r--core/src/test/scala/org/apache/spark/CheckpointSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/JavaAPISuite.java14
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala271
-rw-r--r--core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala76
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala16
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala20
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala (renamed from core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashSetSuite.scala)14
-rw-r--r--docs/monitoring.md1
-rw-r--r--docs/running-on-yarn.md27
-rw-r--r--docs/tuning.md5
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala11
-rw-r--r--pom.xml5
-rw-r--r--project/SparkBuild.scala1
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala10
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala40
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala246
-rw-r--r--yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala47
-rw-r--r--yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala109
35 files changed, 1466 insertions, 218 deletions
diff --git a/conf/metrics.properties.template b/conf/metrics.properties.template
index ae10f615d1..1c3d94e1b0 100644
--- a/conf/metrics.properties.template
+++ b/conf/metrics.properties.template
@@ -80,6 +80,14 @@
# /metrics/aplications/json # App information
# /metrics/master/json # Master information
+# org.apache.spark.metrics.sink.GraphiteSink
+# Name: Default: Description:
+# host NONE Hostname of Graphite server
+# port NONE Port of Graphite server
+# period 10 Poll period
+# unit seconds Units of poll period
+# prefix EMPTY STRING Prefix to prepend to metric name
+
## Examples
# Enable JmxSink for all instances by class name
#*.sink.jmx.class=org.apache.spark.metrics.sink.JmxSink
diff --git a/core/pom.xml b/core/pom.xml
index 8621d257e5..6af229c71d 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -159,6 +159,10 @@
<artifactId>metrics-ganglia</artifactId>
</dependency>
<dependency>
+ <groupId>com.codahale.metrics</groupId>
+ <artifactId>metrics-graphite</artifactId>
+ </dependency>
+ <dependency>
<groupId>org.apache.derby</groupId>
<artifactId>derby</artifactId>
<scope>test</scope>
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index fad54683bc..c7f60fdcaa 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -226,6 +226,31 @@ class SparkContext(
scheduler.initialize(backend)
scheduler
+ case "yarn-client" =>
+ val scheduler = try {
+ val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClientClusterScheduler")
+ val cons = clazz.getConstructor(classOf[SparkContext])
+ cons.newInstance(this).asInstanceOf[ClusterScheduler]
+
+ } catch {
+ case th: Throwable => {
+ throw new SparkException("YARN mode not available ?", th)
+ }
+ }
+
+ val backend = try {
+ val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend")
+ val cons = clazz.getConstructor(classOf[ClusterScheduler], classOf[SparkContext])
+ cons.newInstance(scheduler, this).asInstanceOf[CoarseGrainedSchedulerBackend]
+ } catch {
+ case th: Throwable => {
+ throw new SparkException("YARN mode not available ?", th)
+ }
+ }
+
+ scheduler.initialize(backend)
+ scheduler
+
case MESOS_REGEX(mesosUrl) =>
MesosNativeLibrary.load()
val scheduler = new ClusterScheduler(this)
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
index 043cb183ba..9f02a9b7d3 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
@@ -26,6 +26,8 @@ import org.apache.spark.storage.StorageLevel
import java.lang.Double
import org.apache.spark.Partitioner
+import scala.collection.JavaConverters._
+
class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, JavaDoubleRDD] {
override val classManifest: ClassManifest[Double] = implicitly[ClassManifest[Double]]
@@ -182,6 +184,44 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
/** (Experimental) Approximate operation to return the sum within a timeout. */
def sumApprox(timeout: Long): PartialResult[BoundedDouble] = srdd.sumApprox(timeout)
+
+ /**
+ * Compute a histogram of the data using bucketCount number of buckets evenly
+ * spaced between the minimum and maximum of the RDD. For example if the min
+ * value is 0 and the max is 100 and there are two buckets the resulting
+ * buckets will be [0,50) [50,100]. bucketCount must be at least 1
+ * If the RDD contains infinity, NaN throws an exception
+ * If the elements in RDD do not vary (max == min) always returns a single bucket.
+ */
+ def histogram(bucketCount: Int): Pair[Array[scala.Double], Array[Long]] = {
+ val result = srdd.histogram(bucketCount)
+ (result._1, result._2)
+ }
+
+ /**
+ * Compute a histogram using the provided buckets. The buckets are all open
+ * to the left except for the last which is closed
+ * e.g. for the array
+ * [1,10,20,50] the buckets are [1,10) [10,20) [20,50]
+ * e.g 1<=x<10 , 10<=x<20, 20<=x<50
+ * And on the input of 1 and 50 we would have a histogram of 1,0,0
+ *
+ * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched
+ * from an O(log n) inseration to O(1) per element. (where n = # buckets) if you set evenBuckets
+ * to true.
+ * buckets must be sorted and not contain any duplicates.
+ * buckets array must be at least two elements
+ * All NaN entries are treated the same. If you have a NaN bucket it must be
+ * the maximum value of the last position and all NaN entries will be counted
+ * in that bucket.
+ */
+ def histogram(buckets: Array[scala.Double]): Array[Long] = {
+ srdd.histogram(buckets, false)
+ }
+
+ def histogram(buckets: Array[Double], evenBuckets: Boolean): Array[Long] = {
+ srdd.histogram(buckets.map(_.toDouble), evenBuckets)
+ }
}
object JavaDoubleRDD {
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index 0b4892f98f..c0ce46e379 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -61,50 +61,53 @@ object TaskMetrics {
class ShuffleReadMetrics extends Serializable {
/**
- * Time when shuffle finishs
+ * Absolute time when this task finished reading shuffle data
*/
var shuffleFinishTime: Long = _
/**
- * Total number of blocks fetched in a shuffle (remote or local)
+ * Number of blocks fetched in this shuffle by this task (remote or local)
*/
var totalBlocksFetched: Int = _
/**
- * Number of remote blocks fetched in a shuffle
+ * Number of remote blocks fetched in this shuffle by this task
*/
var remoteBlocksFetched: Int = _
/**
- * Local blocks fetched in a shuffle
+ * Number of local blocks fetched in this shuffle by this task
*/
var localBlocksFetched: Int = _
/**
- * Total time that is spent blocked waiting for shuffle to fetch data
+ * Time the task spent waiting for remote shuffle blocks. This only includes the time
+ * blocking on shuffle input data. For instance if block B is being fetched while the task is
+ * still not finished processing block A, it is not considered to be blocking on block B.
*/
var fetchWaitTime: Long = _
/**
- * The total amount of time for all the shuffle fetches. This adds up time from overlapping
- * shuffles, so can be longer than task time
+ * Total time spent fetching remote shuffle blocks. This aggregates the time spent fetching all
+ * input blocks. Since block fetches are both pipelined and parallelized, this can
+ * exceed fetchWaitTime and executorRunTime.
*/
var remoteFetchTime: Long = _
/**
- * Total number of remote bytes read from a shuffle
+ * Total number of remote bytes read from the shuffle by this task
*/
var remoteBytesRead: Long = _
}
class ShuffleWriteMetrics extends Serializable {
/**
- * Number of bytes written for a shuffle
+ * Number of bytes written for the shuffle by this task
*/
var shuffleBytesWritten: Long = _
/**
- * Time spent blocking on writes to disk or buffer cache, in nanoseconds.
+ * Time the task spent blocking on writes to disk or buffer cache, in nanoseconds
*/
var shuffleWriteTime: Long = _
}
diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala
new file mode 100644
index 0000000000..cdcfec8ca7
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala
@@ -0,0 +1,82 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.metrics.sink
+
+import java.util.Properties
+import java.util.concurrent.TimeUnit
+import java.net.InetSocketAddress
+
+import com.codahale.metrics.MetricRegistry
+import com.codahale.metrics.graphite.{GraphiteReporter, Graphite}
+
+import org.apache.spark.metrics.MetricsSystem
+
+class GraphiteSink(val property: Properties, val registry: MetricRegistry) extends Sink {
+ val GRAPHITE_DEFAULT_PERIOD = 10
+ val GRAPHITE_DEFAULT_UNIT = "SECONDS"
+ val GRAPHITE_DEFAULT_PREFIX = ""
+
+ val GRAPHITE_KEY_HOST = "host"
+ val GRAPHITE_KEY_PORT = "port"
+ val GRAPHITE_KEY_PERIOD = "period"
+ val GRAPHITE_KEY_UNIT = "unit"
+ val GRAPHITE_KEY_PREFIX = "prefix"
+
+ def propertyToOption(prop: String) = Option(property.getProperty(prop))
+
+ if (!propertyToOption(GRAPHITE_KEY_HOST).isDefined) {
+ throw new Exception("Graphite sink requires 'host' property.")
+ }
+
+ if (!propertyToOption(GRAPHITE_KEY_PORT).isDefined) {
+ throw new Exception("Graphite sink requires 'port' property.")
+ }
+
+ val host = propertyToOption(GRAPHITE_KEY_HOST).get
+ val port = propertyToOption(GRAPHITE_KEY_PORT).get.toInt
+
+ val pollPeriod = propertyToOption(GRAPHITE_KEY_PERIOD) match {
+ case Some(s) => s.toInt
+ case None => GRAPHITE_DEFAULT_PERIOD
+ }
+
+ val pollUnit = propertyToOption(GRAPHITE_KEY_UNIT) match {
+ case Some(s) => TimeUnit.valueOf(s.toUpperCase())
+ case None => TimeUnit.valueOf(GRAPHITE_DEFAULT_UNIT)
+ }
+
+ val prefix = propertyToOption(GRAPHITE_KEY_PREFIX).getOrElse(GRAPHITE_DEFAULT_PREFIX)
+
+ MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod)
+
+ val graphite: Graphite = new Graphite(new InetSocketAddress(host, port))
+
+ val reporter: GraphiteReporter = GraphiteReporter.forRegistry(registry)
+ .convertDurationsTo(TimeUnit.MILLISECONDS)
+ .convertRatesTo(TimeUnit.SECONDS)
+ .prefixedWith(prefix)
+ .build(graphite)
+
+ override def start() {
+ reporter.start(pollPeriod, pollUnit)
+ }
+
+ override def stop() {
+ reporter.stop()
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
index a4bec41752..02d75eccc5 100644
--- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
@@ -24,6 +24,8 @@ import org.apache.spark.partial.SumEvaluator
import org.apache.spark.util.StatCounter
import org.apache.spark.{TaskContext, Logging}
+import scala.collection.immutable.NumericRange
+
/**
* Extra functions available on RDDs of Doubles through an implicit conversion.
* Import `org.apache.spark.SparkContext._` at the top of your program to use these functions.
@@ -76,4 +78,128 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
val evaluator = new SumEvaluator(self.partitions.size, confidence)
self.context.runApproximateJob(self, processPartition, evaluator, timeout)
}
+
+ /**
+ * Compute a histogram of the data using bucketCount number of buckets evenly
+ * spaced between the minimum and maximum of the RDD. For example if the min
+ * value is 0 and the max is 100 and there are two buckets the resulting
+ * buckets will be [0, 50) [50, 100]. bucketCount must be at least 1
+ * If the RDD contains infinity, NaN throws an exception
+ * If the elements in RDD do not vary (max == min) always returns a single bucket.
+ */
+ def histogram(bucketCount: Int): Pair[Array[Double], Array[Long]] = {
+ // Compute the minimum and the maxium
+ val (max: Double, min: Double) = self.mapPartitions { items =>
+ Iterator(items.foldRight(-1/0.0, Double.NaN)((e: Double, x: Pair[Double, Double]) =>
+ (x._1.max(e), x._2.min(e))))
+ }.reduce { (maxmin1, maxmin2) =>
+ (maxmin1._1.max(maxmin2._1), maxmin1._2.min(maxmin2._2))
+ }
+ if (max.isNaN() || max.isInfinity || min.isInfinity ) {
+ throw new UnsupportedOperationException(
+ "Histogram on either an empty RDD or RDD containing +/-infinity or NaN")
+ }
+ val increment = (max-min)/bucketCount.toDouble
+ val range = if (increment != 0) {
+ Range.Double.inclusive(min, max, increment)
+ } else {
+ List(min, min)
+ }
+ val buckets = range.toArray
+ (buckets, histogram(buckets, true))
+ }
+
+ /**
+ * Compute a histogram using the provided buckets. The buckets are all open
+ * to the left except for the last which is closed
+ * e.g. for the array
+ * [1, 10, 20, 50] the buckets are [1, 10) [10, 20) [20, 50]
+ * e.g 1<=x<10 , 10<=x<20, 20<=x<50
+ * And on the input of 1 and 50 we would have a histogram of 1, 0, 0
+ *
+ * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched
+ * from an O(log n) inseration to O(1) per element. (where n = # buckets) if you set evenBuckets
+ * to true.
+ * buckets must be sorted and not contain any duplicates.
+ * buckets array must be at least two elements
+ * All NaN entries are treated the same. If you have a NaN bucket it must be
+ * the maximum value of the last position and all NaN entries will be counted
+ * in that bucket.
+ */
+ def histogram(buckets: Array[Double], evenBuckets: Boolean = false): Array[Long] = {
+ if (buckets.length < 2) {
+ throw new IllegalArgumentException("buckets array must have at least two elements")
+ }
+ // The histogramPartition function computes the partail histogram for a given
+ // partition. The provided bucketFunction determines which bucket in the array
+ // to increment or returns None if there is no bucket. This is done so we can
+ // specialize for uniformly distributed buckets and save the O(log n) binary
+ // search cost.
+ def histogramPartition(bucketFunction: (Double) => Option[Int])(iter: Iterator[Double]):
+ Iterator[Array[Long]] = {
+ val counters = new Array[Long](buckets.length - 1)
+ while (iter.hasNext) {
+ bucketFunction(iter.next()) match {
+ case Some(x: Int) => {counters(x) += 1}
+ case _ => {}
+ }
+ }
+ Iterator(counters)
+ }
+ // Merge the counters.
+ def mergeCounters(a1: Array[Long], a2: Array[Long]): Array[Long] = {
+ a1.indices.foreach(i => a1(i) += a2(i))
+ a1
+ }
+ // Basic bucket function. This works using Java's built in Array
+ // binary search. Takes log(size(buckets))
+ def basicBucketFunction(e: Double): Option[Int] = {
+ val location = java.util.Arrays.binarySearch(buckets, e)
+ if (location < 0) {
+ // If the location is less than 0 then the insertion point in the array
+ // to keep it sorted is -location-1
+ val insertionPoint = -location-1
+ // If we have to insert before the first element or after the last one
+ // its out of bounds.
+ // We do this rather than buckets.lengthCompare(insertionPoint)
+ // because Array[Double] fails to override it (for now).
+ if (insertionPoint > 0 && insertionPoint < buckets.length) {
+ Some(insertionPoint-1)
+ } else {
+ None
+ }
+ } else if (location < buckets.length - 1) {
+ // Exact match, just insert here
+ Some(location)
+ } else {
+ // Exact match to the last element
+ Some(location - 1)
+ }
+ }
+ // Determine the bucket function in constant time. Requires that buckets are evenly spaced
+ def fastBucketFunction(min: Double, increment: Double, count: Int)(e: Double): Option[Int] = {
+ // If our input is not a number unless the increment is also NaN then we fail fast
+ if (e.isNaN()) {
+ return None
+ }
+ val bucketNumber = (e - min)/(increment)
+ // We do this rather than buckets.lengthCompare(bucketNumber)
+ // because Array[Double] fails to override it (for now).
+ if (bucketNumber > count || bucketNumber < 0) {
+ None
+ } else {
+ Some(bucketNumber.toInt.min(count - 1))
+ }
+ }
+ // Decide which bucket function to pass to histogramPartition. We decide here
+ // rather than having a general function so that the decission need only be made
+ // once rather than once per shard
+ val bucketFunction = if (evenBuckets) {
+ fastBucketFunction(buckets(0), buckets(1)-buckets(0), buckets.length-1) _
+ } else {
+ basicBucketFunction _
+ }
+ self.mapPartitions(histogramPartition(bucketFunction)).reduce(mergeCounters)
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
index 203179c4ea..ae70d55951 100644
--- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
@@ -20,18 +20,16 @@ package org.apache.spark.rdd
import org.apache.spark.{Partition, TaskContext}
-private[spark]
-class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
+private[spark] class MapPartitionsRDD[U: ClassManifest, T: ClassManifest](
prev: RDD[T],
- f: Iterator[T] => Iterator[U],
+ f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator)
preservesPartitioning: Boolean = false)
extends RDD[U](prev) {
- override val partitioner =
- if (preservesPartitioning) firstParent[T].partitioner else None
+ override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None
override def getPartitions: Array[Partition] = firstParent[T].partitions
override def compute(split: Partition, context: TaskContext) =
- f(firstParent[T].iterator(split, context))
+ f(context, split.index, firstParent[T].iterator(split, context))
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala
deleted file mode 100644
index aea08ff81b..0000000000
--- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala
+++ /dev/null
@@ -1,41 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.rdd
-
-import org.apache.spark.{Partition, TaskContext}
-
-
-/**
- * A variant of the MapPartitionsRDD that passes the TaskContext into the closure. From the
- * TaskContext, the closure can either get access to the interruptible flag or get the index
- * of the partition in the RDD.
- */
-private[spark]
-class MapPartitionsWithContextRDD[U: ClassManifest, T: ClassManifest](
- prev: RDD[T],
- f: (TaskContext, Iterator[T]) => Iterator[U],
- preservesPartitioning: Boolean
- ) extends RDD[U](prev) {
-
- override def getPartitions: Array[Partition] = firstParent[T].partitions
-
- override val partitioner = if (preservesPartitioning) prev.partitioner else None
-
- override def compute(split: Partition, context: TaskContext) =
- f(context, firstParent[T].iterator(split, context))
-}
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index 6e88be6f6a..5b1285307d 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -408,7 +408,6 @@ abstract class RDD[T: ClassManifest](
def pipe(command: String, env: Map[String, String]): RDD[String] =
new PipedRDD(this, command, env)
-
/**
* Return an RDD created by piping elements to a forked external process.
* The print behavior can be customized by providing two functions.
@@ -442,7 +441,8 @@ abstract class RDD[T: ClassManifest](
*/
def mapPartitions[U: ClassManifest](
f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
- new MapPartitionsRDD(this, sc.clean(f), preservesPartitioning)
+ val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(iter)
+ new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning)
}
/**
@@ -451,8 +451,8 @@ abstract class RDD[T: ClassManifest](
*/
def mapPartitionsWithIndex[U: ClassManifest](
f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
- val func = (context: TaskContext, iter: Iterator[T]) => f(context.partitionId, iter)
- new MapPartitionsWithContextRDD(this, sc.clean(func), preservesPartitioning)
+ val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter)
+ new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning)
}
/**
@@ -462,7 +462,8 @@ abstract class RDD[T: ClassManifest](
def mapPartitionsWithContext[U: ClassManifest](
f: (TaskContext, Iterator[T]) => Iterator[U],
preservesPartitioning: Boolean = false): RDD[U] = {
- new MapPartitionsWithContextRDD(this, sc.clean(f), preservesPartitioning)
+ val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(context, iter)
+ new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning)
}
/**
@@ -483,11 +484,10 @@ abstract class RDD[T: ClassManifest](
def mapWith[A: ClassManifest, U: ClassManifest]
(constructA: Int => A, preservesPartitioning: Boolean = false)
(f: (T, A) => U): RDD[U] = {
- def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = {
- val a = constructA(context.partitionId)
+ mapPartitionsWithIndex((index, iter) => {
+ val a = constructA(index)
iter.map(t => f(t, a))
- }
- new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning)
+ }, preservesPartitioning)
}
/**
@@ -498,11 +498,10 @@ abstract class RDD[T: ClassManifest](
def flatMapWith[A: ClassManifest, U: ClassManifest]
(constructA: Int => A, preservesPartitioning: Boolean = false)
(f: (T, A) => Seq[U]): RDD[U] = {
- def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = {
- val a = constructA(context.partitionId)
+ mapPartitionsWithIndex((index, iter) => {
+ val a = constructA(index)
iter.flatMap(t => f(t, a))
- }
- new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning)
+ }, preservesPartitioning)
}
/**
@@ -511,11 +510,10 @@ abstract class RDD[T: ClassManifest](
* partition with the index of that partition.
*/
def foreachWith[A: ClassManifest](constructA: Int => A)(f: (T, A) => Unit) {
- def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = {
- val a = constructA(context.partitionId)
+ mapPartitionsWithIndex { (index, iter) =>
+ val a = constructA(index)
iter.map(t => {f(t, a); t})
- }
- new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true).foreach(_ => {})
+ }.foreach(_ => {})
}
/**
@@ -524,11 +522,10 @@ abstract class RDD[T: ClassManifest](
* partition with the index of that partition.
*/
def filterWith[A: ClassManifest](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = {
- def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = {
- val a = constructA(context.partitionId)
+ mapPartitionsWithIndex((index, iter) => {
+ val a = constructA(index)
iter.filter(t => p(t, a))
- }
- new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true)
+ }, preservesPartitioning = true)
}
/**
@@ -546,19 +543,34 @@ abstract class RDD[T: ClassManifest](
* of elements in each partition.
*/
def zipPartitions[B: ClassManifest, V: ClassManifest]
+ (rdd2: RDD[B], preservesPartitioning: Boolean)
+ (f: (Iterator[T], Iterator[B]) => Iterator[V]): RDD[V] =
+ new ZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2, preservesPartitioning)
+
+ def zipPartitions[B: ClassManifest, V: ClassManifest]
(rdd2: RDD[B])
(f: (Iterator[T], Iterator[B]) => Iterator[V]): RDD[V] =
- new ZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2)
+ new ZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2, false)
+
+ def zipPartitions[B: ClassManifest, C: ClassManifest, V: ClassManifest]
+ (rdd2: RDD[B], rdd3: RDD[C], preservesPartitioning: Boolean)
+ (f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V]): RDD[V] =
+ new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3, preservesPartitioning)
def zipPartitions[B: ClassManifest, C: ClassManifest, V: ClassManifest]
(rdd2: RDD[B], rdd3: RDD[C])
(f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V]): RDD[V] =
- new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3)
+ new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3, false)
+
+ def zipPartitions[B: ClassManifest, C: ClassManifest, D: ClassManifest, V: ClassManifest]
+ (rdd2: RDD[B], rdd3: RDD[C], rdd4: RDD[D], preservesPartitioning: Boolean)
+ (f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V]): RDD[V] =
+ new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4, preservesPartitioning)
def zipPartitions[B: ClassManifest, C: ClassManifest, D: ClassManifest, V: ClassManifest]
(rdd2: RDD[B], rdd3: RDD[C], rdd4: RDD[D])
(f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V]): RDD[V] =
- new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4)
+ new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4, false)
// Actions (launch a job to return a value to the user program)
diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
index 31e6fd519d..faeb316664 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
@@ -39,9 +39,13 @@ private[spark] class ZippedPartitionsPartition(
abstract class ZippedPartitionsBaseRDD[V: ClassManifest](
sc: SparkContext,
- var rdds: Seq[RDD[_]])
+ var rdds: Seq[RDD[_]],
+ preservesPartitioning: Boolean = false)
extends RDD[V](sc, rdds.map(x => new OneToOneDependency(x))) {
+ override val partitioner =
+ if (preservesPartitioning) firstParent[Any].partitioner else None
+
override def getPartitions: Array[Partition] = {
val sizes = rdds.map(x => x.partitions.size)
if (!sizes.forall(x => x == sizes(0))) {
@@ -76,8 +80,9 @@ class ZippedPartitionsRDD2[A: ClassManifest, B: ClassManifest, V: ClassManifest]
sc: SparkContext,
f: (Iterator[A], Iterator[B]) => Iterator[V],
var rdd1: RDD[A],
- var rdd2: RDD[B])
- extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2)) {
+ var rdd2: RDD[B],
+ preservesPartitioning: Boolean = false)
+ extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2), preservesPartitioning) {
override def compute(s: Partition, context: TaskContext): Iterator[V] = {
val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
@@ -97,8 +102,9 @@ class ZippedPartitionsRDD3
f: (Iterator[A], Iterator[B], Iterator[C]) => Iterator[V],
var rdd1: RDD[A],
var rdd2: RDD[B],
- var rdd3: RDD[C])
- extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3)) {
+ var rdd3: RDD[C],
+ preservesPartitioning: Boolean = false)
+ extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3), preservesPartitioning) {
override def compute(s: Partition, context: TaskContext): Iterator[V] = {
val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
@@ -122,8 +128,9 @@ class ZippedPartitionsRDD4
var rdd1: RDD[A],
var rdd2: RDD[B],
var rdd3: RDD[C],
- var rdd4: RDD[D])
- extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4)) {
+ var rdd4: RDD[D],
+ preservesPartitioning: Boolean = false)
+ extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4), preservesPartitioning) {
override def compute(s: Partition, context: TaskContext): Iterator[V] = {
val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions
diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
index 1dc71a0428..0f2deb4bcb 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala
@@ -167,6 +167,7 @@ private[spark] class ShuffleMapTask(
var totalTime = 0L
val compressedSizes: Array[Byte] = shuffle.writers.map { writer: BlockObjectWriter =>
writer.commit()
+ writer.close()
val size = writer.fileSegment().length
totalBytes += size
totalTime += writer.timeWriting()
@@ -184,14 +185,16 @@ private[spark] class ShuffleMapTask(
} catch { case e: Exception =>
// If there is an exception from running the task, revert the partial writes
// and throw the exception upstream to Spark.
- if (shuffle != null) {
- shuffle.writers.foreach(_.revertPartialWrites())
+ if (shuffle != null && shuffle.writers != null) {
+ for (writer <- shuffle.writers) {
+ writer.revertPartialWrites()
+ writer.close()
+ }
}
throw e
} finally {
// Release the writers back to the shuffle block manager.
if (shuffle != null && shuffle.writers != null) {
- shuffle.writers.foreach(_.close())
shuffle.releaseWriters(success)
}
// Execute the callbacks on task completion.
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
index 469e68fed7..b4451fc7b8 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala
@@ -93,6 +93,8 @@ class DiskBlockObjectWriter(
def write(i: Int): Unit = callWithTiming(out.write(i))
override def write(b: Array[Byte]) = callWithTiming(out.write(b))
override def write(b: Array[Byte], off: Int, len: Int) = callWithTiming(out.write(b, off, len))
+ override def close() = out.close()
+ override def flush() = out.flush()
}
private val syncWrites = System.getProperty("spark.shuffle.sync", "false").toBoolean
diff --git a/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala
index f60deafc6f..8bb4ee3bfa 100644
--- a/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala
+++ b/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala
@@ -35,6 +35,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
private var capacity = nextPowerOf2(initialCapacity)
private var mask = capacity - 1
private var curSize = 0
+ private var growThreshold = LOAD_FACTOR * capacity
// Holds keys and values in the same array for memory locality; specifically, the order of
// elements is key0, value0, key1, value1, key2, value2, etc.
@@ -56,7 +57,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
var i = 1
while (true) {
val curKey = data(2 * pos)
- if (k.eq(curKey) || k == curKey) {
+ if (k.eq(curKey) || k.equals(curKey)) {
return data(2 * pos + 1).asInstanceOf[V]
} else if (curKey.eq(null)) {
return null.asInstanceOf[V]
@@ -80,9 +81,23 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
haveNullValue = true
return
}
- val isNewEntry = putInto(data, k, value.asInstanceOf[AnyRef])
- if (isNewEntry) {
- incrementSize()
+ var pos = rehash(key.hashCode) & mask
+ var i = 1
+ while (true) {
+ val curKey = data(2 * pos)
+ if (curKey.eq(null)) {
+ data(2 * pos) = k
+ data(2 * pos + 1) = value.asInstanceOf[AnyRef]
+ incrementSize() // Since we added a new key
+ return
+ } else if (k.eq(curKey) || k.equals(curKey)) {
+ data(2 * pos + 1) = value.asInstanceOf[AnyRef]
+ return
+ } else {
+ val delta = i
+ pos = (pos + delta) & mask
+ i += 1
+ }
}
}
@@ -104,7 +119,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
var i = 1
while (true) {
val curKey = data(2 * pos)
- if (k.eq(curKey) || k == curKey) {
+ if (k.eq(curKey) || k.equals(curKey)) {
val newValue = updateFunc(true, data(2 * pos + 1).asInstanceOf[V])
data(2 * pos + 1) = newValue.asInstanceOf[AnyRef]
return newValue
@@ -161,45 +176,17 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
/** Increase table size by 1, rehashing if necessary */
private def incrementSize() {
curSize += 1
- if (curSize > LOAD_FACTOR * capacity) {
+ if (curSize > growThreshold) {
growTable()
}
}
/**
- * Re-hash a value to deal better with hash functions that don't differ
- * in the lower bits, similar to java.util.HashMap
+ * Re-hash a value to deal better with hash functions that don't differ in the lower bits.
+ * We use the Murmur Hash 3 finalization step that's also used in fastutil.
*/
private def rehash(h: Int): Int = {
- val r = h ^ (h >>> 20) ^ (h >>> 12)
- r ^ (r >>> 7) ^ (r >>> 4)
- }
-
- /**
- * Put an entry into a table represented by data, returning true if
- * this increases the size of the table or false otherwise. Assumes
- * that "data" has at least one empty slot.
- */
- private def putInto(data: Array[AnyRef], key: AnyRef, value: AnyRef): Boolean = {
- val mask = (data.length / 2) - 1
- var pos = rehash(key.hashCode) & mask
- var i = 1
- while (true) {
- val curKey = data(2 * pos)
- if (curKey.eq(null)) {
- data(2 * pos) = key
- data(2 * pos + 1) = value.asInstanceOf[AnyRef]
- return true
- } else if (curKey.eq(key) || curKey == key) {
- data(2 * pos + 1) = value.asInstanceOf[AnyRef]
- return false
- } else {
- val delta = i
- pos = (pos + delta) & mask
- i += 1
- }
- }
- return false // Never reached but needed to keep compiler happy
+ it.unimi.dsi.fastutil.HashCommon.murmurHash3(h)
}
/** Double the table's size and re-hash everything */
@@ -211,16 +198,36 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi
throw new Exception("Can't make capacity bigger than 2^29 elements")
}
val newData = new Array[AnyRef](2 * newCapacity)
- var pos = 0
- while (pos < capacity) {
- if (!data(2 * pos).eq(null)) {
- putInto(newData, data(2 * pos), data(2 * pos + 1))
+ val newMask = newCapacity - 1
+ // Insert all our old values into the new array. Note that because our old keys are
+ // unique, there's no need to check for equality here when we insert.
+ var oldPos = 0
+ while (oldPos < capacity) {
+ if (!data(2 * oldPos).eq(null)) {
+ val key = data(2 * oldPos)
+ val value = data(2 * oldPos + 1)
+ var newPos = rehash(key.hashCode) & newMask
+ var i = 1
+ var keepGoing = true
+ while (keepGoing) {
+ val curKey = newData(2 * newPos)
+ if (curKey.eq(null)) {
+ newData(2 * newPos) = key
+ newData(2 * newPos + 1) = value
+ keepGoing = false
+ } else {
+ val delta = i
+ newPos = (newPos + delta) & newMask
+ i += 1
+ }
+ }
}
- pos += 1
+ oldPos += 1
}
data = newData
capacity = newCapacity
- mask = newCapacity - 1
+ mask = newMask
+ growThreshold = LOAD_FACTOR * newCapacity
}
private def nextPowerOf2(n: Int): Int = {
diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala
index fe932d8ede..a79e64e810 100644
--- a/core/src/main/scala/org/apache/spark/util/Utils.scala
+++ b/core/src/main/scala/org/apache/spark/util/Utils.scala
@@ -823,4 +823,28 @@ private[spark] object Utils extends Logging {
return System.getProperties().clone()
.asInstanceOf[java.util.Properties].toMap[String, String]
}
+
+ /**
+ * Method executed for repeating a task for side effects.
+ * Unlike a for comprehension, it permits JVM JIT optimization
+ */
+ def times(numIters: Int)(f: => Unit): Unit = {
+ var i = 0
+ while (i < numIters) {
+ f
+ i += 1
+ }
+ }
+
+ /**
+ * Timing method based on iterations that permit JVM JIT optimization.
+ * @param numIters number of iterations
+ * @param f function to be executed
+ */
+ def timeIt(numIters: Int)(f: => Unit): Long = {
+ val start = System.currentTimeMillis
+ times(numIters)(f)
+ System.currentTimeMillis - start
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala
new file mode 100644
index 0000000000..e9907e6c85
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala
@@ -0,0 +1,94 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.util.{Random => JavaRandom}
+import org.apache.spark.util.Utils.timeIt
+
+/**
+ * This class implements a XORShift random number generator algorithm
+ * Source:
+ * Marsaglia, G. (2003). Xorshift RNGs. Journal of Statistical Software, Vol. 8, Issue 14.
+ * @see <a href="http://www.jstatsoft.org/v08/i14/paper">Paper</a>
+ * This implementation is approximately 3.5 times faster than
+ * {@link java.util.Random java.util.Random}, partly because of the algorithm, but also due
+ * to renouncing thread safety. JDK's implementation uses an AtomicLong seed, this class
+ * uses a regular Long. We can forgo thread safety since we use a new instance of the RNG
+ * for each thread.
+ */
+private[spark] class XORShiftRandom(init: Long) extends JavaRandom(init) {
+
+ def this() = this(System.nanoTime)
+
+ private var seed = init
+
+ // we need to just override next - this will be called by nextInt, nextDouble,
+ // nextGaussian, nextLong, etc.
+ override protected def next(bits: Int): Int = {
+ var nextSeed = seed ^ (seed << 21)
+ nextSeed ^= (nextSeed >>> 35)
+ nextSeed ^= (nextSeed << 4)
+ seed = nextSeed
+ (nextSeed & ((1L << bits) -1)).asInstanceOf[Int]
+ }
+}
+
+/** Contains benchmark method and main method to run benchmark of the RNG */
+private[spark] object XORShiftRandom {
+
+ /**
+ * Main method for running benchmark
+ * @param args takes one argument - the number of random numbers to generate
+ */
+ def main(args: Array[String]): Unit = {
+ if (args.length != 1) {
+ println("Benchmark of XORShiftRandom vis-a-vis java.util.Random")
+ println("Usage: XORShiftRandom number_of_random_numbers_to_generate")
+ System.exit(1)
+ }
+ println(benchmark(args(0).toInt))
+ }
+
+ /**
+ * @param numIters Number of random numbers to generate while running the benchmark
+ * @return Map of execution times for {@link java.util.Random java.util.Random}
+ * and XORShift
+ */
+ def benchmark(numIters: Int) = {
+
+ val seed = 1L
+ val million = 1e6.toInt
+ val javaRand = new JavaRandom(seed)
+ val xorRand = new XORShiftRandom(seed)
+
+ // this is just to warm up the JIT - we're not timing anything
+ timeIt(1e6.toInt) {
+ javaRand.nextInt()
+ xorRand.nextInt()
+ }
+
+ val iters = timeIt(numIters)(_)
+
+ /* Return results as a map instead of just printing to screen
+ in case the user wants to do something with them */
+ Map("javaTime" -> iters {javaRand.nextInt()},
+ "xorTime" -> iters {xorRand.nextInt()})
+
+ }
+
+} \ No newline at end of file
diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
index 4592e4f939..40986e3731 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
@@ -79,6 +79,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest](
protected var _capacity = nextPowerOf2(initialCapacity)
protected var _mask = _capacity - 1
protected var _size = 0
+ protected var _growThreshold = (loadFactor * _capacity).toInt
protected var _bitset = new BitSet(_capacity)
@@ -115,7 +116,29 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest](
* @return The position where the key is placed, plus the highest order bit is set if the key
* exists previously.
*/
- def addWithoutResize(k: T): Int = putInto(_bitset, _data, k)
+ def addWithoutResize(k: T): Int = {
+ var pos = hashcode(hasher.hash(k)) & _mask
+ var i = 1
+ while (true) {
+ if (!_bitset.get(pos)) {
+ // This is a new key.
+ _data(pos) = k
+ _bitset.set(pos)
+ _size += 1
+ return pos | NONEXISTENCE_MASK
+ } else if (_data(pos) == k) {
+ // Found an existing key.
+ return pos
+ } else {
+ val delta = i
+ pos = (pos + delta) & _mask
+ i += 1
+ }
+ }
+ // Never reached here
+ assert(INVALID_POS != INVALID_POS)
+ INVALID_POS
+ }
/**
* Rehash the set if it is overloaded.
@@ -126,7 +149,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest](
* to a new position (in the new data array).
*/
def rehashIfNeeded(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) {
- if (_size > loadFactor * _capacity) {
+ if (_size > _growThreshold) {
rehash(k, allocateFunc, moveFunc)
}
}
@@ -161,37 +184,6 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest](
def nextPos(fromPos: Int): Int = _bitset.nextSetBit(fromPos)
/**
- * Put an entry into the set. Return the position where the key is placed. In addition, the
- * highest bit in the returned position is set if the key exists prior to this put.
- *
- * This function assumes the data array has at least one empty slot.
- */
- private def putInto(bitset: BitSet, data: Array[T], k: T): Int = {
- val mask = data.length - 1
- var pos = hashcode(hasher.hash(k)) & mask
- var i = 1
- while (true) {
- if (!bitset.get(pos)) {
- // This is a new key.
- data(pos) = k
- bitset.set(pos)
- _size += 1
- return pos | NONEXISTENCE_MASK
- } else if (data(pos) == k) {
- // Found an existing key.
- return pos
- } else {
- val delta = i
- pos = (pos + delta) & mask
- i += 1
- }
- }
- // Never reached here
- assert(INVALID_POS != INVALID_POS)
- INVALID_POS
- }
-
- /**
* Double the table's size and re-hash everything. We are not really using k, but it is declared
* so Scala compiler can specialize this method (which leads to calling the specialized version
* of putInto).
@@ -204,34 +196,49 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest](
*/
private def rehash(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) {
val newCapacity = _capacity * 2
- require(newCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements")
-
allocateFunc(newCapacity)
- val newData = new Array[T](newCapacity)
val newBitset = new BitSet(newCapacity)
- var pos = 0
- _size = 0
- while (pos < _capacity) {
- if (_bitset.get(pos)) {
- val newPos = putInto(newBitset, newData, _data(pos))
- moveFunc(pos, newPos & POSITION_MASK)
+ val newData = new Array[T](newCapacity)
+ val newMask = newCapacity - 1
+
+ var oldPos = 0
+ while (oldPos < capacity) {
+ if (_bitset.get(oldPos)) {
+ val key = _data(oldPos)
+ var newPos = hashcode(hasher.hash(key)) & newMask
+ var i = 1
+ var keepGoing = true
+ // No need to check for equality here when we insert so this has one less if branch than
+ // the similar code path in addWithoutResize.
+ while (keepGoing) {
+ if (!newBitset.get(newPos)) {
+ // Inserting the key at newPos
+ newData(newPos) = key
+ newBitset.set(newPos)
+ moveFunc(oldPos, newPos)
+ keepGoing = false
+ } else {
+ val delta = i
+ newPos = (newPos + delta) & newMask
+ i += 1
+ }
+ }
}
- pos += 1
+ oldPos += 1
}
+
_bitset = newBitset
_data = newData
_capacity = newCapacity
- _mask = newCapacity - 1
+ _mask = newMask
+ _growThreshold = (loadFactor * newCapacity).toInt
}
/**
- * Re-hash a value to deal better with hash functions that don't differ
- * in the lower bits, similar to java.util.HashMap
+ * Re-hash a value to deal better with hash functions that don't differ in the lower bits.
+ * We use the Murmur Hash 3 finalization step that's also used in fastutil.
*/
- private def hashcode(h: Int): Int = {
- val r = h ^ (h >>> 20) ^ (h >>> 12)
- r ^ (r >>> 7) ^ (r >>> 4)
- }
+ private def hashcode(h: Int): Int = it.unimi.dsi.fastutil.HashCommon.murmurHash3(h)
private def nextPowerOf2(n: Int): Int = {
val highBit = Integer.highestOneBit(n)
diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
index f26c44d3e7..d2226aa5a5 100644
--- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
@@ -62,8 +62,6 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
testCheckpointing(_.sample(false, 0.5, 0))
testCheckpointing(_.glom())
testCheckpointing(_.mapPartitions(_.map(_.toString)))
- testCheckpointing(r => new MapPartitionsWithContextRDD(r,
- (context: TaskContext, iter: Iterator[Int]) => iter.map(_.toString), false ))
testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString))
testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x))
testCheckpointing(_.pipe(Seq("cat")))
diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
index 352036f182..4234f6eac7 100644
--- a/core/src/test/scala/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
@@ -365,6 +365,20 @@ public class JavaAPISuite implements Serializable {
}
@Test
+ public void javaDoubleRDDHistoGram() {
+ JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0));
+ // Test using generated buckets
+ Tuple2<double[], long[]> results = rdd.histogram(2);
+ double[] expected_buckets = {1.0, 2.5, 4.0};
+ long[] expected_counts = {2, 2};
+ Assert.assertArrayEquals(expected_buckets, results._1, 0.1);
+ Assert.assertArrayEquals(expected_counts, results._2);
+ // Test with provided buckets
+ long[] histogram = rdd.histogram(expected_buckets);
+ Assert.assertArrayEquals(expected_counts, histogram);
+ }
+
+ @Test
public void map() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
JavaDoubleRDD doubles = rdd.map(new DoubleFunction<Integer>() {
diff --git a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala
new file mode 100644
index 0000000000..7f50a5a47c
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala
@@ -0,0 +1,271 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import scala.math.abs
+import scala.collection.mutable.ArrayBuffer
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.SparkContext._
+import org.apache.spark.rdd._
+import org.apache.spark._
+
+class DoubleRDDSuite extends FunSuite with SharedSparkContext {
+ // Verify tests on the histogram functionality. We test with both evenly
+ // and non-evenly spaced buckets as the bucket lookup function changes.
+ test("WorksOnEmpty") {
+ // Make sure that it works on an empty input
+ val rdd: RDD[Double] = sc.parallelize(Seq())
+ val buckets = Array(0.0, 10.0)
+ val histogramResults = rdd.histogram(buckets)
+ val histogramResults2 = rdd.histogram(buckets, true)
+ val expectedHistogramResults = Array(0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramResults2 === expectedHistogramResults)
+ }
+
+ test("WorksWithOutOfRangeWithOneBucket") {
+ // Verify that if all of the elements are out of range the counts are zero
+ val rdd = sc.parallelize(Seq(10.01, -0.01))
+ val buckets = Array(0.0, 10.0)
+ val histogramResults = rdd.histogram(buckets)
+ val histogramResults2 = rdd.histogram(buckets, true)
+ val expectedHistogramResults = Array(0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramResults2 === expectedHistogramResults)
+ }
+
+ test("WorksInRangeWithOneBucket") {
+ // Verify the basic case of one bucket and all elements in that bucket works
+ val rdd = sc.parallelize(Seq(1, 2, 3, 4))
+ val buckets = Array(0.0, 10.0)
+ val histogramResults = rdd.histogram(buckets)
+ val histogramResults2 = rdd.histogram(buckets, true)
+ val expectedHistogramResults = Array(4)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramResults2 === expectedHistogramResults)
+ }
+
+ test("WorksInRangeWithOneBucketExactMatch") {
+ // Verify the basic case of one bucket and all elements in that bucket works
+ val rdd = sc.parallelize(Seq(1, 2, 3, 4))
+ val buckets = Array(1.0, 4.0)
+ val histogramResults = rdd.histogram(buckets)
+ val histogramResults2 = rdd.histogram(buckets, true)
+ val expectedHistogramResults = Array(4)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramResults2 === expectedHistogramResults)
+ }
+
+ test("WorksWithOutOfRangeWithTwoBuckets") {
+ // Verify that out of range works with two buckets
+ val rdd = sc.parallelize(Seq(10.01, -0.01))
+ val buckets = Array(0.0, 5.0, 10.0)
+ val histogramResults = rdd.histogram(buckets)
+ val histogramResults2 = rdd.histogram(buckets, true)
+ val expectedHistogramResults = Array(0, 0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramResults2 === expectedHistogramResults)
+ }
+
+ test("WorksWithOutOfRangeWithTwoUnEvenBuckets") {
+ // Verify that out of range works with two un even buckets
+ val rdd = sc.parallelize(Seq(10.01, -0.01))
+ val buckets = Array(0.0, 4.0, 10.0)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(0, 0)
+ assert(histogramResults === expectedHistogramResults)
+ }
+
+ test("WorksInRangeWithTwoBuckets") {
+ // Make sure that it works with two equally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(1, 2, 3, 5, 6))
+ val buckets = Array(0.0, 5.0, 10.0)
+ val histogramResults = rdd.histogram(buckets)
+ val histogramResults2 = rdd.histogram(buckets, true)
+ val expectedHistogramResults = Array(3, 2)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramResults2 === expectedHistogramResults)
+ }
+
+ test("WorksInRangeWithTwoBucketsAndNaN") {
+ // Make sure that it works with two equally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(1, 2, 3, 5, 6, Double.NaN))
+ val buckets = Array(0.0, 5.0, 10.0)
+ val histogramResults = rdd.histogram(buckets)
+ val histogramResults2 = rdd.histogram(buckets, true)
+ val expectedHistogramResults = Array(3, 2)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramResults2 === expectedHistogramResults)
+ }
+
+ test("WorksInRangeWithTwoUnevenBuckets") {
+ // Make sure that it works with two unequally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(1, 2, 3, 5, 6))
+ val buckets = Array(0.0, 5.0, 11.0)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(3, 2)
+ assert(histogramResults === expectedHistogramResults)
+ }
+
+ test("WorksMixedRangeWithTwoUnevenBuckets") {
+ // Make sure that it works with two unequally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.0, 11.01))
+ val buckets = Array(0.0, 5.0, 11.0)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(4, 3)
+ assert(histogramResults === expectedHistogramResults)
+ }
+
+ test("WorksMixedRangeWithFourUnevenBuckets") {
+ // Make sure that it works with two unequally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0,
+ 200.0, 200.1))
+ val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(4, 2, 1, 3)
+ assert(histogramResults === expectedHistogramResults)
+ }
+
+ test("WorksMixedRangeWithUnevenBucketsAndNaN") {
+ // Make sure that it works with two unequally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0,
+ 200.0, 200.1, Double.NaN))
+ val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(4, 2, 1, 3)
+ assert(histogramResults === expectedHistogramResults)
+ }
+ // Make sure this works with a NaN end bucket
+ test("WorksMixedRangeWithUnevenBucketsAndNaNAndNaNRange") {
+ // Make sure that it works with two unequally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0,
+ 200.0, 200.1, Double.NaN))
+ val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0, Double.NaN)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(4, 2, 1, 2, 3)
+ assert(histogramResults === expectedHistogramResults)
+ }
+ // Make sure this works with a NaN end bucket and an inifity
+ test("WorksMixedRangeWithUnevenBucketsAndNaNAndNaNRangeAndInfity") {
+ // Make sure that it works with two unequally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0,
+ 200.0, 200.1, 1.0/0.0, -1.0/0.0, Double.NaN))
+ val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0, Double.NaN)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(4, 2, 1, 2, 4)
+ assert(histogramResults === expectedHistogramResults)
+ }
+
+ test("WorksWithOutOfRangeWithInfiniteBuckets") {
+ // Verify that out of range works with two buckets
+ val rdd = sc.parallelize(Seq(10.01, -0.01, Double.NaN))
+ val buckets = Array(-1.0/0.0 , 0.0, 1.0/0.0)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(1, 1)
+ assert(histogramResults === expectedHistogramResults)
+ }
+ // Test the failure mode with an invalid bucket array
+ test("ThrowsExceptionOnInvalidBucketArray") {
+ val rdd = sc.parallelize(Seq(1.0))
+ // Empty array
+ intercept[IllegalArgumentException] {
+ val buckets = Array.empty[Double]
+ val result = rdd.histogram(buckets)
+ }
+ // Single element array
+ intercept[IllegalArgumentException] {
+ val buckets = Array(1.0)
+ val result = rdd.histogram(buckets)
+ }
+ }
+
+ // Test automatic histogram function
+ test("WorksWithoutBucketsBasic") {
+ // Verify the basic case of one bucket and all elements in that bucket works
+ val rdd = sc.parallelize(Seq(1, 2, 3, 4))
+ val (histogramBuckets, histogramResults) = rdd.histogram(1)
+ val expectedHistogramResults = Array(4)
+ val expectedHistogramBuckets = Array(1.0, 4.0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramBuckets === expectedHistogramBuckets)
+ }
+ // Test automatic histogram function with a single element
+ test("WorksWithoutBucketsBasicSingleElement") {
+ // Verify the basic case of one bucket and all elements in that bucket works
+ val rdd = sc.parallelize(Seq(1))
+ val (histogramBuckets, histogramResults) = rdd.histogram(1)
+ val expectedHistogramResults = Array(1)
+ val expectedHistogramBuckets = Array(1.0, 1.0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramBuckets === expectedHistogramBuckets)
+ }
+ // Test automatic histogram function with a single element
+ test("WorksWithoutBucketsBasicNoRange") {
+ // Verify the basic case of one bucket and all elements in that bucket works
+ val rdd = sc.parallelize(Seq(1, 1, 1, 1))
+ val (histogramBuckets, histogramResults) = rdd.histogram(1)
+ val expectedHistogramResults = Array(4)
+ val expectedHistogramBuckets = Array(1.0, 1.0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramBuckets === expectedHistogramBuckets)
+ }
+
+ test("WorksWithoutBucketsBasicTwo") {
+ // Verify the basic case of one bucket and all elements in that bucket works
+ val rdd = sc.parallelize(Seq(1, 2, 3, 4))
+ val (histogramBuckets, histogramResults) = rdd.histogram(2)
+ val expectedHistogramResults = Array(2, 2)
+ val expectedHistogramBuckets = Array(1.0, 2.5, 4.0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramBuckets === expectedHistogramBuckets)
+ }
+
+ test("WorksWithoutBucketsWithMoreRequestedThanElements") {
+ // Verify the basic case of one bucket and all elements in that bucket works
+ val rdd = sc.parallelize(Seq(1, 2))
+ val (histogramBuckets, histogramResults) = rdd.histogram(10)
+ val expectedHistogramResults =
+ Array(1, 0, 0, 0, 0, 0, 0, 0, 0, 1)
+ val expectedHistogramBuckets =
+ Array(1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramBuckets === expectedHistogramBuckets)
+ }
+
+ // Test the failure mode with an invalid RDD
+ test("ThrowsExceptionOnInvalidRDDs") {
+ // infinity
+ intercept[UnsupportedOperationException] {
+ val rdd = sc.parallelize(Seq(1, 1.0/0.0))
+ val result = rdd.histogram(1)
+ }
+ // NaN
+ intercept[UnsupportedOperationException] {
+ val rdd = sc.parallelize(Seq(1, Double.NaN))
+ val result = rdd.histogram(1)
+ }
+ // Empty
+ intercept[UnsupportedOperationException] {
+ val rdd: RDD[Double] = sc.parallelize(Seq())
+ val result = rdd.histogram(1)
+ }
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala
new file mode 100644
index 0000000000..b78367b6ca
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala
@@ -0,0 +1,76 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.util
+
+import java.util.Random
+import org.scalatest.FlatSpec
+import org.scalatest.FunSuite
+import org.scalatest.matchers.ShouldMatchers
+import org.apache.spark.util.Utils.times
+
+class XORShiftRandomSuite extends FunSuite with ShouldMatchers {
+
+ def fixture = new {
+ val seed = 1L
+ val xorRand = new XORShiftRandom(seed)
+ val hundMil = 1e8.toInt
+ }
+
+ /*
+ * This test is based on a chi-squared test for randomness. The values are hard-coded
+ * so as not to create Spark's dependency on apache.commons.math3 just to call one
+ * method for calculating the exact p-value for a given number of random numbers
+ * and bins. In case one would want to move to a full-fledged test based on
+ * apache.commons.math3, the relevant class is here:
+ * org.apache.commons.math3.stat.inference.ChiSquareTest
+ */
+ test ("XORShift generates valid random numbers") {
+
+ val f = fixture
+
+ val numBins = 10
+ // create 10 bins
+ val bins = Array.fill(numBins)(0)
+
+ // populate bins based on modulus of the random number
+ times(f.hundMil) {bins(math.abs(f.xorRand.nextInt) % 10) += 1}
+
+ /* since the seed is deterministic, until the algorithm is changed, we know the result will be
+ * exactly this: Array(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272,
+ * 10000790, 10002286, 9998699), so the test will never fail at the prespecified (5%)
+ * significance level. However, should the RNG implementation change, the test should still
+ * pass at the same significance level. The chi-squared test done in R gave the following
+ * results:
+ * > chisq.test(c(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272,
+ * 10000790, 10002286, 9998699))
+ * Chi-squared test for given probabilities
+ * data: c(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272, 10000790,
+ * 10002286, 9998699)
+ * X-squared = 11.975, df = 9, p-value = 0.2147
+ * Note that the p-value was ~0.22. The test will fail if alpha < 0.05, which for 100 million
+ * random numbers
+ * and 10 bins will happen at X-squared of ~16.9196. So, the test will fail if X-squared
+ * is greater than or equal to that number.
+ */
+ val binSize = f.hundMil/numBins
+ val xSquared = bins.map(x => math.pow((binSize - x), 2)/binSize).sum
+ xSquared should be < (16.9196)
+
+ }
+
+} \ No newline at end of file
diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala
index ca3f684668..63e874fed3 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala
@@ -2,8 +2,20 @@ package org.apache.spark.util.collection
import scala.collection.mutable.HashSet
import org.scalatest.FunSuite
-
-class OpenHashMapSuite extends FunSuite {
+import org.scalatest.matchers.ShouldMatchers
+import org.apache.spark.util.SizeEstimator
+
+class OpenHashMapSuite extends FunSuite with ShouldMatchers {
+
+ test("size for specialized, primitive value (int)") {
+ val capacity = 1024
+ val map = new OpenHashMap[String, Int](capacity)
+ val actualSize = SizeEstimator.estimate(map)
+ // 64 bit for pointers, 32 bit for ints, and 1 bit for the bitset.
+ val expectedSize = capacity * (64 + 32 + 1) / 8
+ // Make sure we are not allocating a significant amount of memory beyond our expected.
+ actualSize should be <= (expectedSize * 1.1).toLong
+ }
test("initialization") {
val goodMap1 = new OpenHashMap[String, Int](1)
diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala
index 4e11e8a628..4768a1e60b 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala
@@ -1,9 +1,27 @@
package org.apache.spark.util.collection
import org.scalatest.FunSuite
+import org.scalatest.matchers.ShouldMatchers
+import org.apache.spark.util.SizeEstimator
-class OpenHashSetSuite extends FunSuite {
+
+class OpenHashSetSuite extends FunSuite with ShouldMatchers {
+
+ test("size for specialized, primitive int") {
+ val loadFactor = 0.7
+ val set = new OpenHashSet[Int](64, loadFactor)
+ for (i <- 0 until 1024) {
+ set.add(i)
+ }
+ assert(set.size === 1024)
+ assert(set.capacity > 1024)
+ val actualSize = SizeEstimator.estimate(set)
+ // 32 bits for the ints + 1 bit for the bitset
+ val expectedSize = set.capacity * (32 + 1) / 8
+ // Make sure we are not allocating a significant amount of memory beyond our expected.
+ actualSize should be <= (expectedSize * 1.1).toLong
+ }
test("primitive int") {
val set = new OpenHashSet[Int]
diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala
index dfd6aed2c4..2220b4f0d5 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashSetSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala
@@ -2,8 +2,20 @@ package org.apache.spark.util.collection
import scala.collection.mutable.HashSet
import org.scalatest.FunSuite
+import org.scalatest.matchers.ShouldMatchers
+import org.apache.spark.util.SizeEstimator
-class PrimitiveKeyOpenHashSetSuite extends FunSuite {
+class PrimitiveKeyOpenHashMapSuite extends FunSuite with ShouldMatchers {
+
+ test("size for specialized, primitive key, value (int, int)") {
+ val capacity = 1024
+ val map = new PrimitiveKeyOpenHashMap[Int, Int](capacity)
+ val actualSize = SizeEstimator.estimate(map)
+ // 32 bit for keys, 32 bit for values, and 1 bit for the bitset.
+ val expectedSize = capacity * (32 + 32 + 1) / 8
+ // Make sure we are not allocating a significant amount of memory beyond our expected.
+ actualSize should be <= (expectedSize * 1.1).toLong
+ }
test("initialization") {
val goodMap1 = new PrimitiveKeyOpenHashMap[Int, Int](1)
diff --git a/docs/monitoring.md b/docs/monitoring.md
index 5f456b999b..5ed0474477 100644
--- a/docs/monitoring.md
+++ b/docs/monitoring.md
@@ -50,6 +50,7 @@ Each instance can report to zero or more _sinks_. Sinks are contained in the
* `GangliaSink`: Sends metrics to a Ganglia node or multicast group.
* `JmxSink`: Registers metrics for viewing in a JXM console.
* `MetricsServlet`: Adds a servlet within the existing Spark UI to serve metrics data as JSON data.
+* `GraphiteSink`: Sends metrics to a Graphite node.
The syntax of the metrics configuration file is defined in an example configuration file,
`$SPARK_HOME/conf/metrics.conf.template`.
diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md
index 4056e9c15d..68fd6c2ab1 100644
--- a/docs/running-on-yarn.md
+++ b/docs/running-on-yarn.md
@@ -45,6 +45,10 @@ System Properties:
Ensure that HADOOP_CONF_DIR or YARN_CONF_DIR points to the directory which contains the (client side) configuration files for the hadoop cluster.
This would be used to connect to the cluster, write to the dfs and submit jobs to the resource manager.
+There are two scheduler mode that can be used to launch spark application on YARN.
+
+## Launch spark application by YARN Client with yarn-standalone mode.
+
The command to launch the YARN Client is as follows:
SPARK_JAR=<SPARK_ASSEMBLY_JAR_FILE> ./spark-class org.apache.spark.deploy.yarn.Client \
@@ -52,6 +56,7 @@ The command to launch the YARN Client is as follows:
--class <APP_MAIN_CLASS> \
--args <APP_MAIN_ARGUMENTS> \
--num-workers <NUMBER_OF_WORKER_MACHINES> \
+ --master-class <ApplicationMaster_CLASS>
--master-memory <MEMORY_FOR_MASTER> \
--worker-memory <MEMORY_PER_WORKER> \
--worker-cores <CORES_PER_WORKER> \
@@ -85,11 +90,29 @@ For example:
$ cat $YARN_APP_LOGS_DIR/$YARN_APP_ID/container*_000001/stdout
Pi is roughly 3.13794
-The above starts a YARN Client programs which periodically polls the Application Master for status updates and displays them in the console. The client will exit once your application has finished running.
+The above starts a YARN Client programs which start the default Application Master. Then SparkPi will be run as a child thread of Application Master, YARN Client will periodically polls the Application Master for status updates and displays them in the console. The client will exit once your application has finished running.
+
+With this mode, your application is actually run on the remote machine where the Application Master is run upon. Thus application that involve local interaction will not work well, e.g. spark-shell.
+
+## Launch spark application with yarn-client mode.
+
+With yarn-client mode, the application will be launched locally. Just like running application or spark-shell on Local / Mesos / Standalone mode. The launch method is also the similar with them, just make sure that when you need to specify a master url, use "yarn-client" instead. And you also need to export the env value for SPARK_JAR and SPARK_YARN_APP_JAR
+
+In order to tune worker core/number/memory etc. You need to export SPARK_WORKER_CORES, SPARK_WORKER_MEMORY, SPARK_WORKER_INSTANCES e.g. by ./conf/spark-env.sh
+
+For example:
+
+ SPARK_JAR=./assembly/target/scala-{{site.SCALA_VERSION}}/spark-assembly-{{site.SPARK_VERSION}}-hadoop2.0.5-alpha.jar \
+ SPARK_YARN_APP_JAR=examples/target/scala-{{site.SCALA_VERSION}}/spark-examples-assembly-{{site.SPARK_VERSION}}.jar \
+ ./run-example org.apache.spark.examples.SparkPi yarn-client
+
+
+ SPARK_JAR=./assembly/target/scala-{{site.SCALA_VERSION}}/spark-assembly-{{site.SPARK_VERSION}}-hadoop2.0.5-alpha.jar \
+ SPARK_YARN_APP_JAR=examples/target/scala-{{site.SCALA_VERSION}}/spark-examples-assembly-{{site.SPARK_VERSION}}.jar \
+ MASTER=yarn-client ./spark-shell
# Important Notes
-- When your application instantiates a Spark context it must use a special "yarn-standalone" master url. This starts the scheduler without forcing it to connect to a cluster. A good way to handle this is to pass "yarn-standalone" as an argument to your program, as shown in the example above.
- We do not requesting container resources based on the number of cores. Thus the numbers of cores given via command line arguments cannot be guaranteed.
- The local directories used for spark will be the local directories configured for YARN (Hadoop Yarn config yarn.nodemanager.local-dirs). If the user specifies spark.local.dir, it will be ignored.
- The --files and --archives options support specifying file names with the # similar to Hadoop. For example you can specify: --files localtest.txt#appSees.txt and this will upload the file you have locally named localtest.txt into HDFS but this will be linked to by the name appSees.txt and your application should use the name as appSees.txt to reference it when running on YARN.
diff --git a/docs/tuning.md b/docs/tuning.md
index f491ae9b95..a4be188169 100644
--- a/docs/tuning.md
+++ b/docs/tuning.md
@@ -39,7 +39,8 @@ in your operations) and performance. It provides two serialization libraries:
for best performance.
You can switch to using Kryo by calling `System.setProperty("spark.serializer", "org.apache.spark.serializer.KryoSerializer")`
-*before* creating your SparkContext. The only reason it is not the default is because of the custom
+*before* creating your SparkContext. This setting configures the serializer used for not only shuffling data between worker
+nodes but also when serializing RDDs to disk. The only reason Kryo is not the default is because of the custom
registration requirement, but we recommend trying it in any network-intensive application.
Finally, to register your classes with Kryo, create a public class that extends
@@ -67,7 +68,7 @@ The [Kryo documentation](http://code.google.com/p/kryo/) describes more advanced
registration options, such as adding custom serialization code.
If your objects are large, you may also need to increase the `spark.kryoserializer.buffer.mb`
-system property. The default is 32, but this value needs to be large enough to hold the *largest*
+system property. The default is 2, but this value needs to be large enough to hold the *largest*
object you will serialize.
Finally, if you don't register your classes, Kryo will still work, but it will have to store the
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
index edbf77dbcc..0dee9399a8 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala
@@ -18,15 +18,16 @@
package org.apache.spark.mllib.clustering
import scala.collection.mutable.ArrayBuffer
-import scala.util.Random
+
+import org.jblas.DoubleMatrix
import org.apache.spark.SparkContext
import org.apache.spark.SparkContext._
import org.apache.spark.rdd.RDD
import org.apache.spark.Logging
import org.apache.spark.mllib.util.MLUtils
+import org.apache.spark.util.XORShiftRandom
-import org.jblas.DoubleMatrix
/**
@@ -195,7 +196,7 @@ class KMeans private (
*/
private def initRandom(data: RDD[Array[Double]]): Array[ClusterCenters] = {
// Sample all the cluster centers in one pass to avoid repeated scans
- val sample = data.takeSample(true, runs * k, new Random().nextInt()).toSeq
+ val sample = data.takeSample(true, runs * k, new XORShiftRandom().nextInt()).toSeq
Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k).toArray)
}
@@ -210,7 +211,7 @@ class KMeans private (
*/
private def initKMeansParallel(data: RDD[Array[Double]]): Array[ClusterCenters] = {
// Initialize each run's center to a random point
- val seed = new Random().nextInt()
+ val seed = new XORShiftRandom().nextInt()
val sample = data.takeSample(true, runs, seed).toSeq
val centers = Array.tabulate(runs)(r => ArrayBuffer(sample(r)))
@@ -222,7 +223,7 @@ class KMeans private (
for (r <- 0 until runs) yield (r, KMeans.pointCost(centerArrays(r), point))
}.reduceByKey(_ + _).collectAsMap()
val chosen = data.mapPartitionsWithIndex { (index, points) =>
- val rand = new Random(seed ^ (step << 16) ^ index)
+ val rand = new XORShiftRandom(seed ^ (step << 16) ^ index)
for {
p <- points
r <- 0 until runs
diff --git a/pom.xml b/pom.xml
index edcc3b35cd..42c1e00e9d 100644
--- a/pom.xml
+++ b/pom.xml
@@ -347,6 +347,11 @@
<version>3.0.0</version>
</dependency>
<dependency>
+ <groupId>com.codahale.metrics</groupId>
+ <artifactId>metrics-graphite</artifactId>
+ <version>3.0.0</version>
+ </dependency>
+ <dependency>
<groupId>org.scala-lang</groupId>
<artifactId>scala-compiler</artifactId>
<version>${scala.version}</version>
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 3dd7c8c962..2619c705bc 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -244,6 +244,7 @@ object SparkBuild extends Build {
"com.codahale.metrics" % "metrics-jvm" % "3.0.0",
"com.codahale.metrics" % "metrics-json" % "3.0.0",
"com.codahale.metrics" % "metrics-ganglia" % "3.0.0",
+ "com.codahale.metrics" % "metrics-graphite" % "3.0.0",
"com.twitter" % "chill_2.9.3" % "0.3.1",
"com.twitter" % "chill-java" % "0.3.1"
)
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index b3a6d2b8eb..79dd038065 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -57,10 +57,12 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
// Staging directory is private! -> rwx--------
val STAGING_DIR_PERMISSION: FsPermission = FsPermission.createImmutable(0700:Short)
+
// App files are world-wide readable and owner writable -> rw-r--r--
val APP_FILE_PERMISSION: FsPermission = FsPermission.createImmutable(0644:Short)
- def run() {
+ // for client user who want to monitor app status by itself.
+ def runApp() = {
validateArgs()
init(yarnConf)
@@ -82,7 +84,11 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
appContext.setUser(UserGroupInformation.getCurrentUser().getShortUserName())
submitApp(appContext)
+ appId
+ }
+ def run() {
+ val appId = runApp()
monitorApplication(appId)
System.exit(0)
}
@@ -384,7 +390,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
val commands = List[String](javaCommand +
" -server " +
JAVA_OPTS +
- " org.apache.spark.deploy.yarn.ApplicationMaster" +
+ " " + args.amClass +
" --class " + args.userClass +
" --jar " + args.userJar +
userArgsToString(args) +
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
index 852dbd7dab..b9dbc3fb87 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
@@ -35,6 +35,7 @@ class ClientArguments(val args: Array[String]) {
var numWorkers = 2
var amQueue = System.getProperty("QUEUE", "default")
var amMemory: Int = 512
+ var amClass: String = "org.apache.spark.deploy.yarn.ApplicationMaster"
var appName: String = "Spark"
// TODO
var inputFormatInfo: List[InputFormatInfo] = null
@@ -62,18 +63,22 @@ class ClientArguments(val args: Array[String]) {
userArgsBuffer += value
args = tail
- case ("--master-memory") :: MemoryParam(value) :: tail =>
- amMemory = value
+ case ("--master-class") :: value :: tail =>
+ amClass = value
args = tail
- case ("--num-workers") :: IntParam(value) :: tail =>
- numWorkers = value
+ case ("--master-memory") :: MemoryParam(value) :: tail =>
+ amMemory = value
args = tail
case ("--worker-memory") :: MemoryParam(value) :: tail =>
workerMemory = value
args = tail
+ case ("--num-workers") :: IntParam(value) :: tail =>
+ numWorkers = value
+ args = tail
+
case ("--worker-cores") :: IntParam(value) :: tail =>
workerCores = value
args = tail
@@ -119,19 +124,20 @@ class ClientArguments(val args: Array[String]) {
System.err.println(
"Usage: org.apache.spark.deploy.yarn.Client [options] \n" +
"Options:\n" +
- " --jar JAR_PATH Path to your application's JAR file (required)\n" +
- " --class CLASS_NAME Name of your application's main class (required)\n" +
- " --args ARGS Arguments to be passed to your application's main class.\n" +
- " Mutliple invocations are possible, each will be passed in order.\n" +
- " --num-workers NUM Number of workers to start (Default: 2)\n" +
- " --worker-cores NUM Number of cores for the workers (Default: 1). This is unsused right now.\n" +
- " --master-memory MEM Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb)\n" +
- " --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n" +
- " --name NAME The name of your application (Default: Spark)\n" +
- " --queue QUEUE The hadoop queue to use for allocation requests (Default: 'default')\n" +
- " --addJars jars Comma separated list of local jars that want SparkContext.addJar to work with.\n" +
- " --files files Comma separated list of files to be distributed with the job.\n" +
- " --archives archives Comma separated list of archives to be distributed with the job."
+ " --jar JAR_PATH Path to your application's JAR file (required)\n" +
+ " --class CLASS_NAME Name of your application's main class (required)\n" +
+ " --args ARGS Arguments to be passed to your application's main class.\n" +
+ " Mutliple invocations are possible, each will be passed in order.\n" +
+ " --num-workers NUM Number of workers to start (Default: 2)\n" +
+ " --worker-cores NUM Number of cores for the workers (Default: 1). This is unsused right now.\n" +
+ " --master-class CLASS_NAME Class Name for Master (Default: spark.deploy.yarn.ApplicationMaster)\n" +
+ " --master-memory MEM Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb)\n" +
+ " --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n" +
+ " --name NAME The name of your application (Default: Spark)\n" +
+ " --queue QUEUE The hadoop queue to use for allocation requests (Default: 'default')\n" +
+ " --addJars jars Comma separated list of local jars that want SparkContext.addJar to work with.\n" +
+ " --files files Comma separated list of files to be distributed with the job.\n" +
+ " --archives archives Comma separated list of archives to be distributed with the job."
)
System.exit(exitCode)
}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala
new file mode 100644
index 0000000000..421a83c87a
--- /dev/null
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala
@@ -0,0 +1,246 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.net.Socket
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.net.NetUtils
+import org.apache.hadoop.yarn.api._
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.api.protocolrecords._
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.ipc.YarnRPC
+import org.apache.hadoop.yarn.util.{ConverterUtils, Records}
+import akka.actor._
+import akka.remote.{RemoteClientShutdown, RemoteClientDisconnected, RemoteClientLifeCycleEvent}
+import akka.remote.RemoteClientShutdown
+import akka.actor.Terminated
+import akka.remote.RemoteClientDisconnected
+import org.apache.spark.{SparkContext, Logging}
+import org.apache.spark.util.{Utils, AkkaUtils}
+import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
+import org.apache.spark.scheduler.SplitInfo
+
+class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration) extends Logging {
+
+ def this(args: ApplicationMasterArguments) = this(args, new Configuration())
+
+ private val rpc: YarnRPC = YarnRPC.create(conf)
+ private var resourceManager: AMRMProtocol = null
+ private var appAttemptId: ApplicationAttemptId = null
+ private var reporterThread: Thread = null
+ private val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
+
+ private var yarnAllocator: YarnAllocationHandler = null
+ private var driverClosed:Boolean = false
+
+ val actorSystem : ActorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0)._1
+ var actor: ActorRef = null
+
+ // This actor just working as a monitor to watch on Driver Actor.
+ class MonitorActor(driverUrl: String) extends Actor {
+
+ var driver: ActorRef = null
+
+ override def preStart() {
+ logInfo("Listen to driver: " + driverUrl)
+ driver = context.actorFor(driverUrl)
+ driver ! "hello"
+ context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
+ context.watch(driver) // Doesn't work with remote actors, but useful for testing
+ }
+
+ override def receive = {
+ case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) =>
+ logInfo("Driver terminated or disconnected! Shutting down.")
+ driverClosed = true
+ }
+ }
+
+ def run() {
+
+ appAttemptId = getApplicationAttemptId()
+ resourceManager = registerWithResourceManager()
+ val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster()
+
+ // Compute number of threads for akka
+ val minimumMemory = appMasterResponse.getMinimumResourceCapability().getMemory()
+
+ if (minimumMemory > 0) {
+ val mem = args.workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD
+ val numCore = (mem / minimumMemory) + (if (0 != (mem % minimumMemory)) 1 else 0)
+
+ if (numCore > 0) {
+ // do not override - hits https://issues.apache.org/jira/browse/HADOOP-8406
+ // TODO: Uncomment when hadoop is on a version which has this fixed.
+ // args.workerCores = numCore
+ }
+ }
+
+ waitForSparkMaster()
+
+ // Allocate all containers
+ allocateWorkers()
+
+ // Launch a progress reporter thread, else app will get killed after expiration (def: 10mins) timeout
+ // ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapse.
+
+ val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000)
+ // must be <= timeoutInterval/ 2.
+ // On other hand, also ensure that we are reasonably responsive without causing too many requests to RM.
+ // so atleast 1 minute or timeoutInterval / 10 - whichever is higher.
+ val interval = math.min(timeoutInterval / 2, math.max(timeoutInterval/ 10, 60000L))
+ reporterThread = launchReporterThread(interval)
+
+ // Wait for the reporter thread to Finish.
+ reporterThread.join()
+
+ finishApplicationMaster(FinalApplicationStatus.SUCCEEDED)
+ actorSystem.shutdown()
+
+ logInfo("Exited")
+ System.exit(0)
+ }
+
+ private def getApplicationAttemptId(): ApplicationAttemptId = {
+ val envs = System.getenv()
+ val containerIdString = envs.get(ApplicationConstants.AM_CONTAINER_ID_ENV)
+ val containerId = ConverterUtils.toContainerId(containerIdString)
+ val appAttemptId = containerId.getApplicationAttemptId()
+ logInfo("ApplicationAttemptId: " + appAttemptId)
+ return appAttemptId
+ }
+
+ private def registerWithResourceManager(): AMRMProtocol = {
+ val rmAddress = NetUtils.createSocketAddr(yarnConf.get(
+ YarnConfiguration.RM_SCHEDULER_ADDRESS,
+ YarnConfiguration.DEFAULT_RM_SCHEDULER_ADDRESS))
+ logInfo("Connecting to ResourceManager at " + rmAddress)
+ return rpc.getProxy(classOf[AMRMProtocol], rmAddress, conf).asInstanceOf[AMRMProtocol]
+ }
+
+ private def registerApplicationMaster(): RegisterApplicationMasterResponse = {
+ logInfo("Registering the ApplicationMaster")
+ val appMasterRequest = Records.newRecord(classOf[RegisterApplicationMasterRequest])
+ .asInstanceOf[RegisterApplicationMasterRequest]
+ appMasterRequest.setApplicationAttemptId(appAttemptId)
+ // Setting this to master host,port - so that the ApplicationReport at client has some sensible info.
+ // Users can then monitor stderr/stdout on that node if required.
+ appMasterRequest.setHost(Utils.localHostName())
+ appMasterRequest.setRpcPort(0)
+ // What do we provide here ? Might make sense to expose something sensible later ?
+ appMasterRequest.setTrackingUrl("")
+ return resourceManager.registerApplicationMaster(appMasterRequest)
+ }
+
+ private def waitForSparkMaster() {
+ logInfo("Waiting for spark driver to be reachable.")
+ var driverUp = false
+ val hostport = args.userArgs(0)
+ val (driverHost, driverPort) = Utils.parseHostPort(hostport)
+ while(!driverUp) {
+ try {
+ val socket = new Socket(driverHost, driverPort)
+ socket.close()
+ logInfo("Master now available: " + driverHost + ":" + driverPort)
+ driverUp = true
+ } catch {
+ case e: Exception =>
+ logError("Failed to connect to driver at " + driverHost + ":" + driverPort)
+ Thread.sleep(100)
+ }
+ }
+ System.setProperty("spark.driver.host", driverHost)
+ System.setProperty("spark.driver.port", driverPort.toString)
+
+ val driverUrl = "akka://spark@%s:%s/user/%s".format(
+ driverHost, driverPort.toString, CoarseGrainedSchedulerBackend.ACTOR_NAME)
+
+ actor = actorSystem.actorOf(Props(new MonitorActor(driverUrl)), name = "YarnAM")
+ }
+
+
+ private def allocateWorkers() {
+
+ // Fixme: should get preferredNodeLocationData from SparkContext, just fake a empty one for now.
+ val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] = scala.collection.immutable.Map()
+
+ yarnAllocator = YarnAllocationHandler.newAllocator(yarnConf, resourceManager, appAttemptId, args, preferredNodeLocationData)
+
+ logInfo("Allocating " + args.numWorkers + " workers.")
+ // Wait until all containers have finished
+ // TODO: This is a bit ugly. Can we make it nicer?
+ // TODO: Handle container failure
+ while(yarnAllocator.getNumWorkersRunning < args.numWorkers) {
+ yarnAllocator.allocateContainers(math.max(args.numWorkers - yarnAllocator.getNumWorkersRunning, 0))
+ Thread.sleep(100)
+ }
+
+ logInfo("All workers have launched.")
+
+ }
+
+ // TODO: We might want to extend this to allocate more containers in case they die !
+ private def launchReporterThread(_sleepTime: Long): Thread = {
+ val sleepTime = if (_sleepTime <= 0 ) 0 else _sleepTime
+
+ val t = new Thread {
+ override def run() {
+ while (!driverClosed) {
+ val missingWorkerCount = args.numWorkers - yarnAllocator.getNumWorkersRunning
+ if (missingWorkerCount > 0) {
+ logInfo("Allocating " + missingWorkerCount + " containers to make up for (potentially ?) lost containers")
+ yarnAllocator.allocateContainers(missingWorkerCount)
+ }
+ else sendProgress()
+ Thread.sleep(sleepTime)
+ }
+ }
+ }
+ // setting to daemon status, though this is usually not a good idea.
+ t.setDaemon(true)
+ t.start()
+ logInfo("Started progress reporter thread - sleep time : " + sleepTime)
+ return t
+ }
+
+ private def sendProgress() {
+ logDebug("Sending progress")
+ // simulated with an allocate request with no nodes requested ...
+ yarnAllocator.allocateContainers(0)
+ }
+
+ def finishApplicationMaster(status: FinalApplicationStatus) {
+
+ logInfo("finish ApplicationMaster with " + status)
+ val finishReq = Records.newRecord(classOf[FinishApplicationMasterRequest])
+ .asInstanceOf[FinishApplicationMasterRequest]
+ finishReq.setAppAttemptId(appAttemptId)
+ finishReq.setFinishApplicationStatus(status)
+ resourceManager.finishApplicationMaster(finishReq)
+ }
+
+}
+
+
+object WorkerLauncher {
+ def main(argStrings: Array[String]) {
+ val args = new ApplicationMasterArguments(argStrings)
+ new WorkerLauncher(args).run()
+ }
+}
diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala
new file mode 100644
index 0000000000..63a0449e5a
--- /dev/null
+++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import org.apache.spark._
+import org.apache.hadoop.conf.Configuration
+import org.apache.spark.deploy.yarn.YarnAllocationHandler
+import org.apache.spark.util.Utils
+
+/**
+ *
+ * This scheduler launch worker through Yarn - by call into Client to launch WorkerLauncher as AM.
+ */
+private[spark] class YarnClientClusterScheduler(sc: SparkContext, conf: Configuration) extends ClusterScheduler(sc) {
+
+ def this(sc: SparkContext) = this(sc, new Configuration())
+
+ // By default, rack is unknown
+ override def getRackForHost(hostPort: String): Option[String] = {
+ val host = Utils.parseHostPort(hostPort)._1
+ val retval = YarnAllocationHandler.lookupRack(conf, host)
+ if (retval != null) Some(retval) else None
+ }
+
+ override def postStartHook() {
+
+ // The yarn application is running, but the worker might not yet ready
+ // Wait for a few seconds for the slaves to bootstrap and register with master - best case attempt
+ Thread.sleep(2000L)
+ logInfo("YarnClientClusterScheduler.postStartHook done")
+ }
+}
diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
new file mode 100644
index 0000000000..b206780c78
--- /dev/null
+++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import org.apache.hadoop.yarn.api.records.{ApplicationId, YarnApplicationState}
+import org.apache.spark.{SparkException, Logging, SparkContext}
+import org.apache.spark.deploy.yarn.{Client, ClientArguments}
+
+private[spark] class YarnClientSchedulerBackend(
+ scheduler: ClusterScheduler,
+ sc: SparkContext)
+ extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem)
+ with Logging {
+
+ var client: Client = null
+ var appId: ApplicationId = null
+
+ override def start() {
+ super.start()
+
+ val defalutWorkerCores = "2"
+ val defalutWorkerMemory = "512m"
+ val defaultWorkerNumber = "1"
+
+ val userJar = System.getenv("SPARK_YARN_APP_JAR")
+ var workerCores = System.getenv("SPARK_WORKER_CORES")
+ var workerMemory = System.getenv("SPARK_WORKER_MEMORY")
+ var workerNumber = System.getenv("SPARK_WORKER_INSTANCES")
+
+ if (userJar == null)
+ throw new SparkException("env SPARK_YARN_APP_JAR is not set")
+
+ if (workerCores == null)
+ workerCores = defalutWorkerCores
+ if (workerMemory == null)
+ workerMemory = defalutWorkerMemory
+ if (workerNumber == null)
+ workerNumber = defaultWorkerNumber
+
+ val driverHost = System.getProperty("spark.driver.host")
+ val driverPort = System.getProperty("spark.driver.port")
+ val hostport = driverHost + ":" + driverPort
+
+ val argsArray = Array[String](
+ "--class", "notused",
+ "--jar", userJar,
+ "--args", hostport,
+ "--worker-memory", workerMemory,
+ "--worker-cores", workerCores,
+ "--num-workers", workerNumber,
+ "--master-class", "org.apache.spark.deploy.yarn.WorkerLauncher"
+ )
+
+ val args = new ClientArguments(argsArray)
+ client = new Client(args)
+ appId = client.runApp()
+ waitForApp()
+ }
+
+ def waitForApp() {
+
+ // TODO : need a better way to find out whether the workers are ready or not
+ // maybe by resource usage report?
+ while(true) {
+ val report = client.getApplicationReport(appId)
+
+ logInfo("Application report from ASM: \n" +
+ "\t appMasterRpcPort: " + report.getRpcPort() + "\n" +
+ "\t appStartTime: " + report.getStartTime() + "\n" +
+ "\t yarnAppState: " + report.getYarnApplicationState() + "\n"
+ )
+
+ // Ready to go, or already gone.
+ val state = report.getYarnApplicationState()
+ if (state == YarnApplicationState.RUNNING) {
+ return
+ } else if (state == YarnApplicationState.FINISHED ||
+ state == YarnApplicationState.FAILED ||
+ state == YarnApplicationState.KILLED) {
+ throw new SparkException("Yarn application already ended," +
+ "might be killed or not able to launch application master.")
+ }
+
+ Thread.sleep(1000)
+ }
+ }
+
+ override def stop() {
+ super.stop()
+ client.stop()
+ logInfo("Stoped")
+ }
+
+}