aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--core/src/main/scala/org/apache/spark/SparkContext.scala25
-rw-r--r--core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala40
-rw-r--r--core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala149
-rw-r--r--core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala23
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala126
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala11
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala42
-rw-r--r--core/src/main/scala/org/apache/spark/rdd/RDD.scala43
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala15
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala2
-rw-r--r--core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala1
-rw-r--r--core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala107
-rw-r--r--core/src/test/scala/org/apache/spark/CheckpointSuite.scala2
-rw-r--r--core/src/test/scala/org/apache/spark/JavaAPISuite.java14
-rw-r--r--core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala271
-rw-r--r--core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala7
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala16
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala20
-rw-r--r--core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala (renamed from core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashSetSuite.scala)14
-rw-r--r--docs/running-on-yarn.md27
-rw-r--r--docs/tuning.md3
-rw-r--r--python/epydoc.conf2
-rw-r--r--python/pyspark/accumulators.py6
-rw-r--r--python/pyspark/context.py71
-rw-r--r--python/pyspark/rdd.py97
-rw-r--r--python/pyspark/serializers.py301
-rw-r--r--python/pyspark/tests.py3
-rw-r--r--python/pyspark/worker.py44
-rwxr-xr-xpython/run-tests1
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala13
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala40
-rw-r--r--yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala246
-rw-r--r--yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala47
-rw-r--r--yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala109
35 files changed, 1530 insertions, 409 deletions
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index b9fe7f604e..6fd7a0d15a 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -228,6 +228,31 @@ class SparkContext(
scheduler.initialize(backend)
scheduler
+ case "yarn-client" =>
+ val scheduler = try {
+ val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClientClusterScheduler")
+ val cons = clazz.getConstructor(classOf[SparkContext])
+ cons.newInstance(this).asInstanceOf[ClusterScheduler]
+
+ } catch {
+ case th: Throwable => {
+ throw new SparkException("YARN mode not available ?", th)
+ }
+ }
+
+ val backend = try {
+ val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend")
+ val cons = clazz.getConstructor(classOf[ClusterScheduler], classOf[SparkContext])
+ cons.newInstance(scheduler, this).asInstanceOf[CoarseGrainedSchedulerBackend]
+ } catch {
+ case th: Throwable => {
+ throw new SparkException("YARN mode not available ?", th)
+ }
+ }
+
+ scheduler.initialize(backend)
+ scheduler
+
case MESOS_REGEX(mesosUrl) =>
MesosNativeLibrary.load()
val scheduler = new ClusterScheduler(this)
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
index e5e20dbb66..da30cf619a 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala
@@ -29,6 +29,8 @@ import org.apache.spark.storage.StorageLevel
import java.lang.Double
import org.apache.spark.Partitioner
+import scala.collection.JavaConverters._
+
class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, JavaDoubleRDD] {
override val classTag: ClassTag[Double] = implicitly[ClassTag[Double]]
@@ -185,6 +187,44 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav
/** (Experimental) Approximate operation to return the sum within a timeout. */
def sumApprox(timeout: Long): PartialResult[BoundedDouble] = srdd.sumApprox(timeout)
+
+ /**
+ * Compute a histogram of the data using bucketCount number of buckets evenly
+ * spaced between the minimum and maximum of the RDD. For example if the min
+ * value is 0 and the max is 100 and there are two buckets the resulting
+ * buckets will be [0,50) [50,100]. bucketCount must be at least 1
+ * If the RDD contains infinity, NaN throws an exception
+ * If the elements in RDD do not vary (max == min) always returns a single bucket.
+ */
+ def histogram(bucketCount: Int): Pair[Array[scala.Double], Array[Long]] = {
+ val result = srdd.histogram(bucketCount)
+ (result._1, result._2)
+ }
+
+ /**
+ * Compute a histogram using the provided buckets. The buckets are all open
+ * to the left except for the last which is closed
+ * e.g. for the array
+ * [1,10,20,50] the buckets are [1,10) [10,20) [20,50]
+ * e.g 1<=x<10 , 10<=x<20, 20<=x<50
+ * And on the input of 1 and 50 we would have a histogram of 1,0,0
+ *
+ * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched
+ * from an O(log n) inseration to O(1) per element. (where n = # buckets) if you set evenBuckets
+ * to true.
+ * buckets must be sorted and not contain any duplicates.
+ * buckets array must be at least two elements
+ * All NaN entries are treated the same. If you have a NaN bucket it must be
+ * the maximum value of the last position and all NaN entries will be counted
+ * in that bucket.
+ */
+ def histogram(buckets: Array[scala.Double]): Array[Long] = {
+ srdd.histogram(buckets, false)
+ }
+
+ def histogram(buckets: Array[Double], evenBuckets: Boolean): Array[Long] = {
+ srdd.histogram(buckets.map(_.toDouble), evenBuckets)
+ }
}
object JavaDoubleRDD {
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 53b53df9ac..2bf7ac256e 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -28,12 +28,11 @@ import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark._
import org.apache.spark.rdd.RDD
-import org.apache.spark.rdd.PipedRDD
import org.apache.spark.util.Utils
private[spark] class PythonRDD[T: ClassTag](
parent: RDD[T],
- command: Seq[String],
+ command: Array[Byte],
envVars: JMap[String, String],
pythonIncludes: JList[String],
preservePartitoning: Boolean,
@@ -44,21 +43,10 @@ private[spark] class PythonRDD[T: ClassTag](
val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt
- // Similar to Runtime.exec(), if we are given a single string, split it into words
- // using a standard StringTokenizer (i.e. by spaces)
- def this(parent: RDD[T], command: String, envVars: JMap[String, String],
- pythonIncludes: JList[String],
- preservePartitoning: Boolean, pythonExec: String,
- broadcastVars: JList[Broadcast[Array[Byte]]],
- accumulator: Accumulator[JList[Array[Byte]]]) =
- this(parent, PipedRDD.tokenize(command), envVars, pythonIncludes, preservePartitoning, pythonExec,
- broadcastVars, accumulator)
-
override def getPartitions = parent.partitions
override val partitioner = if (preservePartitoning) parent.partitioner else None
-
override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = {
val startTime = System.currentTimeMillis
val env = SparkEnv.get
@@ -71,11 +59,10 @@ private[spark] class PythonRDD[T: ClassTag](
SparkEnv.set(env)
val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
val dataOut = new DataOutputStream(stream)
- val printOut = new PrintWriter(stream)
// Partition index
dataOut.writeInt(split.index)
// sparkFilesDir
- PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dataOut)
+ dataOut.writeUTF(SparkFiles.getRootDirectory)
// Broadcast variables
dataOut.writeInt(broadcastVars.length)
for (broadcast <- broadcastVars) {
@@ -85,21 +72,16 @@ private[spark] class PythonRDD[T: ClassTag](
}
// Python includes (*.zip and *.egg files)
dataOut.writeInt(pythonIncludes.length)
- for (f <- pythonIncludes) {
- PythonRDD.writeAsPickle(f, dataOut)
- }
+ pythonIncludes.foreach(dataOut.writeUTF)
dataOut.flush()
- // Serialized user code
- for (elem <- command) {
- printOut.println(elem)
- }
- printOut.flush()
+ // Serialized command:
+ dataOut.writeInt(command.length)
+ dataOut.write(command)
// Data values
for (elem <- parent.iterator(split, context)) {
- PythonRDD.writeAsPickle(elem, dataOut)
+ PythonRDD.writeToStream(elem, dataOut)
}
dataOut.flush()
- printOut.flush()
worker.shutdownOutput()
} catch {
case e: IOException =>
@@ -132,7 +114,7 @@ private[spark] class PythonRDD[T: ClassTag](
val obj = new Array[Byte](length)
stream.readFully(obj)
obj
- case -3 =>
+ case SpecialLengths.TIMING_DATA =>
// Timing data from worker
val bootTime = stream.readLong()
val initTime = stream.readLong()
@@ -143,24 +125,24 @@ private[spark] class PythonRDD[T: ClassTag](
val total = finishTime - startTime
logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, init, finish))
read
- case -2 =>
+ case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
// Signals that an exception has been thrown in python
val exLength = stream.readInt()
val obj = new Array[Byte](exLength)
stream.readFully(obj)
throw new PythonException(new String(obj))
- case -1 =>
+ case SpecialLengths.END_OF_DATA_SECTION =>
// We've finished the data section of the output, but we can still
- // read some accumulator updates; let's do that, breaking when we
- // get a negative length record.
- var len2 = stream.readInt()
- while (len2 >= 0) {
- val update = new Array[Byte](len2)
+ // read some accumulator updates:
+ val numAccumulatorUpdates = stream.readInt()
+ (1 to numAccumulatorUpdates).foreach { _ =>
+ val updateLen = stream.readInt()
+ val update = new Array[Byte](updateLen)
stream.readFully(update)
accumulator += Collections.singletonList(update)
- len2 = stream.readInt()
+
}
- new Array[Byte](0)
+ Array.empty[Byte]
}
} catch {
case eof: EOFException => {
@@ -197,62 +179,15 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends
val asJavaPairRDD : JavaPairRDD[Long, Array[Byte]] = JavaPairRDD.fromRDD(this)
}
-private[spark] object PythonRDD {
-
- /** Strips the pickle PROTO and STOP opcodes from the start and end of a pickle */
- def stripPickle(arr: Array[Byte]) : Array[Byte] = {
- arr.slice(2, arr.length - 1)
- }
+private object SpecialLengths {
+ val END_OF_DATA_SECTION = -1
+ val PYTHON_EXCEPTION_THROWN = -2
+ val TIMING_DATA = -3
+}
- /**
- * Write strings, pickled Python objects, or pairs of pickled objects to a data output stream.
- * The data format is a 32-bit integer representing the pickled object's length (in bytes),
- * followed by the pickled data.
- *
- * Pickle module:
- *
- * http://docs.python.org/2/library/pickle.html
- *
- * The pickle protocol is documented in the source of the `pickle` and `pickletools` modules:
- *
- * http://hg.python.org/cpython/file/2.6/Lib/pickle.py
- * http://hg.python.org/cpython/file/2.6/Lib/pickletools.py
- *
- * @param elem the object to write
- * @param dOut a data output stream
- */
- def writeAsPickle(elem: Any, dOut: DataOutputStream) {
- if (elem.isInstanceOf[Array[Byte]]) {
- val arr = elem.asInstanceOf[Array[Byte]]
- dOut.writeInt(arr.length)
- dOut.write(arr)
- } else if (elem.isInstanceOf[scala.Tuple2[_, _]]) {
- val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]
- val length = t._1.length + t._2.length - 3 - 3 + 4 // stripPickle() removes 3 bytes
- dOut.writeInt(length)
- dOut.writeByte(Pickle.PROTO)
- dOut.writeByte(Pickle.TWO)
- dOut.write(PythonRDD.stripPickle(t._1))
- dOut.write(PythonRDD.stripPickle(t._2))
- dOut.writeByte(Pickle.TUPLE2)
- dOut.writeByte(Pickle.STOP)
- } else if (elem.isInstanceOf[String]) {
- // For uniformity, strings are wrapped into Pickles.
- val s = elem.asInstanceOf[String].getBytes("UTF-8")
- val length = 2 + 1 + 4 + s.length + 1
- dOut.writeInt(length)
- dOut.writeByte(Pickle.PROTO)
- dOut.writeByte(Pickle.TWO)
- dOut.write(Pickle.BINUNICODE)
- dOut.writeInt(Integer.reverseBytes(s.length))
- dOut.write(s)
- dOut.writeByte(Pickle.STOP)
- } else {
- throw new SparkException("Unexpected RDD type")
- }
- }
+private[spark] object PythonRDD {
- def readRDDFromPickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) :
+ def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int):
JavaRDD[Array[Byte]] = {
val file = new DataInputStream(new FileInputStream(filename))
val objs = new collection.mutable.ArrayBuffer[Array[Byte]]
@@ -270,15 +205,32 @@ private[spark] object PythonRDD {
JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism))
}
- def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) {
+ def writeToStream(elem: Any, dataOut: DataOutputStream) {
+ elem match {
+ case bytes: Array[Byte] =>
+ dataOut.writeInt(bytes.length)
+ dataOut.write(bytes)
+ case pair: (Array[Byte], Array[Byte]) =>
+ dataOut.writeInt(pair._1.length)
+ dataOut.write(pair._1)
+ dataOut.writeInt(pair._2.length)
+ dataOut.write(pair._2)
+ case str: String =>
+ dataOut.writeUTF(str)
+ case other =>
+ throw new SparkException("Unexpected element type " + other.getClass)
+ }
+ }
+
+ def writeToFile[T](items: java.util.Iterator[T], filename: String) {
import scala.collection.JavaConverters._
- writeIteratorToPickleFile(items.asScala, filename)
+ writeToFile(items.asScala, filename)
}
- def writeIteratorToPickleFile[T](items: Iterator[T], filename: String) {
+ def writeToFile[T](items: Iterator[T], filename: String) {
val file = new DataOutputStream(new FileOutputStream(filename))
for (item <- items) {
- writeAsPickle(item, file)
+ writeToStream(item, file)
}
file.close()
}
@@ -289,17 +241,6 @@ private[spark] object PythonRDD {
}
}
-private object Pickle {
- val PROTO: Byte = 0x80.toByte
- val TWO: Byte = 0x02.toByte
- val BINUNICODE: Byte = 'X'
- val STOP: Byte = '.'
- val TUPLE2: Byte = 0x86.toByte
- val EMPTY_LIST: Byte = ']'
- val MARK: Byte = '('
- val APPENDS: Byte = 'e'
-}
-
private class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] {
override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8")
}
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index 0b4892f98f..c0ce46e379 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -61,50 +61,53 @@ object TaskMetrics {
class ShuffleReadMetrics extends Serializable {
/**
- * Time when shuffle finishs
+ * Absolute time when this task finished reading shuffle data
*/
var shuffleFinishTime: Long = _
/**
- * Total number of blocks fetched in a shuffle (remote or local)
+ * Number of blocks fetched in this shuffle by this task (remote or local)
*/
var totalBlocksFetched: Int = _
/**
- * Number of remote blocks fetched in a shuffle
+ * Number of remote blocks fetched in this shuffle by this task
*/
var remoteBlocksFetched: Int = _
/**
- * Local blocks fetched in a shuffle
+ * Number of local blocks fetched in this shuffle by this task
*/
var localBlocksFetched: Int = _
/**
- * Total time that is spent blocked waiting for shuffle to fetch data
+ * Time the task spent waiting for remote shuffle blocks. This only includes the time
+ * blocking on shuffle input data. For instance if block B is being fetched while the task is
+ * still not finished processing block A, it is not considered to be blocking on block B.
*/
var fetchWaitTime: Long = _
/**
- * The total amount of time for all the shuffle fetches. This adds up time from overlapping
- * shuffles, so can be longer than task time
+ * Total time spent fetching remote shuffle blocks. This aggregates the time spent fetching all
+ * input blocks. Since block fetches are both pipelined and parallelized, this can
+ * exceed fetchWaitTime and executorRunTime.
*/
var remoteFetchTime: Long = _
/**
- * Total number of remote bytes read from a shuffle
+ * Total number of remote bytes read from the shuffle by this task
*/
var remoteBytesRead: Long = _
}
class ShuffleWriteMetrics extends Serializable {
/**
- * Number of bytes written for a shuffle
+ * Number of bytes written for the shuffle by this task
*/
var shuffleBytesWritten: Long = _
/**
- * Time spent blocking on writes to disk or buffer cache, in nanoseconds.
+ * Time the task spent blocking on writes to disk or buffer cache, in nanoseconds
*/
var shuffleWriteTime: Long = _
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
index a4bec41752..02d75eccc5 100644
--- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala
@@ -24,6 +24,8 @@ import org.apache.spark.partial.SumEvaluator
import org.apache.spark.util.StatCounter
import org.apache.spark.{TaskContext, Logging}
+import scala.collection.immutable.NumericRange
+
/**
* Extra functions available on RDDs of Doubles through an implicit conversion.
* Import `org.apache.spark.SparkContext._` at the top of your program to use these functions.
@@ -76,4 +78,128 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable {
val evaluator = new SumEvaluator(self.partitions.size, confidence)
self.context.runApproximateJob(self, processPartition, evaluator, timeout)
}
+
+ /**
+ * Compute a histogram of the data using bucketCount number of buckets evenly
+ * spaced between the minimum and maximum of the RDD. For example if the min
+ * value is 0 and the max is 100 and there are two buckets the resulting
+ * buckets will be [0, 50) [50, 100]. bucketCount must be at least 1
+ * If the RDD contains infinity, NaN throws an exception
+ * If the elements in RDD do not vary (max == min) always returns a single bucket.
+ */
+ def histogram(bucketCount: Int): Pair[Array[Double], Array[Long]] = {
+ // Compute the minimum and the maxium
+ val (max: Double, min: Double) = self.mapPartitions { items =>
+ Iterator(items.foldRight(-1/0.0, Double.NaN)((e: Double, x: Pair[Double, Double]) =>
+ (x._1.max(e), x._2.min(e))))
+ }.reduce { (maxmin1, maxmin2) =>
+ (maxmin1._1.max(maxmin2._1), maxmin1._2.min(maxmin2._2))
+ }
+ if (max.isNaN() || max.isInfinity || min.isInfinity ) {
+ throw new UnsupportedOperationException(
+ "Histogram on either an empty RDD or RDD containing +/-infinity or NaN")
+ }
+ val increment = (max-min)/bucketCount.toDouble
+ val range = if (increment != 0) {
+ Range.Double.inclusive(min, max, increment)
+ } else {
+ List(min, min)
+ }
+ val buckets = range.toArray
+ (buckets, histogram(buckets, true))
+ }
+
+ /**
+ * Compute a histogram using the provided buckets. The buckets are all open
+ * to the left except for the last which is closed
+ * e.g. for the array
+ * [1, 10, 20, 50] the buckets are [1, 10) [10, 20) [20, 50]
+ * e.g 1<=x<10 , 10<=x<20, 20<=x<50
+ * And on the input of 1 and 50 we would have a histogram of 1, 0, 0
+ *
+ * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched
+ * from an O(log n) inseration to O(1) per element. (where n = # buckets) if you set evenBuckets
+ * to true.
+ * buckets must be sorted and not contain any duplicates.
+ * buckets array must be at least two elements
+ * All NaN entries are treated the same. If you have a NaN bucket it must be
+ * the maximum value of the last position and all NaN entries will be counted
+ * in that bucket.
+ */
+ def histogram(buckets: Array[Double], evenBuckets: Boolean = false): Array[Long] = {
+ if (buckets.length < 2) {
+ throw new IllegalArgumentException("buckets array must have at least two elements")
+ }
+ // The histogramPartition function computes the partail histogram for a given
+ // partition. The provided bucketFunction determines which bucket in the array
+ // to increment or returns None if there is no bucket. This is done so we can
+ // specialize for uniformly distributed buckets and save the O(log n) binary
+ // search cost.
+ def histogramPartition(bucketFunction: (Double) => Option[Int])(iter: Iterator[Double]):
+ Iterator[Array[Long]] = {
+ val counters = new Array[Long](buckets.length - 1)
+ while (iter.hasNext) {
+ bucketFunction(iter.next()) match {
+ case Some(x: Int) => {counters(x) += 1}
+ case _ => {}
+ }
+ }
+ Iterator(counters)
+ }
+ // Merge the counters.
+ def mergeCounters(a1: Array[Long], a2: Array[Long]): Array[Long] = {
+ a1.indices.foreach(i => a1(i) += a2(i))
+ a1
+ }
+ // Basic bucket function. This works using Java's built in Array
+ // binary search. Takes log(size(buckets))
+ def basicBucketFunction(e: Double): Option[Int] = {
+ val location = java.util.Arrays.binarySearch(buckets, e)
+ if (location < 0) {
+ // If the location is less than 0 then the insertion point in the array
+ // to keep it sorted is -location-1
+ val insertionPoint = -location-1
+ // If we have to insert before the first element or after the last one
+ // its out of bounds.
+ // We do this rather than buckets.lengthCompare(insertionPoint)
+ // because Array[Double] fails to override it (for now).
+ if (insertionPoint > 0 && insertionPoint < buckets.length) {
+ Some(insertionPoint-1)
+ } else {
+ None
+ }
+ } else if (location < buckets.length - 1) {
+ // Exact match, just insert here
+ Some(location)
+ } else {
+ // Exact match to the last element
+ Some(location - 1)
+ }
+ }
+ // Determine the bucket function in constant time. Requires that buckets are evenly spaced
+ def fastBucketFunction(min: Double, increment: Double, count: Int)(e: Double): Option[Int] = {
+ // If our input is not a number unless the increment is also NaN then we fail fast
+ if (e.isNaN()) {
+ return None
+ }
+ val bucketNumber = (e - min)/(increment)
+ // We do this rather than buckets.lengthCompare(bucketNumber)
+ // because Array[Double] fails to override it (for now).
+ if (bucketNumber > count || bucketNumber < 0) {
+ None
+ } else {
+ Some(bucketNumber.toInt.min(count - 1))
+ }
+ }
+ // Decide which bucket function to pass to histogramPartition. We decide here
+ // rather than having a general function so that the decission need only be made
+ // once rather than once per shard
+ val bucketFunction = if (evenBuckets) {
+ fastBucketFunction(buckets(0), buckets(1)-buckets(0), buckets.length-1) _
+ } else {
+ basicBucketFunction _
+ }
+ self.mapPartitions(histogramPartition(bucketFunction)).reduce(mergeCounters)
+ }
+
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
index cdb5946b49..db15baf503 100644
--- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala
@@ -20,19 +20,16 @@ package org.apache.spark.rdd
import org.apache.spark.{Partition, TaskContext}
import scala.reflect.ClassTag
-
-private[spark]
-class MapPartitionsRDD[U: ClassTag, T: ClassTag](
+private[spark] class MapPartitionsRDD[U: ClassTag, T: ClassTag](
prev: RDD[T],
- f: Iterator[T] => Iterator[U],
+ f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator)
preservesPartitioning: Boolean = false)
extends RDD[U](prev) {
- override val partitioner =
- if (preservesPartitioning) firstParent[T].partitioner else None
+ override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None
override def getPartitions: Array[Partition] = firstParent[T].partitions
override def compute(split: Partition, context: TaskContext) =
- f(firstParent[T].iterator(split, context))
+ f(context, split.index, firstParent[T].iterator(split, context))
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala
deleted file mode 100644
index 67636751bb..0000000000
--- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala
+++ /dev/null
@@ -1,42 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.rdd
-
-import org.apache.spark.{Partition, TaskContext}
-import scala.reflect.ClassTag
-
-
-/**
- * A variant of the MapPartitionsRDD that passes the TaskContext into the closure. From the
- * TaskContext, the closure can either get access to the interruptible flag or get the index
- * of the partition in the RDD.
- */
-private[spark]
-class MapPartitionsWithContextRDD[U: ClassTag, T: ClassTag](
- prev: RDD[T],
- f: (TaskContext, Iterator[T]) => Iterator[U],
- preservesPartitioning: Boolean
- ) extends RDD[U](prev) {
-
- override def getPartitions: Array[Partition] = firstParent[T].partitions
-
- override val partitioner = if (preservesPartitioning) prev.partitioner else None
-
- override def compute(split: Partition, context: TaskContext) =
- f(context, firstParent[T].iterator(split, context))
-}
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index da18d45e65..f80d3d601c 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -411,7 +411,6 @@ abstract class RDD[T: ClassTag](
def pipe(command: String, env: Map[String, String]): RDD[String] =
new PipedRDD(this, command, env)
-
/**
* Return an RDD created by piping elements to a forked external process.
* The print behavior can be customized by providing two functions.
@@ -443,9 +442,10 @@ abstract class RDD[T: ClassTag](
/**
* Return a new RDD by applying a function to each partition of this RDD.
*/
- def mapPartitions[U: ClassTag](f: Iterator[T] => Iterator[U],
- preservesPartitioning: Boolean = false): RDD[U] = {
- new MapPartitionsRDD(this, sc.clean(f), preservesPartitioning)
+ def mapPartitions[U: ClassTag](
+ f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
+ val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(iter)
+ new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning)
}
/**
@@ -454,8 +454,8 @@ abstract class RDD[T: ClassTag](
*/
def mapPartitionsWithIndex[U: ClassTag](
f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = {
- val func = (context: TaskContext, iter: Iterator[T]) => f(context.partitionId, iter)
- new MapPartitionsWithContextRDD(this, sc.clean(func), preservesPartitioning)
+ val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter)
+ new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning)
}
/**
@@ -465,7 +465,8 @@ abstract class RDD[T: ClassTag](
def mapPartitionsWithContext[U: ClassTag](
f: (TaskContext, Iterator[T]) => Iterator[U],
preservesPartitioning: Boolean = false): RDD[U] = {
- new MapPartitionsWithContextRDD(this, sc.clean(f), preservesPartitioning)
+ val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(context, iter)
+ new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning)
}
/**
@@ -486,11 +487,10 @@ abstract class RDD[T: ClassTag](
def mapWith[A: ClassTag, U: ClassTag]
(constructA: Int => A, preservesPartitioning: Boolean = false)
(f: (T, A) => U): RDD[U] = {
- def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = {
- val a = constructA(context.partitionId)
+ mapPartitionsWithIndex((index, iter) => {
+ val a = constructA(index)
iter.map(t => f(t, a))
- }
- new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning)
+ }, preservesPartitioning)
}
/**
@@ -501,11 +501,10 @@ abstract class RDD[T: ClassTag](
def flatMapWith[A: ClassTag, U: ClassTag]
(constructA: Int => A, preservesPartitioning: Boolean = false)
(f: (T, A) => Seq[U]): RDD[U] = {
- def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = {
- val a = constructA(context.partitionId)
+ mapPartitionsWithIndex((index, iter) => {
+ val a = constructA(index)
iter.flatMap(t => f(t, a))
- }
- new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning)
+ }, preservesPartitioning)
}
/**
@@ -514,11 +513,10 @@ abstract class RDD[T: ClassTag](
* partition with the index of that partition.
*/
def foreachWith[A: ClassTag](constructA: Int => A)(f: (T, A) => Unit) {
- def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = {
- val a = constructA(context.partitionId)
+ mapPartitionsWithIndex { (index, iter) =>
+ val a = constructA(index)
iter.map(t => {f(t, a); t})
- }
- new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true).foreach(_ => {})
+ }.foreach(_ => {})
}
/**
@@ -527,11 +525,10 @@ abstract class RDD[T: ClassTag](
* partition with the index of that partition.
*/
def filterWith[A: ClassTag](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = {
- def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = {
- val a = constructA(context.partitionId)
+ mapPartitionsWithIndex((index, iter) => {
+ val a = constructA(index)
iter.filter(t => p(t, a))
- }
- new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true)
+ }, preservesPartitioning = true)
}
/**
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 773e9ec182..201572d16a 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -112,6 +112,9 @@ class DAGScheduler(
// resubmit failed stages
val POLL_TIMEOUT = 10L
+ // Warns the user if a stage contains a task with size greater than this value (in KB)
+ val TASK_SIZE_TO_WARN = 100
+
private val eventProcessActor: ActorRef = env.actorSystem.actorOf(Props(new Actor {
override def preStart() {
import context.dispatcher
@@ -433,6 +436,18 @@ class DAGScheduler(
handleExecutorLost(execId)
case BeginEvent(task, taskInfo) =>
+ for (
+ job <- idToActiveJob.get(task.stageId);
+ stage <- stageIdToStage.get(task.stageId);
+ stageInfo <- stageToInfos.get(stage)
+ ) {
+ if (taskInfo.serializedSize > TASK_SIZE_TO_WARN * 1024 && !stageInfo.emittedTaskSizeWarning) {
+ stageInfo.emittedTaskSizeWarning = true
+ logWarning(("Stage %d (%s) contains a task of very large " +
+ "size (%d KB). The maximum recommended task size is %d KB.").format(
+ task.stageId, stageInfo.name, taskInfo.serializedSize / 1024, TASK_SIZE_TO_WARN))
+ }
+ }
listenerBus.post(SparkListenerTaskStart(task, taskInfo))
case GettingResultEvent(task, taskInfo) =>
diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
index 93599dfdc8..e9f2198a00 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala
@@ -33,4 +33,5 @@ class StageInfo(
val name = stage.name
val numPartitions = stage.numPartitions
val numTasks = stage.numTasks
+ var emittedTaskSizeWarning = false
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
index 4bae26f3a6..3c22edd524 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala
@@ -46,6 +46,8 @@ class TaskInfo(
var failed = false
+ var serializedSize: Int = 0
+
def markGettingResult(time: Long = System.currentTimeMillis) {
gettingResultTime = time
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
index 4c5eca8537..8884ea85a3 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala
@@ -377,6 +377,7 @@ private[spark] class ClusterTaskSetManager(
logInfo("Serialized task %s:%d as %d bytes in %d ms".format(
taskSet.id, index, serializedTask.limit, timeTaken))
val taskName = "task %s:%d".format(taskSet.id, index)
+ info.serializedSize = serializedTask.limit
if (taskAttempts(index).size == 1)
taskStarted(task,info)
return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask))
diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
index 49d95afdb9..87e009a4de 100644
--- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
+++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala
@@ -80,6 +80,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag](
protected var _capacity = nextPowerOf2(initialCapacity)
protected var _mask = _capacity - 1
protected var _size = 0
+ protected var _growThreshold = (loadFactor * _capacity).toInt
protected var _bitset = new BitSet(_capacity)
@@ -116,7 +117,29 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag](
* @return The position where the key is placed, plus the highest order bit is set if the key
* exists previously.
*/
- def addWithoutResize(k: T): Int = putInto(_bitset, _data, k)
+ def addWithoutResize(k: T): Int = {
+ var pos = hashcode(hasher.hash(k)) & _mask
+ var i = 1
+ while (true) {
+ if (!_bitset.get(pos)) {
+ // This is a new key.
+ _data(pos) = k
+ _bitset.set(pos)
+ _size += 1
+ return pos | NONEXISTENCE_MASK
+ } else if (_data(pos) == k) {
+ // Found an existing key.
+ return pos
+ } else {
+ val delta = i
+ pos = (pos + delta) & _mask
+ i += 1
+ }
+ }
+ // Never reached here
+ assert(INVALID_POS != INVALID_POS)
+ INVALID_POS
+ }
/**
* Rehash the set if it is overloaded.
@@ -127,7 +150,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag](
* to a new position (in the new data array).
*/
def rehashIfNeeded(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) {
- if (_size > loadFactor * _capacity) {
+ if (_size > _growThreshold) {
rehash(k, allocateFunc, moveFunc)
}
}
@@ -162,37 +185,6 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag](
def nextPos(fromPos: Int): Int = _bitset.nextSetBit(fromPos)
/**
- * Put an entry into the set. Return the position where the key is placed. In addition, the
- * highest bit in the returned position is set if the key exists prior to this put.
- *
- * This function assumes the data array has at least one empty slot.
- */
- private def putInto(bitset: BitSet, data: Array[T], k: T): Int = {
- val mask = data.length - 1
- var pos = hashcode(hasher.hash(k)) & mask
- var i = 1
- while (true) {
- if (!bitset.get(pos)) {
- // This is a new key.
- data(pos) = k
- bitset.set(pos)
- _size += 1
- return pos | NONEXISTENCE_MASK
- } else if (data(pos) == k) {
- // Found an existing key.
- return pos
- } else {
- val delta = i
- pos = (pos + delta) & mask
- i += 1
- }
- }
- // Never reached here
- assert(INVALID_POS != INVALID_POS)
- INVALID_POS
- }
-
- /**
* Double the table's size and re-hash everything. We are not really using k, but it is declared
* so Scala compiler can specialize this method (which leads to calling the specialized version
* of putInto).
@@ -205,34 +197,49 @@ class OpenHashSet[@specialized(Long, Int) T: ClassTag](
*/
private def rehash(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) {
val newCapacity = _capacity * 2
- require(newCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements")
-
allocateFunc(newCapacity)
- val newData = new Array[T](newCapacity)
val newBitset = new BitSet(newCapacity)
- var pos = 0
- _size = 0
- while (pos < _capacity) {
- if (_bitset.get(pos)) {
- val newPos = putInto(newBitset, newData, _data(pos))
- moveFunc(pos, newPos & POSITION_MASK)
+ val newData = new Array[T](newCapacity)
+ val newMask = newCapacity - 1
+
+ var oldPos = 0
+ while (oldPos < capacity) {
+ if (_bitset.get(oldPos)) {
+ val key = _data(oldPos)
+ var newPos = hashcode(hasher.hash(key)) & newMask
+ var i = 1
+ var keepGoing = true
+ // No need to check for equality here when we insert so this has one less if branch than
+ // the similar code path in addWithoutResize.
+ while (keepGoing) {
+ if (!newBitset.get(newPos)) {
+ // Inserting the key at newPos
+ newData(newPos) = key
+ newBitset.set(newPos)
+ moveFunc(oldPos, newPos)
+ keepGoing = false
+ } else {
+ val delta = i
+ newPos = (newPos + delta) & newMask
+ i += 1
+ }
+ }
}
- pos += 1
+ oldPos += 1
}
+
_bitset = newBitset
_data = newData
_capacity = newCapacity
- _mask = newCapacity - 1
+ _mask = newMask
+ _growThreshold = (loadFactor * newCapacity).toInt
}
/**
- * Re-hash a value to deal better with hash functions that don't differ
- * in the lower bits, similar to java.util.HashMap
+ * Re-hash a value to deal better with hash functions that don't differ in the lower bits.
+ * We use the Murmur Hash 3 finalization step that's also used in fastutil.
*/
- private def hashcode(h: Int): Int = {
- val r = h ^ (h >>> 20) ^ (h >>> 12)
- r ^ (r >>> 7) ^ (r >>> 4)
- }
+ private def hashcode(h: Int): Int = it.unimi.dsi.fastutil.HashCommon.murmurHash3(h)
private def nextPowerOf2(n: Int): Int = {
val highBit = Integer.highestOneBit(n)
diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
index fcfc2c9893..f25d921d3f 100644
--- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
+++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala
@@ -63,8 +63,6 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging {
testCheckpointing(_.sample(false, 0.5, 0))
testCheckpointing(_.glom())
testCheckpointing(_.mapPartitions(_.map(_.toString)))
- testCheckpointing(r => new MapPartitionsWithContextRDD(r,
- (context: TaskContext, iter: Iterator[Int]) => iter.map(_.toString), false ))
testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString))
testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x))
testCheckpointing(_.pipe(Seq("cat")))
diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
index 352036f182..4234f6eac7 100644
--- a/core/src/test/scala/org/apache/spark/JavaAPISuite.java
+++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java
@@ -365,6 +365,20 @@ public class JavaAPISuite implements Serializable {
}
@Test
+ public void javaDoubleRDDHistoGram() {
+ JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0));
+ // Test using generated buckets
+ Tuple2<double[], long[]> results = rdd.histogram(2);
+ double[] expected_buckets = {1.0, 2.5, 4.0};
+ long[] expected_counts = {2, 2};
+ Assert.assertArrayEquals(expected_buckets, results._1, 0.1);
+ Assert.assertArrayEquals(expected_counts, results._2);
+ // Test with provided buckets
+ long[] histogram = rdd.histogram(expected_buckets);
+ Assert.assertArrayEquals(expected_counts, histogram);
+ }
+
+ @Test
public void map() {
JavaRDD<Integer> rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5));
JavaDoubleRDD doubles = rdd.map(new DoubleFunction<Integer>() {
diff --git a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala
new file mode 100644
index 0000000000..7f50a5a47c
--- /dev/null
+++ b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala
@@ -0,0 +1,271 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.rdd
+
+import scala.math.abs
+import scala.collection.mutable.ArrayBuffer
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.SparkContext._
+import org.apache.spark.rdd._
+import org.apache.spark._
+
+class DoubleRDDSuite extends FunSuite with SharedSparkContext {
+ // Verify tests on the histogram functionality. We test with both evenly
+ // and non-evenly spaced buckets as the bucket lookup function changes.
+ test("WorksOnEmpty") {
+ // Make sure that it works on an empty input
+ val rdd: RDD[Double] = sc.parallelize(Seq())
+ val buckets = Array(0.0, 10.0)
+ val histogramResults = rdd.histogram(buckets)
+ val histogramResults2 = rdd.histogram(buckets, true)
+ val expectedHistogramResults = Array(0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramResults2 === expectedHistogramResults)
+ }
+
+ test("WorksWithOutOfRangeWithOneBucket") {
+ // Verify that if all of the elements are out of range the counts are zero
+ val rdd = sc.parallelize(Seq(10.01, -0.01))
+ val buckets = Array(0.0, 10.0)
+ val histogramResults = rdd.histogram(buckets)
+ val histogramResults2 = rdd.histogram(buckets, true)
+ val expectedHistogramResults = Array(0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramResults2 === expectedHistogramResults)
+ }
+
+ test("WorksInRangeWithOneBucket") {
+ // Verify the basic case of one bucket and all elements in that bucket works
+ val rdd = sc.parallelize(Seq(1, 2, 3, 4))
+ val buckets = Array(0.0, 10.0)
+ val histogramResults = rdd.histogram(buckets)
+ val histogramResults2 = rdd.histogram(buckets, true)
+ val expectedHistogramResults = Array(4)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramResults2 === expectedHistogramResults)
+ }
+
+ test("WorksInRangeWithOneBucketExactMatch") {
+ // Verify the basic case of one bucket and all elements in that bucket works
+ val rdd = sc.parallelize(Seq(1, 2, 3, 4))
+ val buckets = Array(1.0, 4.0)
+ val histogramResults = rdd.histogram(buckets)
+ val histogramResults2 = rdd.histogram(buckets, true)
+ val expectedHistogramResults = Array(4)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramResults2 === expectedHistogramResults)
+ }
+
+ test("WorksWithOutOfRangeWithTwoBuckets") {
+ // Verify that out of range works with two buckets
+ val rdd = sc.parallelize(Seq(10.01, -0.01))
+ val buckets = Array(0.0, 5.0, 10.0)
+ val histogramResults = rdd.histogram(buckets)
+ val histogramResults2 = rdd.histogram(buckets, true)
+ val expectedHistogramResults = Array(0, 0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramResults2 === expectedHistogramResults)
+ }
+
+ test("WorksWithOutOfRangeWithTwoUnEvenBuckets") {
+ // Verify that out of range works with two un even buckets
+ val rdd = sc.parallelize(Seq(10.01, -0.01))
+ val buckets = Array(0.0, 4.0, 10.0)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(0, 0)
+ assert(histogramResults === expectedHistogramResults)
+ }
+
+ test("WorksInRangeWithTwoBuckets") {
+ // Make sure that it works with two equally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(1, 2, 3, 5, 6))
+ val buckets = Array(0.0, 5.0, 10.0)
+ val histogramResults = rdd.histogram(buckets)
+ val histogramResults2 = rdd.histogram(buckets, true)
+ val expectedHistogramResults = Array(3, 2)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramResults2 === expectedHistogramResults)
+ }
+
+ test("WorksInRangeWithTwoBucketsAndNaN") {
+ // Make sure that it works with two equally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(1, 2, 3, 5, 6, Double.NaN))
+ val buckets = Array(0.0, 5.0, 10.0)
+ val histogramResults = rdd.histogram(buckets)
+ val histogramResults2 = rdd.histogram(buckets, true)
+ val expectedHistogramResults = Array(3, 2)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramResults2 === expectedHistogramResults)
+ }
+
+ test("WorksInRangeWithTwoUnevenBuckets") {
+ // Make sure that it works with two unequally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(1, 2, 3, 5, 6))
+ val buckets = Array(0.0, 5.0, 11.0)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(3, 2)
+ assert(histogramResults === expectedHistogramResults)
+ }
+
+ test("WorksMixedRangeWithTwoUnevenBuckets") {
+ // Make sure that it works with two unequally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.0, 11.01))
+ val buckets = Array(0.0, 5.0, 11.0)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(4, 3)
+ assert(histogramResults === expectedHistogramResults)
+ }
+
+ test("WorksMixedRangeWithFourUnevenBuckets") {
+ // Make sure that it works with two unequally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0,
+ 200.0, 200.1))
+ val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(4, 2, 1, 3)
+ assert(histogramResults === expectedHistogramResults)
+ }
+
+ test("WorksMixedRangeWithUnevenBucketsAndNaN") {
+ // Make sure that it works with two unequally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0,
+ 200.0, 200.1, Double.NaN))
+ val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(4, 2, 1, 3)
+ assert(histogramResults === expectedHistogramResults)
+ }
+ // Make sure this works with a NaN end bucket
+ test("WorksMixedRangeWithUnevenBucketsAndNaNAndNaNRange") {
+ // Make sure that it works with two unequally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0,
+ 200.0, 200.1, Double.NaN))
+ val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0, Double.NaN)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(4, 2, 1, 2, 3)
+ assert(histogramResults === expectedHistogramResults)
+ }
+ // Make sure this works with a NaN end bucket and an inifity
+ test("WorksMixedRangeWithUnevenBucketsAndNaNAndNaNRangeAndInfity") {
+ // Make sure that it works with two unequally spaced buckets and elements in each
+ val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0,
+ 200.0, 200.1, 1.0/0.0, -1.0/0.0, Double.NaN))
+ val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0, Double.NaN)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(4, 2, 1, 2, 4)
+ assert(histogramResults === expectedHistogramResults)
+ }
+
+ test("WorksWithOutOfRangeWithInfiniteBuckets") {
+ // Verify that out of range works with two buckets
+ val rdd = sc.parallelize(Seq(10.01, -0.01, Double.NaN))
+ val buckets = Array(-1.0/0.0 , 0.0, 1.0/0.0)
+ val histogramResults = rdd.histogram(buckets)
+ val expectedHistogramResults = Array(1, 1)
+ assert(histogramResults === expectedHistogramResults)
+ }
+ // Test the failure mode with an invalid bucket array
+ test("ThrowsExceptionOnInvalidBucketArray") {
+ val rdd = sc.parallelize(Seq(1.0))
+ // Empty array
+ intercept[IllegalArgumentException] {
+ val buckets = Array.empty[Double]
+ val result = rdd.histogram(buckets)
+ }
+ // Single element array
+ intercept[IllegalArgumentException] {
+ val buckets = Array(1.0)
+ val result = rdd.histogram(buckets)
+ }
+ }
+
+ // Test automatic histogram function
+ test("WorksWithoutBucketsBasic") {
+ // Verify the basic case of one bucket and all elements in that bucket works
+ val rdd = sc.parallelize(Seq(1, 2, 3, 4))
+ val (histogramBuckets, histogramResults) = rdd.histogram(1)
+ val expectedHistogramResults = Array(4)
+ val expectedHistogramBuckets = Array(1.0, 4.0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramBuckets === expectedHistogramBuckets)
+ }
+ // Test automatic histogram function with a single element
+ test("WorksWithoutBucketsBasicSingleElement") {
+ // Verify the basic case of one bucket and all elements in that bucket works
+ val rdd = sc.parallelize(Seq(1))
+ val (histogramBuckets, histogramResults) = rdd.histogram(1)
+ val expectedHistogramResults = Array(1)
+ val expectedHistogramBuckets = Array(1.0, 1.0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramBuckets === expectedHistogramBuckets)
+ }
+ // Test automatic histogram function with a single element
+ test("WorksWithoutBucketsBasicNoRange") {
+ // Verify the basic case of one bucket and all elements in that bucket works
+ val rdd = sc.parallelize(Seq(1, 1, 1, 1))
+ val (histogramBuckets, histogramResults) = rdd.histogram(1)
+ val expectedHistogramResults = Array(4)
+ val expectedHistogramBuckets = Array(1.0, 1.0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramBuckets === expectedHistogramBuckets)
+ }
+
+ test("WorksWithoutBucketsBasicTwo") {
+ // Verify the basic case of one bucket and all elements in that bucket works
+ val rdd = sc.parallelize(Seq(1, 2, 3, 4))
+ val (histogramBuckets, histogramResults) = rdd.histogram(2)
+ val expectedHistogramResults = Array(2, 2)
+ val expectedHistogramBuckets = Array(1.0, 2.5, 4.0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramBuckets === expectedHistogramBuckets)
+ }
+
+ test("WorksWithoutBucketsWithMoreRequestedThanElements") {
+ // Verify the basic case of one bucket and all elements in that bucket works
+ val rdd = sc.parallelize(Seq(1, 2))
+ val (histogramBuckets, histogramResults) = rdd.histogram(10)
+ val expectedHistogramResults =
+ Array(1, 0, 0, 0, 0, 0, 0, 0, 0, 1)
+ val expectedHistogramBuckets =
+ Array(1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0)
+ assert(histogramResults === expectedHistogramResults)
+ assert(histogramBuckets === expectedHistogramBuckets)
+ }
+
+ // Test the failure mode with an invalid RDD
+ test("ThrowsExceptionOnInvalidRDDs") {
+ // infinity
+ intercept[UnsupportedOperationException] {
+ val rdd = sc.parallelize(Seq(1, 1.0/0.0))
+ val result = rdd.histogram(1)
+ }
+ // NaN
+ intercept[UnsupportedOperationException] {
+ val rdd = sc.parallelize(Seq(1, Double.NaN))
+ val result = rdd.histogram(1)
+ }
+ // Empty
+ intercept[UnsupportedOperationException] {
+ val rdd: RDD[Double] = sc.parallelize(Seq())
+ val result = rdd.histogram(1)
+ }
+ }
+
+}
diff --git a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala
index 984881861c..002368ff55 100644
--- a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala
+++ b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala
@@ -31,6 +31,7 @@ import org.apache.spark.rdd.RDD
class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers {
+ val WAIT_TIMEOUT_MILLIS = 10000
test("inner method") {
sc = new SparkContext("local", "joblogger")
@@ -92,6 +93,8 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) }
rdd.reduceByKey(_+_).collect()
+ assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
+
val user = System.getProperty("user.name", SparkContext.SPARK_UNKNOWN_USER)
joblogger.getLogDir should be ("/tmp/spark-%s".format(user))
@@ -120,7 +123,9 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers
sc.addSparkListener(joblogger)
val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) }
rdd.reduceByKey(_+_).collect()
-
+
+ assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS))
+
joblogger.onJobStartCount should be (1)
joblogger.onJobEndCount should be (1)
joblogger.onTaskEndCount should be (8)
diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala
index ca3f684668..63e874fed3 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala
@@ -2,8 +2,20 @@ package org.apache.spark.util.collection
import scala.collection.mutable.HashSet
import org.scalatest.FunSuite
-
-class OpenHashMapSuite extends FunSuite {
+import org.scalatest.matchers.ShouldMatchers
+import org.apache.spark.util.SizeEstimator
+
+class OpenHashMapSuite extends FunSuite with ShouldMatchers {
+
+ test("size for specialized, primitive value (int)") {
+ val capacity = 1024
+ val map = new OpenHashMap[String, Int](capacity)
+ val actualSize = SizeEstimator.estimate(map)
+ // 64 bit for pointers, 32 bit for ints, and 1 bit for the bitset.
+ val expectedSize = capacity * (64 + 32 + 1) / 8
+ // Make sure we are not allocating a significant amount of memory beyond our expected.
+ actualSize should be <= (expectedSize * 1.1).toLong
+ }
test("initialization") {
val goodMap1 = new OpenHashMap[String, Int](1)
diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala
index 4e11e8a628..4768a1e60b 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala
@@ -1,9 +1,27 @@
package org.apache.spark.util.collection
import org.scalatest.FunSuite
+import org.scalatest.matchers.ShouldMatchers
+import org.apache.spark.util.SizeEstimator
-class OpenHashSetSuite extends FunSuite {
+
+class OpenHashSetSuite extends FunSuite with ShouldMatchers {
+
+ test("size for specialized, primitive int") {
+ val loadFactor = 0.7
+ val set = new OpenHashSet[Int](64, loadFactor)
+ for (i <- 0 until 1024) {
+ set.add(i)
+ }
+ assert(set.size === 1024)
+ assert(set.capacity > 1024)
+ val actualSize = SizeEstimator.estimate(set)
+ // 32 bits for the ints + 1 bit for the bitset
+ val expectedSize = set.capacity * (32 + 1) / 8
+ // Make sure we are not allocating a significant amount of memory beyond our expected.
+ actualSize should be <= (expectedSize * 1.1).toLong
+ }
test("primitive int") {
val set = new OpenHashSet[Int]
diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala
index dfd6aed2c4..2220b4f0d5 100644
--- a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashSetSuite.scala
+++ b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala
@@ -2,8 +2,20 @@ package org.apache.spark.util.collection
import scala.collection.mutable.HashSet
import org.scalatest.FunSuite
+import org.scalatest.matchers.ShouldMatchers
+import org.apache.spark.util.SizeEstimator
-class PrimitiveKeyOpenHashSetSuite extends FunSuite {
+class PrimitiveKeyOpenHashMapSuite extends FunSuite with ShouldMatchers {
+
+ test("size for specialized, primitive key, value (int, int)") {
+ val capacity = 1024
+ val map = new PrimitiveKeyOpenHashMap[Int, Int](capacity)
+ val actualSize = SizeEstimator.estimate(map)
+ // 32 bit for keys, 32 bit for values, and 1 bit for the bitset.
+ val expectedSize = capacity * (32 + 32 + 1) / 8
+ // Make sure we are not allocating a significant amount of memory beyond our expected.
+ actualSize should be <= (expectedSize * 1.1).toLong
+ }
test("initialization") {
val goodMap1 = new PrimitiveKeyOpenHashMap[Int, Int](1)
diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md
index 4056e9c15d..68fd6c2ab1 100644
--- a/docs/running-on-yarn.md
+++ b/docs/running-on-yarn.md
@@ -45,6 +45,10 @@ System Properties:
Ensure that HADOOP_CONF_DIR or YARN_CONF_DIR points to the directory which contains the (client side) configuration files for the hadoop cluster.
This would be used to connect to the cluster, write to the dfs and submit jobs to the resource manager.
+There are two scheduler mode that can be used to launch spark application on YARN.
+
+## Launch spark application by YARN Client with yarn-standalone mode.
+
The command to launch the YARN Client is as follows:
SPARK_JAR=<SPARK_ASSEMBLY_JAR_FILE> ./spark-class org.apache.spark.deploy.yarn.Client \
@@ -52,6 +56,7 @@ The command to launch the YARN Client is as follows:
--class <APP_MAIN_CLASS> \
--args <APP_MAIN_ARGUMENTS> \
--num-workers <NUMBER_OF_WORKER_MACHINES> \
+ --master-class <ApplicationMaster_CLASS>
--master-memory <MEMORY_FOR_MASTER> \
--worker-memory <MEMORY_PER_WORKER> \
--worker-cores <CORES_PER_WORKER> \
@@ -85,11 +90,29 @@ For example:
$ cat $YARN_APP_LOGS_DIR/$YARN_APP_ID/container*_000001/stdout
Pi is roughly 3.13794
-The above starts a YARN Client programs which periodically polls the Application Master for status updates and displays them in the console. The client will exit once your application has finished running.
+The above starts a YARN Client programs which start the default Application Master. Then SparkPi will be run as a child thread of Application Master, YARN Client will periodically polls the Application Master for status updates and displays them in the console. The client will exit once your application has finished running.
+
+With this mode, your application is actually run on the remote machine where the Application Master is run upon. Thus application that involve local interaction will not work well, e.g. spark-shell.
+
+## Launch spark application with yarn-client mode.
+
+With yarn-client mode, the application will be launched locally. Just like running application or spark-shell on Local / Mesos / Standalone mode. The launch method is also the similar with them, just make sure that when you need to specify a master url, use "yarn-client" instead. And you also need to export the env value for SPARK_JAR and SPARK_YARN_APP_JAR
+
+In order to tune worker core/number/memory etc. You need to export SPARK_WORKER_CORES, SPARK_WORKER_MEMORY, SPARK_WORKER_INSTANCES e.g. by ./conf/spark-env.sh
+
+For example:
+
+ SPARK_JAR=./assembly/target/scala-{{site.SCALA_VERSION}}/spark-assembly-{{site.SPARK_VERSION}}-hadoop2.0.5-alpha.jar \
+ SPARK_YARN_APP_JAR=examples/target/scala-{{site.SCALA_VERSION}}/spark-examples-assembly-{{site.SPARK_VERSION}}.jar \
+ ./run-example org.apache.spark.examples.SparkPi yarn-client
+
+
+ SPARK_JAR=./assembly/target/scala-{{site.SCALA_VERSION}}/spark-assembly-{{site.SPARK_VERSION}}-hadoop2.0.5-alpha.jar \
+ SPARK_YARN_APP_JAR=examples/target/scala-{{site.SCALA_VERSION}}/spark-examples-assembly-{{site.SPARK_VERSION}}.jar \
+ MASTER=yarn-client ./spark-shell
# Important Notes
-- When your application instantiates a Spark context it must use a special "yarn-standalone" master url. This starts the scheduler without forcing it to connect to a cluster. A good way to handle this is to pass "yarn-standalone" as an argument to your program, as shown in the example above.
- We do not requesting container resources based on the number of cores. Thus the numbers of cores given via command line arguments cannot be guaranteed.
- The local directories used for spark will be the local directories configured for YARN (Hadoop Yarn config yarn.nodemanager.local-dirs). If the user specifies spark.local.dir, it will be ignored.
- The --files and --archives options support specifying file names with the # similar to Hadoop. For example you can specify: --files localtest.txt#appSees.txt and this will upload the file you have locally named localtest.txt into HDFS but this will be linked to by the name appSees.txt and your application should use the name as appSees.txt to reference it when running on YARN.
diff --git a/docs/tuning.md b/docs/tuning.md
index f33fda37eb..a4be188169 100644
--- a/docs/tuning.md
+++ b/docs/tuning.md
@@ -39,7 +39,8 @@ in your operations) and performance. It provides two serialization libraries:
for best performance.
You can switch to using Kryo by calling `System.setProperty("spark.serializer", "org.apache.spark.serializer.KryoSerializer")`
-*before* creating your SparkContext. The only reason it is not the default is because of the custom
+*before* creating your SparkContext. This setting configures the serializer used for not only shuffling data between worker
+nodes but also when serializing RDDs to disk. The only reason Kryo is not the default is because of the custom
registration requirement, but we recommend trying it in any network-intensive application.
Finally, to register your classes with Kryo, create a public class that extends
diff --git a/python/epydoc.conf b/python/epydoc.conf
index 1d0d002d36..0b42e729f8 100644
--- a/python/epydoc.conf
+++ b/python/epydoc.conf
@@ -32,6 +32,6 @@ target: docs/
private: no
-exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.serializers
+exclude: pyspark.cloudpickle pyspark.worker pyspark.join
pyspark.java_gateway pyspark.examples pyspark.shell pyspark.test
pyspark.rddsampler pyspark.daemon
diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py
index da3d96689a..2204e9c9ca 100644
--- a/python/pyspark/accumulators.py
+++ b/python/pyspark/accumulators.py
@@ -90,9 +90,11 @@ import struct
import SocketServer
import threading
from pyspark.cloudpickle import CloudPickler
-from pyspark.serializers import read_int, read_with_length, load_pickle
+from pyspark.serializers import read_int, PickleSerializer
+pickleSer = PickleSerializer()
+
# Holds accumulators registered on the current machine, keyed by ID. This is then used to send
# the local accumulator updates back to the driver program at the end of a task.
_accumulatorRegistry = {}
@@ -211,7 +213,7 @@ class _UpdateRequestHandler(SocketServer.StreamRequestHandler):
from pyspark.accumulators import _accumulatorRegistry
num_updates = read_int(self.rfile)
for _ in range(num_updates):
- (aid, update) = load_pickle(read_with_length(self.rfile))
+ (aid, update) = pickleSer._read_with_length(self.rfile)
_accumulatorRegistry[aid] += update
# Write a byte in acknowledgement
self.wfile.write(struct.pack("!b", 1))
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index a7ca8bc888..cbd41e58c4 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -26,7 +26,7 @@ from pyspark.accumulators import Accumulator
from pyspark.broadcast import Broadcast
from pyspark.files import SparkFiles
from pyspark.java_gateway import launch_gateway
-from pyspark.serializers import dump_pickle, write_with_length, batched
+from pyspark.serializers import PickleSerializer, BatchedSerializer, MUTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.rdd import RDD
@@ -42,7 +42,7 @@ class SparkContext(object):
_gateway = None
_jvm = None
- _writeIteratorToPickleFile = None
+ _writeToFile = None
_takePartition = None
_next_accum_id = 0
_active_spark_context = None
@@ -51,7 +51,7 @@ class SparkContext(object):
def __init__(self, master, jobName, sparkHome=None, pyFiles=None,
- environment=None, batchSize=1024):
+ environment=None, batchSize=1024, serializer=PickleSerializer()):
"""
Create a new SparkContext.
@@ -67,6 +67,7 @@ class SparkContext(object):
@param batchSize: The number of Python objects represented as a single
Java object. Set 1 to disable batching or -1 to use an
unlimited batch size.
+ @param serializer: The serializer for RDDs.
>>> from pyspark.context import SparkContext
@@ -83,7 +84,13 @@ class SparkContext(object):
self.jobName = jobName
self.sparkHome = sparkHome or None # None becomes null in Py4J
self.environment = environment or {}
- self.batchSize = batchSize # -1 represents a unlimited batch size
+ self._batchSize = batchSize # -1 represents an unlimited batch size
+ self._unbatched_serializer = serializer
+ if batchSize == 1:
+ self.serializer = self._unbatched_serializer
+ else:
+ self.serializer = BatchedSerializer(self._unbatched_serializer,
+ batchSize)
# Create the Java SparkContext through Py4J
empty_string_array = self._gateway.new_array(self._jvm.String, 0)
@@ -125,8 +132,8 @@ class SparkContext(object):
if not SparkContext._gateway:
SparkContext._gateway = launch_gateway()
SparkContext._jvm = SparkContext._gateway.jvm
- SparkContext._writeIteratorToPickleFile = \
- SparkContext._jvm.PythonRDD.writeIteratorToPickleFile
+ SparkContext._writeToFile = \
+ SparkContext._jvm.PythonRDD.writeToFile
SparkContext._takePartition = \
SparkContext._jvm.PythonRDD.takePartition
@@ -184,15 +191,17 @@ class SparkContext(object):
# Make sure we distribute data evenly if it's smaller than self.batchSize
if "__len__" not in dir(c):
c = list(c) # Make it a list so we can compute its length
- batchSize = min(len(c) // numSlices, self.batchSize)
+ batchSize = min(len(c) // numSlices, self._batchSize)
if batchSize > 1:
- c = batched(c, batchSize)
- for x in c:
- write_with_length(dump_pickle(x), tempFile)
+ serializer = BatchedSerializer(self._unbatched_serializer,
+ batchSize)
+ else:
+ serializer = self._unbatched_serializer
+ serializer.dump_stream(c, tempFile)
tempFile.close()
- readRDDFromPickleFile = self._jvm.PythonRDD.readRDDFromPickleFile
- jrdd = readRDDFromPickleFile(self._jsc, tempFile.name, numSlices)
- return RDD(jrdd, self)
+ readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile
+ jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices)
+ return RDD(jrdd, self, serializer)
def textFile(self, name, minSplits=None):
"""
@@ -201,21 +210,39 @@ class SparkContext(object):
RDD of Strings.
"""
minSplits = minSplits or min(self.defaultParallelism, 2)
- jrdd = self._jsc.textFile(name, minSplits)
- return RDD(jrdd, self)
+ return RDD(self._jsc.textFile(name, minSplits), self,
+ MUTF8Deserializer())
- def _checkpointFile(self, name):
+ def _checkpointFile(self, name, input_deserializer):
jrdd = self._jsc.checkpointFile(name)
- return RDD(jrdd, self)
+ return RDD(jrdd, self, input_deserializer)
def union(self, rdds):
"""
Build the union of a list of RDDs.
+
+ This supports unions() of RDDs with different serialized formats,
+ although this forces them to be reserialized using the default
+ serializer:
+
+ >>> path = os.path.join(tempdir, "union-text.txt")
+ >>> with open(path, "w") as testFile:
+ ... testFile.write("Hello")
+ >>> textFile = sc.textFile(path)
+ >>> textFile.collect()
+ [u'Hello']
+ >>> parallelized = sc.parallelize(["World!"])
+ >>> sorted(sc.union([textFile, parallelized]).collect())
+ [u'Hello', 'World!']
"""
+ first_jrdd_deserializer = rdds[0]._jrdd_deserializer
+ if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds):
+ rdds = [x._reserialize() for x in rdds]
first = rdds[0]._jrdd
rest = [x._jrdd for x in rdds[1:]]
- rest = ListConverter().convert(rest, self.gateway._gateway_client)
- return RDD(self._jsc.union(first, rest), self)
+ rest = ListConverter().convert(rest, self._gateway._gateway_client)
+ return RDD(self._jsc.union(first, rest), self,
+ rdds[0]._jrdd_deserializer)
def broadcast(self, value):
"""
@@ -223,7 +250,9 @@ class SparkContext(object):
object for reading it in distributed functions. The variable will be
sent to each cluster only once.
"""
- jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value)))
+ pickleSer = PickleSerializer()
+ pickled = pickleSer.dumps(value)
+ jbroadcast = self._jsc.broadcast(bytearray(pickled))
return Broadcast(jbroadcast.id(), value, jbroadcast,
self._pickled_broadcast_vars)
@@ -235,7 +264,7 @@ class SparkContext(object):
and floating-point numbers if you do not provide one. For other types,
a custom AccumulatorParam can be used.
"""
- if accum_param == None:
+ if accum_param is None:
if isinstance(value, int):
accum_param = accumulators.INT_ACCUMULATOR_PARAM
elif isinstance(value, float):
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 245a132dfd..d2cb5f191a 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -18,7 +18,7 @@
from base64 import standard_b64encode as b64enc
import copy
from collections import defaultdict
-from itertools import chain, ifilter, imap, product
+from itertools import chain, ifilter, imap
import operator
import os
import sys
@@ -27,9 +27,8 @@ from subprocess import Popen, PIPE
from tempfile import NamedTemporaryFile
from threading import Thread
-from pyspark import cloudpickle
-from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \
- read_from_pickle_file, pack_long
+from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \
+ BatchedSerializer, CloudPickleSerializer, pack_long
from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_cogroup
from pyspark.statcounter import StatCounter
@@ -48,12 +47,12 @@ class RDD(object):
operated on in parallel.
"""
- def __init__(self, jrdd, ctx):
+ def __init__(self, jrdd, ctx, jrdd_deserializer):
self._jrdd = jrdd
self.is_cached = False
self.is_checkpointed = False
self.ctx = ctx
- self._partitionFunc = None
+ self._jrdd_deserializer = jrdd_deserializer
@property
def context(self):
@@ -247,7 +246,23 @@ class RDD(object):
>>> rdd.union(rdd).collect()
[1, 1, 2, 3, 1, 1, 2, 3]
"""
- return RDD(self._jrdd.union(other._jrdd), self.ctx)
+ if self._jrdd_deserializer == other._jrdd_deserializer:
+ rdd = RDD(self._jrdd.union(other._jrdd), self.ctx,
+ self._jrdd_deserializer)
+ return rdd
+ else:
+ # These RDDs contain data in different serialized formats, so we
+ # must normalize them to the default serializer.
+ self_copy = self._reserialize()
+ other_copy = other._reserialize()
+ return RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx,
+ self.ctx.serializer)
+
+ def _reserialize(self):
+ if self._jrdd_deserializer == self.ctx.serializer:
+ return self
+ else:
+ return self.map(lambda x: x, preservesPartitioning=True)
def __add__(self, other):
"""
@@ -334,17 +349,9 @@ class RDD(object):
[(1, 1), (1, 2), (2, 1), (2, 2)]
"""
# Due to batching, we can't use the Java cartesian method.
- java_cartesian = RDD(self._jrdd.cartesian(other._jrdd), self.ctx)
- def unpack_batches(pair):
- (x, y) = pair
- if type(x) == Batch or type(y) == Batch:
- xs = x.items if type(x) == Batch else [x]
- ys = y.items if type(y) == Batch else [y]
- for pair in product(xs, ys):
- yield pair
- else:
- yield pair
- return java_cartesian.flatMap(unpack_batches)
+ deserializer = CartesianDeserializer(self._jrdd_deserializer,
+ other._jrdd_deserializer)
+ return RDD(self._jrdd.cartesian(other._jrdd), self.ctx, deserializer)
def groupBy(self, f, numPartitions=None):
"""
@@ -391,8 +398,8 @@ class RDD(object):
"""
Return a list that contains all of the elements in this RDD.
"""
- picklesInJava = self._jrdd.collect().iterator()
- return list(self._collect_iterator_through_file(picklesInJava))
+ bytesInJava = self._jrdd.collect().iterator()
+ return list(self._collect_iterator_through_file(bytesInJava))
def _collect_iterator_through_file(self, iterator):
# Transferring lots of data through Py4J can be slow because
@@ -400,10 +407,10 @@ class RDD(object):
# file and read it back.
tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir)
tempFile.close()
- self.ctx._writeIteratorToPickleFile(iterator, tempFile.name)
+ self.ctx._writeToFile(iterator, tempFile.name)
# Read the data into Python and deserialize it:
with open(tempFile.name, 'rb') as tempFile:
- for item in read_from_pickle_file(tempFile):
+ for item in self._jrdd_deserializer.load_stream(tempFile):
yield item
os.unlink(tempFile.name)
@@ -571,7 +578,7 @@ class RDD(object):
items = []
for partition in range(mapped._jrdd.splits().size()):
iterator = self.ctx._takePartition(mapped._jrdd.rdd(), partition)
- items.extend(self._collect_iterator_through_file(iterator))
+ items.extend(mapped._collect_iterator_through_file(iterator))
if len(items) >= num:
break
return items[:num]
@@ -735,6 +742,7 @@ class RDD(object):
# Transferring O(n) objects to Java is too expensive. Instead, we'll
# form the hash buckets in Python, transferring O(numPartitions) objects
# to Java. Each object is a (splitNumber, [objects]) pair.
+ outputSerializer = self.ctx._unbatched_serializer
def add_shuffle_key(split, iterator):
buckets = defaultdict(list)
@@ -743,14 +751,14 @@ class RDD(object):
buckets[partitionFunc(k) % numPartitions].append((k, v))
for (split, items) in buckets.iteritems():
yield pack_long(split)
- yield dump_pickle(Batch(items))
+ yield outputSerializer.dumps(items)
keyed = PipelinedRDD(self, add_shuffle_key)
keyed._bypass_serializer = True
pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD()
partitioner = self.ctx._jvm.PythonPartitioner(numPartitions,
id(partitionFunc))
jrdd = pairRDD.partitionBy(partitioner).values()
- rdd = RDD(jrdd, self.ctx)
+ rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer))
# This is required so that id(partitionFunc) remains unique, even if
# partitionFunc is a lambda:
rdd._partitionFunc = partitionFunc
@@ -787,7 +795,8 @@ class RDD(object):
numPartitions = self.ctx.defaultParallelism
def combineLocally(iterator):
combiners = {}
- for (k, v) in iterator:
+ for x in iterator:
+ (k, v) = x
if k not in combiners:
combiners[k] = createCombiner(v)
else:
@@ -929,38 +938,39 @@ class PipelinedRDD(RDD):
20
"""
def __init__(self, prev, func, preservesPartitioning=False):
- if isinstance(prev, PipelinedRDD) and prev._is_pipelinable():
+ if not isinstance(prev, PipelinedRDD) or not prev._is_pipelinable():
+ # This transformation is the first in its stage:
+ self.func = func
+ self.preservesPartitioning = preservesPartitioning
+ self._prev_jrdd = prev._jrdd
+ self._prev_jrdd_deserializer = prev._jrdd_deserializer
+ else:
prev_func = prev.func
def pipeline_func(split, iterator):
return func(split, prev_func(split, iterator))
self.func = pipeline_func
self.preservesPartitioning = \
prev.preservesPartitioning and preservesPartitioning
- self._prev_jrdd = prev._prev_jrdd
- else:
- self.func = func
- self.preservesPartitioning = preservesPartitioning
- self._prev_jrdd = prev._jrdd
+ self._prev_jrdd = prev._prev_jrdd # maintain the pipeline
+ self._prev_jrdd_deserializer = prev._prev_jrdd_deserializer
self.is_cached = False
self.is_checkpointed = False
self.ctx = prev.ctx
self.prev = prev
self._jrdd_val = None
+ self._jrdd_deserializer = self.ctx.serializer
self._bypass_serializer = False
@property
def _jrdd(self):
if self._jrdd_val:
return self._jrdd_val
- func = self.func
- if not self._bypass_serializer and self.ctx.batchSize != 1:
- oldfunc = self.func
- batchSize = self.ctx.batchSize
- def batched_func(split, iterator):
- return batched(oldfunc(split, iterator), batchSize)
- func = batched_func
- cmds = [func, self._bypass_serializer]
- pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds)
+ if self._bypass_serializer:
+ serializer = NoOpSerializer()
+ else:
+ serializer = self.ctx.serializer
+ command = (self.func, self._prev_jrdd_deserializer, serializer)
+ pickled_command = CloudPickleSerializer().dumps(command)
broadcast_vars = ListConverter().convert(
[x._jbroadcast for x in self.ctx._pickled_broadcast_vars],
self.ctx._gateway._gateway_client)
@@ -971,8 +981,9 @@ class PipelinedRDD(RDD):
includes = ListConverter().convert(self.ctx._python_includes,
self.ctx._gateway._gateway_client)
python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(),
- pipe_command, env, includes, self.preservesPartitioning, self.ctx.pythonExec,
- broadcast_vars, self.ctx._javaAccumulator, class_tag)
+ bytearray(pickled_command), env, includes, self.preservesPartitioning,
+ self.ctx.pythonExec, broadcast_vars, self.ctx._javaAccumulator,
+ class_tag)
self._jrdd_val = python_rdd.asJavaRDD()
return self._jrdd_val
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 54fed1c9c7..811fa6f018 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -15,45 +15,269 @@
# limitations under the License.
#
-import struct
+"""
+PySpark supports custom serializers for transferring data; this can improve
+performance.
+
+By default, PySpark uses L{PickleSerializer} to serialize objects using Python's
+C{cPickle} serializer, which can serialize nearly any Python object.
+Other serializers, like L{MarshalSerializer}, support fewer datatypes but can be
+faster.
+
+The serializer is chosen when creating L{SparkContext}:
+
+>>> from pyspark.context import SparkContext
+>>> from pyspark.serializers import MarshalSerializer
+>>> sc = SparkContext('local', 'test', serializer=MarshalSerializer())
+>>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10)
+[0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
+>>> sc.stop()
+
+By default, PySpark serialize objects in batches; the batch size can be
+controlled through SparkContext's C{batchSize} parameter
+(the default size is 1024 objects):
+
+>>> sc = SparkContext('local', 'test', batchSize=2)
+>>> rdd = sc.parallelize(range(16), 4).map(lambda x: x)
+
+Behind the scenes, this creates a JavaRDD with four partitions, each of
+which contains two batches of two objects:
+
+>>> rdd.glom().collect()
+[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]
+>>> rdd._jrdd.count()
+8L
+>>> sc.stop()
+
+A batch size of -1 uses an unlimited batch size, and a size of 1 disables
+batching:
+
+>>> sc = SparkContext('local', 'test', batchSize=1)
+>>> rdd = sc.parallelize(range(16), 4).map(lambda x: x)
+>>> rdd.glom().collect()
+[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]
+>>> rdd._jrdd.count()
+16L
+"""
+
import cPickle
+from itertools import chain, izip, product
+import marshal
+import struct
+from pyspark import cloudpickle
+
+
+__all__ = ["PickleSerializer", "MarshalSerializer"]
+
+
+class SpecialLengths(object):
+ END_OF_DATA_SECTION = -1
+ PYTHON_EXCEPTION_THROWN = -2
+ TIMING_DATA = -3
+
+
+class Serializer(object):
+
+ def dump_stream(self, iterator, stream):
+ """
+ Serialize an iterator of objects to the output stream.
+ """
+ raise NotImplementedError
+
+ def load_stream(self, stream):
+ """
+ Return an iterator of deserialized objects from the input stream.
+ """
+ raise NotImplementedError
+
+
+ def _load_stream_without_unbatching(self, stream):
+ return self.load_stream(stream)
+
+ # Note: our notion of "equality" is that output generated by
+ # equal serializers can be deserialized using the same serializer.
+
+ # This default implementation handles the simple cases;
+ # subclasses should override __eq__ as appropriate.
+
+ def __eq__(self, other):
+ return isinstance(other, self.__class__)
+
+ def __ne__(self, other):
+ return not self.__eq__(other)
+
+
+class FramedSerializer(Serializer):
+ """
+ Serializer that writes objects as a stream of (length, data) pairs,
+ where C{length} is a 32-bit integer and data is C{length} bytes.
+ """
+
+ def dump_stream(self, iterator, stream):
+ for obj in iterator:
+ self._write_with_length(obj, stream)
+
+ def load_stream(self, stream):
+ while True:
+ try:
+ yield self._read_with_length(stream)
+ except EOFError:
+ return
+
+ def _write_with_length(self, obj, stream):
+ serialized = self.dumps(obj)
+ write_int(len(serialized), stream)
+ stream.write(serialized)
+
+ def _read_with_length(self, stream):
+ length = read_int(stream)
+ obj = stream.read(length)
+ if obj == "":
+ raise EOFError
+ return self.loads(obj)
+
+ def dumps(self, obj):
+ """
+ Serialize an object into a byte array.
+ When batching is used, this will be called with an array of objects.
+ """
+ raise NotImplementedError
+
+ def loads(self, obj):
+ """
+ Deserialize an object from a byte array.
+ """
+ raise NotImplementedError
+
+
+class BatchedSerializer(Serializer):
+ """
+ Serializes a stream of objects in batches by calling its wrapped
+ Serializer with streams of objects.
+ """
+
+ UNLIMITED_BATCH_SIZE = -1
+
+ def __init__(self, serializer, batchSize=UNLIMITED_BATCH_SIZE):
+ self.serializer = serializer
+ self.batchSize = batchSize
+
+ def _batched(self, iterator):
+ if self.batchSize == self.UNLIMITED_BATCH_SIZE:
+ yield list(iterator)
+ else:
+ items = []
+ count = 0
+ for item in iterator:
+ items.append(item)
+ count += 1
+ if count == self.batchSize:
+ yield items
+ items = []
+ count = 0
+ if items:
+ yield items
+
+ def dump_stream(self, iterator, stream):
+ self.serializer.dump_stream(self._batched(iterator), stream)
+
+ def load_stream(self, stream):
+ return chain.from_iterable(self._load_stream_without_unbatching(stream))
+
+ def _load_stream_without_unbatching(self, stream):
+ return self.serializer.load_stream(stream)
+
+ def __eq__(self, other):
+ return isinstance(other, BatchedSerializer) and \
+ other.serializer == self.serializer
+
+ def __str__(self):
+ return "BatchedSerializer<%s>" % str(self.serializer)
-class Batch(object):
+class CartesianDeserializer(FramedSerializer):
"""
- Used to store multiple RDD entries as a single Java object.
+ Deserializes the JavaRDD cartesian() of two PythonRDDs.
+ """
+
+ def __init__(self, key_ser, val_ser):
+ self.key_ser = key_ser
+ self.val_ser = val_ser
+
+ def load_stream(self, stream):
+ key_stream = self.key_ser._load_stream_without_unbatching(stream)
+ val_stream = self.val_ser._load_stream_without_unbatching(stream)
+ key_is_batched = isinstance(self.key_ser, BatchedSerializer)
+ val_is_batched = isinstance(self.val_ser, BatchedSerializer)
+ for (keys, vals) in izip(key_stream, val_stream):
+ keys = keys if key_is_batched else [keys]
+ vals = vals if val_is_batched else [vals]
+ for pair in product(keys, vals):
+ yield pair
+
+ def __eq__(self, other):
+ return isinstance(other, CartesianDeserializer) and \
+ self.key_ser == other.key_ser and self.val_ser == other.val_ser
+
+ def __str__(self):
+ return "CartesianDeserializer<%s, %s>" % \
+ (str(self.key_ser), str(self.val_ser))
- This relieves us from having to explicitly track whether an RDD
- is stored as batches of objects and avoids problems when processing
- the union() of batched and unbatched RDDs (e.g. the union() of textFile()
- with another RDD).
+
+class NoOpSerializer(FramedSerializer):
+
+ def loads(self, obj): return obj
+ def dumps(self, obj): return obj
+
+
+class PickleSerializer(FramedSerializer):
"""
- def __init__(self, items):
- self.items = items
+ Serializes objects using Python's cPickle serializer:
+ http://docs.python.org/2/library/pickle.html
-def batched(iterator, batchSize):
- if batchSize == -1: # unlimited batch size
- yield Batch(list(iterator))
- else:
- items = []
- count = 0
- for item in iterator:
- items.append(item)
- count += 1
- if count == batchSize:
- yield Batch(items)
- items = []
- count = 0
- if items:
- yield Batch(items)
+ This serializer supports nearly any Python object, but may
+ not be as fast as more specialized serializers.
+ """
+ def dumps(self, obj): return cPickle.dumps(obj, 2)
+ loads = cPickle.loads
-def dump_pickle(obj):
- return cPickle.dumps(obj, 2)
+class CloudPickleSerializer(PickleSerializer):
+ def dumps(self, obj): return cloudpickle.dumps(obj, 2)
-load_pickle = cPickle.loads
+
+class MarshalSerializer(FramedSerializer):
+ """
+ Serializes objects using Python's Marshal serializer:
+
+ http://docs.python.org/2/library/marshal.html
+
+ This serializer is faster than PickleSerializer but supports fewer datatypes.
+ """
+
+ dumps = marshal.dumps
+ loads = marshal.loads
+
+
+class MUTF8Deserializer(Serializer):
+ """
+ Deserializes streams written by Java's DataOutputStream.writeUTF().
+ """
+
+ def loads(self, stream):
+ length = struct.unpack('>H', stream.read(2))[0]
+ return stream.read(length).decode('utf8')
+
+ def load_stream(self, stream):
+ while True:
+ try:
+ yield self.loads(stream)
+ except struct.error:
+ return
+ except EOFError:
+ return
def read_long(stream):
@@ -84,25 +308,4 @@ def write_int(value, stream):
def write_with_length(obj, stream):
write_int(len(obj), stream)
- stream.write(obj)
-
-
-def read_with_length(stream):
- length = read_int(stream)
- obj = stream.read(length)
- if obj == "":
- raise EOFError
- return obj
-
-
-def read_from_pickle_file(stream):
- try:
- while True:
- obj = load_pickle(read_with_length(stream))
- if type(obj) == Batch: # We don't care about inheritance
- for item in obj.items:
- yield item
- else:
- yield obj
- except EOFError:
- return
+ stream.write(obj) \ No newline at end of file
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 29d6a128f6..621e1cb58c 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -86,7 +86,8 @@ class TestCheckpoint(PySparkTestCase):
time.sleep(1) # 1 second
self.assertTrue(flatMappedRDD.getCheckpointFile() is not None)
- recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile())
+ recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile(),
+ flatMappedRDD._jrdd_deserializer)
self.assertEquals([1, 2, 3, 4], recovered.collect())
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index d63c2aaef7..f2b3f3c142 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -23,23 +23,22 @@ import sys
import time
import socket
import traceback
-from base64 import standard_b64decode
# CloudPickler needs to be imported so that depicklers are registered using the
# copy_reg module.
from pyspark.accumulators import _accumulatorRegistry
from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.cloudpickle import CloudPickler
from pyspark.files import SparkFiles
-from pyspark.serializers import write_with_length, read_with_length, write_int, \
- read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file
+from pyspark.serializers import write_with_length, write_int, read_long, \
+ write_long, read_int, SpecialLengths, MUTF8Deserializer, PickleSerializer
-def load_obj(infile):
- return load_pickle(standard_b64decode(infile.readline().strip()))
+pickleSer = PickleSerializer()
+mutf8_deserializer = MUTF8Deserializer()
def report_times(outfile, boot, init, finish):
- write_int(-3, outfile)
+ write_int(SpecialLengths.TIMING_DATA, outfile)
write_long(1000 * boot, outfile)
write_long(1000 * init, outfile)
write_long(1000 * finish, outfile)
@@ -52,7 +51,7 @@ def main(infile, outfile):
return
# fetch name of workdir
- spark_files_dir = load_pickle(read_with_length(infile))
+ spark_files_dir = mutf8_deserializer.loads(infile)
SparkFiles._root_directory = spark_files_dir
SparkFiles._is_running_on_worker = True
@@ -60,38 +59,33 @@ def main(infile, outfile):
num_broadcast_variables = read_int(infile)
for _ in range(num_broadcast_variables):
bid = read_long(infile)
- value = read_with_length(infile)
- _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value))
+ value = pickleSer._read_with_length(infile)
+ _broadcastRegistry[bid] = Broadcast(bid, value)
# fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH
sys.path.append(spark_files_dir) # *.py files that were added will be copied here
num_python_includes = read_int(infile)
for _ in range(num_python_includes):
- sys.path.append(os.path.join(spark_files_dir, load_pickle(read_with_length(infile))))
+ filename = mutf8_deserializer.loads(infile)
+ sys.path.append(os.path.join(spark_files_dir, filename))
- # now load function
- func = load_obj(infile)
- bypassSerializer = load_obj(infile)
- if bypassSerializer:
- dumps = lambda x: x
- else:
- dumps = dump_pickle
+ command = pickleSer._read_with_length(infile)
+ (func, deserializer, serializer) = command
init_time = time.time()
- iterator = read_from_pickle_file(infile)
try:
- for obj in func(split_index, iterator):
- write_with_length(dumps(obj), outfile)
+ iterator = deserializer.load_stream(infile)
+ serializer.dump_stream(func(split_index, iterator), outfile)
except Exception as e:
- write_int(-2, outfile)
+ write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
write_with_length(traceback.format_exc(), outfile)
sys.exit(-1)
finish_time = time.time()
report_times(outfile, boot_time, init_time, finish_time)
# Mark the beginning of the accumulators section of the output
- write_int(-1, outfile)
- for aid, accum in _accumulatorRegistry.items():
- write_with_length(dump_pickle((aid, accum._value)), outfile)
- write_int(-1, outfile)
+ write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
+ write_int(len(_accumulatorRegistry), outfile)
+ for (aid, accum) in _accumulatorRegistry.items():
+ pickleSer._write_with_length((aid, accum._value), outfile)
if __name__ == '__main__':
diff --git a/python/run-tests b/python/run-tests
index cbc554ea9d..d4dad672d2 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -37,6 +37,7 @@ run_test "pyspark/rdd.py"
run_test "pyspark/context.py"
run_test "-m doctest pyspark/broadcast.py"
run_test "-m doctest pyspark/accumulators.py"
+run_test "-m doctest pyspark/serializers.py"
run_test "pyspark/tests.py"
if [[ $FAILED != 0 ]]; then
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
index 94e353af2e..bb73f6d337 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala
@@ -54,9 +54,10 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
// staging directory is private! -> rwx--------
val STAGING_DIR_PERMISSION: FsPermission = FsPermission.createImmutable(0700:Short)
// app files are world-wide readable and owner writable -> rw-r--r--
- val APP_FILE_PERMISSION: FsPermission = FsPermission.createImmutable(0644:Short)
+ val APP_FILE_PERMISSION: FsPermission = FsPermission.createImmutable(0644:Short)
- def run() {
+ // for client user who want to monitor app status by itself.
+ def runApp() = {
validateArgs()
init(yarnConf)
@@ -78,7 +79,11 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
appContext.setUser(UserGroupInformation.getCurrentUser().getShortUserName())
submitApp(appContext)
-
+ appId
+ }
+
+ def run() {
+ val appId = runApp()
monitorApplication(appId)
System.exit(0)
}
@@ -372,7 +377,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl
val commands = List[String](javaCommand +
" -server " +
JAVA_OPTS +
- " org.apache.spark.deploy.yarn.ApplicationMaster" +
+ " " + args.amClass +
" --class " + args.userClass +
" --jar " + args.userJar +
userArgsToString(args) +
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
index 852dbd7dab..b9dbc3fb87 100644
--- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala
@@ -35,6 +35,7 @@ class ClientArguments(val args: Array[String]) {
var numWorkers = 2
var amQueue = System.getProperty("QUEUE", "default")
var amMemory: Int = 512
+ var amClass: String = "org.apache.spark.deploy.yarn.ApplicationMaster"
var appName: String = "Spark"
// TODO
var inputFormatInfo: List[InputFormatInfo] = null
@@ -62,18 +63,22 @@ class ClientArguments(val args: Array[String]) {
userArgsBuffer += value
args = tail
- case ("--master-memory") :: MemoryParam(value) :: tail =>
- amMemory = value
+ case ("--master-class") :: value :: tail =>
+ amClass = value
args = tail
- case ("--num-workers") :: IntParam(value) :: tail =>
- numWorkers = value
+ case ("--master-memory") :: MemoryParam(value) :: tail =>
+ amMemory = value
args = tail
case ("--worker-memory") :: MemoryParam(value) :: tail =>
workerMemory = value
args = tail
+ case ("--num-workers") :: IntParam(value) :: tail =>
+ numWorkers = value
+ args = tail
+
case ("--worker-cores") :: IntParam(value) :: tail =>
workerCores = value
args = tail
@@ -119,19 +124,20 @@ class ClientArguments(val args: Array[String]) {
System.err.println(
"Usage: org.apache.spark.deploy.yarn.Client [options] \n" +
"Options:\n" +
- " --jar JAR_PATH Path to your application's JAR file (required)\n" +
- " --class CLASS_NAME Name of your application's main class (required)\n" +
- " --args ARGS Arguments to be passed to your application's main class.\n" +
- " Mutliple invocations are possible, each will be passed in order.\n" +
- " --num-workers NUM Number of workers to start (Default: 2)\n" +
- " --worker-cores NUM Number of cores for the workers (Default: 1). This is unsused right now.\n" +
- " --master-memory MEM Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb)\n" +
- " --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n" +
- " --name NAME The name of your application (Default: Spark)\n" +
- " --queue QUEUE The hadoop queue to use for allocation requests (Default: 'default')\n" +
- " --addJars jars Comma separated list of local jars that want SparkContext.addJar to work with.\n" +
- " --files files Comma separated list of files to be distributed with the job.\n" +
- " --archives archives Comma separated list of archives to be distributed with the job."
+ " --jar JAR_PATH Path to your application's JAR file (required)\n" +
+ " --class CLASS_NAME Name of your application's main class (required)\n" +
+ " --args ARGS Arguments to be passed to your application's main class.\n" +
+ " Mutliple invocations are possible, each will be passed in order.\n" +
+ " --num-workers NUM Number of workers to start (Default: 2)\n" +
+ " --worker-cores NUM Number of cores for the workers (Default: 1). This is unsused right now.\n" +
+ " --master-class CLASS_NAME Class Name for Master (Default: spark.deploy.yarn.ApplicationMaster)\n" +
+ " --master-memory MEM Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb)\n" +
+ " --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n" +
+ " --name NAME The name of your application (Default: Spark)\n" +
+ " --queue QUEUE The hadoop queue to use for allocation requests (Default: 'default')\n" +
+ " --addJars jars Comma separated list of local jars that want SparkContext.addJar to work with.\n" +
+ " --files files Comma separated list of files to be distributed with the job.\n" +
+ " --archives archives Comma separated list of archives to be distributed with the job."
)
System.exit(exitCode)
}
diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala
new file mode 100644
index 0000000000..421a83c87a
--- /dev/null
+++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala
@@ -0,0 +1,246 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.yarn
+
+import java.net.Socket
+import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.net.NetUtils
+import org.apache.hadoop.yarn.api._
+import org.apache.hadoop.yarn.api.records._
+import org.apache.hadoop.yarn.api.protocolrecords._
+import org.apache.hadoop.yarn.conf.YarnConfiguration
+import org.apache.hadoop.yarn.ipc.YarnRPC
+import org.apache.hadoop.yarn.util.{ConverterUtils, Records}
+import akka.actor._
+import akka.remote.{RemoteClientShutdown, RemoteClientDisconnected, RemoteClientLifeCycleEvent}
+import akka.remote.RemoteClientShutdown
+import akka.actor.Terminated
+import akka.remote.RemoteClientDisconnected
+import org.apache.spark.{SparkContext, Logging}
+import org.apache.spark.util.{Utils, AkkaUtils}
+import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend
+import org.apache.spark.scheduler.SplitInfo
+
+class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration) extends Logging {
+
+ def this(args: ApplicationMasterArguments) = this(args, new Configuration())
+
+ private val rpc: YarnRPC = YarnRPC.create(conf)
+ private var resourceManager: AMRMProtocol = null
+ private var appAttemptId: ApplicationAttemptId = null
+ private var reporterThread: Thread = null
+ private val yarnConf: YarnConfiguration = new YarnConfiguration(conf)
+
+ private var yarnAllocator: YarnAllocationHandler = null
+ private var driverClosed:Boolean = false
+
+ val actorSystem : ActorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0)._1
+ var actor: ActorRef = null
+
+ // This actor just working as a monitor to watch on Driver Actor.
+ class MonitorActor(driverUrl: String) extends Actor {
+
+ var driver: ActorRef = null
+
+ override def preStart() {
+ logInfo("Listen to driver: " + driverUrl)
+ driver = context.actorFor(driverUrl)
+ driver ! "hello"
+ context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent])
+ context.watch(driver) // Doesn't work with remote actors, but useful for testing
+ }
+
+ override def receive = {
+ case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) =>
+ logInfo("Driver terminated or disconnected! Shutting down.")
+ driverClosed = true
+ }
+ }
+
+ def run() {
+
+ appAttemptId = getApplicationAttemptId()
+ resourceManager = registerWithResourceManager()
+ val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster()
+
+ // Compute number of threads for akka
+ val minimumMemory = appMasterResponse.getMinimumResourceCapability().getMemory()
+
+ if (minimumMemory > 0) {
+ val mem = args.workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD
+ val numCore = (mem / minimumMemory) + (if (0 != (mem % minimumMemory)) 1 else 0)
+
+ if (numCore > 0) {
+ // do not override - hits https://issues.apache.org/jira/browse/HADOOP-8406
+ // TODO: Uncomment when hadoop is on a version which has this fixed.
+ // args.workerCores = numCore
+ }
+ }
+
+ waitForSparkMaster()
+
+ // Allocate all containers
+ allocateWorkers()
+
+ // Launch a progress reporter thread, else app will get killed after expiration (def: 10mins) timeout
+ // ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapse.
+
+ val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000)
+ // must be <= timeoutInterval/ 2.
+ // On other hand, also ensure that we are reasonably responsive without causing too many requests to RM.
+ // so atleast 1 minute or timeoutInterval / 10 - whichever is higher.
+ val interval = math.min(timeoutInterval / 2, math.max(timeoutInterval/ 10, 60000L))
+ reporterThread = launchReporterThread(interval)
+
+ // Wait for the reporter thread to Finish.
+ reporterThread.join()
+
+ finishApplicationMaster(FinalApplicationStatus.SUCCEEDED)
+ actorSystem.shutdown()
+
+ logInfo("Exited")
+ System.exit(0)
+ }
+
+ private def getApplicationAttemptId(): ApplicationAttemptId = {
+ val envs = System.getenv()
+ val containerIdString = envs.get(ApplicationConstants.AM_CONTAINER_ID_ENV)
+ val containerId = ConverterUtils.toContainerId(containerIdString)
+ val appAttemptId = containerId.getApplicationAttemptId()
+ logInfo("ApplicationAttemptId: " + appAttemptId)
+ return appAttemptId
+ }
+
+ private def registerWithResourceManager(): AMRMProtocol = {
+ val rmAddress = NetUtils.createSocketAddr(yarnConf.get(
+ YarnConfiguration.RM_SCHEDULER_ADDRESS,
+ YarnConfiguration.DEFAULT_RM_SCHEDULER_ADDRESS))
+ logInfo("Connecting to ResourceManager at " + rmAddress)
+ return rpc.getProxy(classOf[AMRMProtocol], rmAddress, conf).asInstanceOf[AMRMProtocol]
+ }
+
+ private def registerApplicationMaster(): RegisterApplicationMasterResponse = {
+ logInfo("Registering the ApplicationMaster")
+ val appMasterRequest = Records.newRecord(classOf[RegisterApplicationMasterRequest])
+ .asInstanceOf[RegisterApplicationMasterRequest]
+ appMasterRequest.setApplicationAttemptId(appAttemptId)
+ // Setting this to master host,port - so that the ApplicationReport at client has some sensible info.
+ // Users can then monitor stderr/stdout on that node if required.
+ appMasterRequest.setHost(Utils.localHostName())
+ appMasterRequest.setRpcPort(0)
+ // What do we provide here ? Might make sense to expose something sensible later ?
+ appMasterRequest.setTrackingUrl("")
+ return resourceManager.registerApplicationMaster(appMasterRequest)
+ }
+
+ private def waitForSparkMaster() {
+ logInfo("Waiting for spark driver to be reachable.")
+ var driverUp = false
+ val hostport = args.userArgs(0)
+ val (driverHost, driverPort) = Utils.parseHostPort(hostport)
+ while(!driverUp) {
+ try {
+ val socket = new Socket(driverHost, driverPort)
+ socket.close()
+ logInfo("Master now available: " + driverHost + ":" + driverPort)
+ driverUp = true
+ } catch {
+ case e: Exception =>
+ logError("Failed to connect to driver at " + driverHost + ":" + driverPort)
+ Thread.sleep(100)
+ }
+ }
+ System.setProperty("spark.driver.host", driverHost)
+ System.setProperty("spark.driver.port", driverPort.toString)
+
+ val driverUrl = "akka://spark@%s:%s/user/%s".format(
+ driverHost, driverPort.toString, CoarseGrainedSchedulerBackend.ACTOR_NAME)
+
+ actor = actorSystem.actorOf(Props(new MonitorActor(driverUrl)), name = "YarnAM")
+ }
+
+
+ private def allocateWorkers() {
+
+ // Fixme: should get preferredNodeLocationData from SparkContext, just fake a empty one for now.
+ val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] = scala.collection.immutable.Map()
+
+ yarnAllocator = YarnAllocationHandler.newAllocator(yarnConf, resourceManager, appAttemptId, args, preferredNodeLocationData)
+
+ logInfo("Allocating " + args.numWorkers + " workers.")
+ // Wait until all containers have finished
+ // TODO: This is a bit ugly. Can we make it nicer?
+ // TODO: Handle container failure
+ while(yarnAllocator.getNumWorkersRunning < args.numWorkers) {
+ yarnAllocator.allocateContainers(math.max(args.numWorkers - yarnAllocator.getNumWorkersRunning, 0))
+ Thread.sleep(100)
+ }
+
+ logInfo("All workers have launched.")
+
+ }
+
+ // TODO: We might want to extend this to allocate more containers in case they die !
+ private def launchReporterThread(_sleepTime: Long): Thread = {
+ val sleepTime = if (_sleepTime <= 0 ) 0 else _sleepTime
+
+ val t = new Thread {
+ override def run() {
+ while (!driverClosed) {
+ val missingWorkerCount = args.numWorkers - yarnAllocator.getNumWorkersRunning
+ if (missingWorkerCount > 0) {
+ logInfo("Allocating " + missingWorkerCount + " containers to make up for (potentially ?) lost containers")
+ yarnAllocator.allocateContainers(missingWorkerCount)
+ }
+ else sendProgress()
+ Thread.sleep(sleepTime)
+ }
+ }
+ }
+ // setting to daemon status, though this is usually not a good idea.
+ t.setDaemon(true)
+ t.start()
+ logInfo("Started progress reporter thread - sleep time : " + sleepTime)
+ return t
+ }
+
+ private def sendProgress() {
+ logDebug("Sending progress")
+ // simulated with an allocate request with no nodes requested ...
+ yarnAllocator.allocateContainers(0)
+ }
+
+ def finishApplicationMaster(status: FinalApplicationStatus) {
+
+ logInfo("finish ApplicationMaster with " + status)
+ val finishReq = Records.newRecord(classOf[FinishApplicationMasterRequest])
+ .asInstanceOf[FinishApplicationMasterRequest]
+ finishReq.setAppAttemptId(appAttemptId)
+ finishReq.setFinishApplicationStatus(status)
+ resourceManager.finishApplicationMaster(finishReq)
+ }
+
+}
+
+
+object WorkerLauncher {
+ def main(argStrings: Array[String]) {
+ val args = new ApplicationMasterArguments(argStrings)
+ new WorkerLauncher(args).run()
+ }
+}
diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala
new file mode 100644
index 0000000000..63a0449e5a
--- /dev/null
+++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala
@@ -0,0 +1,47 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import org.apache.spark._
+import org.apache.hadoop.conf.Configuration
+import org.apache.spark.deploy.yarn.YarnAllocationHandler
+import org.apache.spark.util.Utils
+
+/**
+ *
+ * This scheduler launch worker through Yarn - by call into Client to launch WorkerLauncher as AM.
+ */
+private[spark] class YarnClientClusterScheduler(sc: SparkContext, conf: Configuration) extends ClusterScheduler(sc) {
+
+ def this(sc: SparkContext) = this(sc, new Configuration())
+
+ // By default, rack is unknown
+ override def getRackForHost(hostPort: String): Option[String] = {
+ val host = Utils.parseHostPort(hostPort)._1
+ val retval = YarnAllocationHandler.lookupRack(conf, host)
+ if (retval != null) Some(retval) else None
+ }
+
+ override def postStartHook() {
+
+ // The yarn application is running, but the worker might not yet ready
+ // Wait for a few seconds for the slaves to bootstrap and register with master - best case attempt
+ Thread.sleep(2000L)
+ logInfo("YarnClientClusterScheduler.postStartHook done")
+ }
+}
diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
new file mode 100644
index 0000000000..b206780c78
--- /dev/null
+++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala
@@ -0,0 +1,109 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.scheduler.cluster
+
+import org.apache.hadoop.yarn.api.records.{ApplicationId, YarnApplicationState}
+import org.apache.spark.{SparkException, Logging, SparkContext}
+import org.apache.spark.deploy.yarn.{Client, ClientArguments}
+
+private[spark] class YarnClientSchedulerBackend(
+ scheduler: ClusterScheduler,
+ sc: SparkContext)
+ extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem)
+ with Logging {
+
+ var client: Client = null
+ var appId: ApplicationId = null
+
+ override def start() {
+ super.start()
+
+ val defalutWorkerCores = "2"
+ val defalutWorkerMemory = "512m"
+ val defaultWorkerNumber = "1"
+
+ val userJar = System.getenv("SPARK_YARN_APP_JAR")
+ var workerCores = System.getenv("SPARK_WORKER_CORES")
+ var workerMemory = System.getenv("SPARK_WORKER_MEMORY")
+ var workerNumber = System.getenv("SPARK_WORKER_INSTANCES")
+
+ if (userJar == null)
+ throw new SparkException("env SPARK_YARN_APP_JAR is not set")
+
+ if (workerCores == null)
+ workerCores = defalutWorkerCores
+ if (workerMemory == null)
+ workerMemory = defalutWorkerMemory
+ if (workerNumber == null)
+ workerNumber = defaultWorkerNumber
+
+ val driverHost = System.getProperty("spark.driver.host")
+ val driverPort = System.getProperty("spark.driver.port")
+ val hostport = driverHost + ":" + driverPort
+
+ val argsArray = Array[String](
+ "--class", "notused",
+ "--jar", userJar,
+ "--args", hostport,
+ "--worker-memory", workerMemory,
+ "--worker-cores", workerCores,
+ "--num-workers", workerNumber,
+ "--master-class", "org.apache.spark.deploy.yarn.WorkerLauncher"
+ )
+
+ val args = new ClientArguments(argsArray)
+ client = new Client(args)
+ appId = client.runApp()
+ waitForApp()
+ }
+
+ def waitForApp() {
+
+ // TODO : need a better way to find out whether the workers are ready or not
+ // maybe by resource usage report?
+ while(true) {
+ val report = client.getApplicationReport(appId)
+
+ logInfo("Application report from ASM: \n" +
+ "\t appMasterRpcPort: " + report.getRpcPort() + "\n" +
+ "\t appStartTime: " + report.getStartTime() + "\n" +
+ "\t yarnAppState: " + report.getYarnApplicationState() + "\n"
+ )
+
+ // Ready to go, or already gone.
+ val state = report.getYarnApplicationState()
+ if (state == YarnApplicationState.RUNNING) {
+ return
+ } else if (state == YarnApplicationState.FINISHED ||
+ state == YarnApplicationState.FAILED ||
+ state == YarnApplicationState.KILLED) {
+ throw new SparkException("Yarn application already ended," +
+ "might be killed or not able to launch application master.")
+ }
+
+ Thread.sleep(1000)
+ }
+ }
+
+ override def stop() {
+ super.stop()
+ client.stop()
+ logInfo("Stoped")
+ }
+
+}