From 2a37235825cecd3f75286d11456c6e3cb13d4327 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 18 Oct 2013 00:07:49 -0700 Subject: Initial commit of adding histogram functionality to the DoubleRDDFunctions. --- .../org/apache/spark/api/java/JavaDoubleRDD.scala | 32 +++++ .../org/apache/spark/rdd/DoubleRDDFunctions.scala | 134 +++++++++++++++++++++ 2 files changed, 166 insertions(+) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index 5fd1fab580..d2a2818e59 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -26,6 +26,8 @@ import org.apache.spark.storage.StorageLevel import java.lang.Double import org.apache.spark.Partitioner +import scala.collection.JavaConverters._ + class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, JavaDoubleRDD] { override val classManifest: ClassManifest[Double] = implicitly[ClassManifest[Double]] @@ -158,6 +160,36 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav /** (Experimental) Approximate operation to return the sum within a timeout. */ def sumApprox(timeout: Long): PartialResult[BoundedDouble] = srdd.sumApprox(timeout) + + /** + * Compute a histogram of the data using bucketCount number of buckets evenly + * spaced between the minimum and maximum of the RDD. For example if the min + * value is 0 and the max is 100 and there are two buckets the resulting + * buckets will be [0,50) [50,100]. bucketCount must be at least 1 + * If the RDD contains infinity, NaN throws an exception + * If the elements in RDD do not vary (max == min) throws an exception + */ + def histogram(bucketCount: Int): Pair[Array[Double], Array[Long]] = { + val result = srdd.histogram(bucketCount) + (result._1.map(scala.Double.box(_)), result._2) + } + /** + * Compute a histogram using the provided buckets. The buckets are all open + * to the left except for the last which is closed + * e.g. for the array + * [1,10,20,50] the buckets are [1,10) [10,20) [20,50] + * e.g 1<=x<10 , 10<=x<20, 20<=x<50 + * And on the input of 1 and 50 we would have a histogram of 1,0,0 + * + * Note: if your histogram is evenly spaced (e.g. [0,10,20,30]) this switches + * from an O(log n) inseration to O(1) per element. (where n = # buckets) + * buckets must be sorted and not contain any duplicates. + * buckets array must be at least two elements + * All NaN entries are treated the same. + */ + def histogram(buckets: Array[Double]): Array[Long] = { + srdd.histogram(buckets.map(_.toDouble)) + } } object JavaDoubleRDD { diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index a4bec41752..776a83cefe 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -24,6 +24,8 @@ import org.apache.spark.partial.SumEvaluator import org.apache.spark.util.StatCounter import org.apache.spark.{TaskContext, Logging} +import scala.collection.immutable.NumericRange + /** * Extra functions available on RDDs of Doubles through an implicit conversion. * Import `org.apache.spark.SparkContext._` at the top of your program to use these functions. @@ -76,4 +78,136 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { val evaluator = new SumEvaluator(self.partitions.size, confidence) self.context.runApproximateJob(self, processPartition, evaluator, timeout) } + + /** + * Compute a histogram of the data using bucketCount number of buckets evenly + * spaced between the minimum and maximum of the RDD. For example if the min + * value is 0 and the max is 100 and there are two buckets the resulting + * buckets will be [0,50) [50,100]. bucketCount must be at least 1 + * If the RDD contains infinity, NaN throws an exception + * If the elements in RDD do not vary (max == min) throws an exception + */ + def histogram(bucketCount: Int): Pair[Array[Double], Array[Long]] = { + // Compute the minimum and the maxium + val (max: Double, min: Double) = self.mapPartitions { items => + Iterator(items.foldRight(-1/0.0, Double.NaN)((e: Double, x: Pair[Double, Double]) => + (x._1.max(e),x._2.min(e)))) + }.reduce { (maxmin1, maxmin2) => + (maxmin1._1.max(maxmin2._1), maxmin1._2.min(maxmin2._2)) + } + if (max.isNaN() || max.isInfinity || min.isInfinity ) { + throw new UnsupportedOperationException("Histogram on either an empty RDD or RDD containing +-infinity or NaN") + } + if (max == min) { + throw new UnsupportedOperationException("Histogram with no range in elements") + } + val increment: Double = (max-min)/bucketCount.toDouble + val range = Range.Double.inclusive(min, max, increment) + val buckets: Array[Double] = range.toArray + (buckets,histogram(buckets)) + } + /** + * Compute a histogram using the provided buckets. The buckets are all open + * to the left except for the last which is closed + * e.g. for the array + * [1,10,20,50] the buckets are [1,10) [10,20) [20,50] + * e.g 1<=x<10 , 10<=x<20, 20<=x<50 + * And on the input of 1 and 50 we would have a histogram of 1,0,0 + * + * Note: if your histogram is evenly spaced (e.g. [0,10,20,30]) this switches + * from an O(log n) inseration to O(1) per element. (where n = # buckets) + * buckets must be sorted and not contain any duplicates. + * buckets array must be at least two elements + * All NaN entries are treated the same. + */ + def histogram(buckets: Array[Double]): Array[Long] = { + if (buckets.length < 2) { + throw new IllegalArgumentException("buckets array must have at least two elements") + } + // The histogramPartition function computes the partail histogram for a given + // partition. The provided bucketFunction determines which bucket in the array + // to increment or returns None if there is no bucket. This is done so we can + // specialize for uniformly distributed buckets and save the O(log n) binary + // search cost. + def histogramPartition(bucketFunction: (Double) => Option[Int])(iter: Iterator[Double]): Iterator[Array[Long]] = { + val counters = new Array[Long](buckets.length-1) + while (iter.hasNext) { + bucketFunction(iter.next()) match { + case Some(x: Int) => {counters(x)+=1} + case _ => {} + } + } + Iterator(counters) + } + // Merge the counters. + def mergeCounters(a1: Array[Long], a2: Array[Long]): Array[Long] = { + a1.indices.foreach(i => a1(i) += a2(i)) + a1 + } + // Basic bucket function. This works using Java's built in Array + // binary search. Takes log(size(buckets)) + def basicBucketFunction(e: Double): Option[Int] = { + val location = java.util.Arrays.binarySearch(buckets, e) + if (location < 0) { + // If the location is less than 0 then the insertion point in the array + // to keep it sorted is -location-1 + val insertionPoint = -location-1 + // If we have to insert before the first element or after the last one + // its out of bounds. + // We do this rather than buckets.lengthCompare(insertionPoint) + // because Array[Double] fails to override it (for now). + if (insertionPoint > 0 && insertionPoint < buckets.length) { + Some(insertionPoint-1) + } else { + None + } + } else if (location < buckets.length-1) { + // Exact match, just insert here + Some(location) + } else { + // Exact match to the last element + Some(location-1) + } + } + // Determine the bucket function in constant time. Requires that buckets are evenly spaced + def fastBucketFunction(min: Double, increment: Double, count: Int)(e: Double): Option[Int] = { + // If our input is not a number unless the increment is also NaN then we fail fast + if (e.isNaN()) { + return None + } + val bucketNumber = (e-min)/(increment) + // We do this rather than buckets.lengthCompare(bucketNumber) + // because Array[Double] fails to override it (for now). + if (bucketNumber > count || bucketNumber < 0) { + None + } else { + Some(bucketNumber.toInt.min(count-1)) + } + } + def evenlySpaced(buckets: Array[Double]): Boolean = { + val delta = buckets(1)-buckets(0) + // Technically you could have an evenly spaced bucket with NaN + // increments but then its a single bucket and this makes the + // fastBucketFunction simpler. + if (delta.isNaN() || delta.isInfinite()) { + return false + } + for (i <- 1 to buckets.length-1) { + if (buckets(i)-buckets(i-1) != delta) { + return false + } + } + true + } + // Decide which bucket function to pass to histogramPartition. We decide here + // rather than having a general function so that the decission need only be made + // once rather than once per shard + val bucketFunction = if (evenlySpaced(buckets)) { + fastBucketFunction(buckets(0), buckets(1)-buckets(0), buckets.length-1) _ + } else { + basicBucketFunction _ + } + self.mapPartitions(histogramPartition(bucketFunction)).reduce(mergeCounters) + } + } -- cgit v1.2.3 From 699f7d28c0347cb516fa17f94b53d7bc50f18346 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 21 Oct 2013 00:10:03 -0700 Subject: CR feedback --- .../org/apache/spark/api/java/JavaDoubleRDD.scala | 18 ++- .../org/apache/spark/rdd/DoubleRDDFunctions.scala | 68 +++++----- .../org/apache/spark/rdd/DoubleRDDSuite.scala | 140 ++++++++++++--------- 3 files changed, 125 insertions(+), 101 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index d2a2818e59..b002468442 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -167,12 +167,13 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav * value is 0 and the max is 100 and there are two buckets the resulting * buckets will be [0,50) [50,100]. bucketCount must be at least 1 * If the RDD contains infinity, NaN throws an exception - * If the elements in RDD do not vary (max == min) throws an exception + * If the elements in RDD do not vary (max == min) always returns a single bucket. */ def histogram(bucketCount: Int): Pair[Array[Double], Array[Long]] = { val result = srdd.histogram(bucketCount) (result._1.map(scala.Double.box(_)), result._2) } + /** * Compute a histogram using the provided buckets. The buckets are all open * to the left except for the last which is closed @@ -181,14 +182,21 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav * e.g 1<=x<10 , 10<=x<20, 20<=x<50 * And on the input of 1 and 50 we would have a histogram of 1,0,0 * - * Note: if your histogram is evenly spaced (e.g. [0,10,20,30]) this switches - * from an O(log n) inseration to O(1) per element. (where n = # buckets) + * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched + * from an O(log n) inseration to O(1) per element. (where n = # buckets) if you set evenBuckets + * to true. * buckets must be sorted and not contain any duplicates. * buckets array must be at least two elements - * All NaN entries are treated the same. + * All NaN entries are treated the same. If you have a NaN bucket it must be + * the maximum value of the last position and all NaN entries will be counted + * in that bucket. */ def histogram(buckets: Array[Double]): Array[Long] = { - srdd.histogram(buckets.map(_.toDouble)) + srdd.histogram(buckets.map(_.toDouble), false) + } + + def histogram(buckets: Array[Double], evenBuckets: Boolean): Array[Long] = { + srdd.histogram(buckets.map(_.toDouble), evenBuckets) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index 776a83cefe..33738ee094 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -83,44 +83,50 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { * Compute a histogram of the data using bucketCount number of buckets evenly * spaced between the minimum and maximum of the RDD. For example if the min * value is 0 and the max is 100 and there are two buckets the resulting - * buckets will be [0,50) [50,100]. bucketCount must be at least 1 + * buckets will be [0, 50) [50, 100]. bucketCount must be at least 1 * If the RDD contains infinity, NaN throws an exception - * If the elements in RDD do not vary (max == min) throws an exception + * If the elements in RDD do not vary (max == min) always returns a single bucket. */ def histogram(bucketCount: Int): Pair[Array[Double], Array[Long]] = { // Compute the minimum and the maxium val (max: Double, min: Double) = self.mapPartitions { items => Iterator(items.foldRight(-1/0.0, Double.NaN)((e: Double, x: Pair[Double, Double]) => - (x._1.max(e),x._2.min(e)))) + (x._1.max(e), x._2.min(e)))) }.reduce { (maxmin1, maxmin2) => (maxmin1._1.max(maxmin2._1), maxmin1._2.min(maxmin2._2)) } if (max.isNaN() || max.isInfinity || min.isInfinity ) { - throw new UnsupportedOperationException("Histogram on either an empty RDD or RDD containing +-infinity or NaN") - } - if (max == min) { - throw new UnsupportedOperationException("Histogram with no range in elements") + throw new UnsupportedOperationException( + "Histogram on either an empty RDD or RDD containing +/-infinity or NaN") } val increment: Double = (max-min)/bucketCount.toDouble - val range = Range.Double.inclusive(min, max, increment) + val range = if (increment != 0) { + Range.Double.inclusive(min, max, increment) + } else { + List(min, min) + } val buckets: Array[Double] = range.toArray - (buckets,histogram(buckets)) + (buckets, histogram(buckets, true)) } + /** * Compute a histogram using the provided buckets. The buckets are all open * to the left except for the last which is closed * e.g. for the array - * [1,10,20,50] the buckets are [1,10) [10,20) [20,50] + * [1, 10, 20, 50] the buckets are [1, 10) [10, 20) [20, 50] * e.g 1<=x<10 , 10<=x<20, 20<=x<50 - * And on the input of 1 and 50 we would have a histogram of 1,0,0 + * And on the input of 1 and 50 we would have a histogram of 1, 0, 0 * - * Note: if your histogram is evenly spaced (e.g. [0,10,20,30]) this switches - * from an O(log n) inseration to O(1) per element. (where n = # buckets) + * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched + * from an O(log n) inseration to O(1) per element. (where n = # buckets) if you set evenBuckets + * to true. * buckets must be sorted and not contain any duplicates. * buckets array must be at least two elements - * All NaN entries are treated the same. + * All NaN entries are treated the same. If you have a NaN bucket it must be + * the maximum value of the last position and all NaN entries will be counted + * in that bucket. */ - def histogram(buckets: Array[Double]): Array[Long] = { + def histogram(buckets: Array[Double], evenBuckets: Boolean = false): Array[Long] = { if (buckets.length < 2) { throw new IllegalArgumentException("buckets array must have at least two elements") } @@ -129,11 +135,12 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { // to increment or returns None if there is no bucket. This is done so we can // specialize for uniformly distributed buckets and save the O(log n) binary // search cost. - def histogramPartition(bucketFunction: (Double) => Option[Int])(iter: Iterator[Double]): Iterator[Array[Long]] = { - val counters = new Array[Long](buckets.length-1) + def histogramPartition(bucketFunction: (Double) => Option[Int])(iter: Iterator[Double]): + Iterator[Array[Long]] = { + val counters = new Array[Long](buckets.length - 1) while (iter.hasNext) { bucketFunction(iter.next()) match { - case Some(x: Int) => {counters(x)+=1} + case Some(x: Int) => {counters(x) += 1} case _ => {} } } @@ -161,12 +168,12 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { } else { None } - } else if (location < buckets.length-1) { + } else if (location < buckets.length - 1) { // Exact match, just insert here Some(location) } else { // Exact match to the last element - Some(location-1) + Some(location - 1) } } // Determine the bucket function in constant time. Requires that buckets are evenly spaced @@ -175,34 +182,19 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { if (e.isNaN()) { return None } - val bucketNumber = (e-min)/(increment) + val bucketNumber = (e - min)/(increment) // We do this rather than buckets.lengthCompare(bucketNumber) // because Array[Double] fails to override it (for now). if (bucketNumber > count || bucketNumber < 0) { None } else { - Some(bucketNumber.toInt.min(count-1)) - } - } - def evenlySpaced(buckets: Array[Double]): Boolean = { - val delta = buckets(1)-buckets(0) - // Technically you could have an evenly spaced bucket with NaN - // increments but then its a single bucket and this makes the - // fastBucketFunction simpler. - if (delta.isNaN() || delta.isInfinite()) { - return false - } - for (i <- 1 to buckets.length-1) { - if (buckets(i)-buckets(i-1) != delta) { - return false - } + Some(bucketNumber.toInt.min(count - 1)) } - true } // Decide which bucket function to pass to histogramPartition. We decide here // rather than having a general function so that the decission need only be made // once rather than once per shard - val bucketFunction = if (evenlySpaced(buckets)) { + val bucketFunction = if (evenBuckets) { fastBucketFunction(buckets(0), buckets(1)-buckets(0), buckets.length-1) _ } else { basicBucketFunction _ diff --git a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala index 2ec7173511..071084485a 100644 --- a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala @@ -34,134 +34,151 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { val rdd: RDD[Double] = sc.parallelize(Seq()) val buckets: Array[Double] = Array(0.0, 10.0) val histogramResults: Array[Long] = rdd.histogram(buckets) + val histogramResults2: Array[Long] = rdd.histogram(buckets, true) val expectedHistogramResults: Array[Long] = Array(0) assert(histogramResults === expectedHistogramResults) + assert(histogramResults2 === expectedHistogramResults) } test("WorksWithOutOfRangeWithOneBucket") { // Verify that if all of the elements are out of range the counts are zero - val rdd: RDD[Double] = sc.parallelize(Seq(10.01,-0.01)) + val rdd: RDD[Double] = sc.parallelize(Seq(10.01, -0.01)) val buckets: Array[Double] = Array(0.0, 10.0) val histogramResults: Array[Long] = rdd.histogram(buckets) + val histogramResults2: Array[Long] = rdd.histogram(buckets, true) val expectedHistogramResults: Array[Long] = Array(0) assert(histogramResults === expectedHistogramResults) + assert(histogramResults2 === expectedHistogramResults) } test("WorksInRangeWithOneBucket") { // Verify the basic case of one bucket and all elements in that bucket works - val rdd: RDD[Double] = sc.parallelize(Seq(1,2,3,4)) + val rdd: RDD[Double] = sc.parallelize(Seq(1, 2, 3, 4)) val buckets: Array[Double] = Array(0.0, 10.0) val histogramResults: Array[Long] = rdd.histogram(buckets) + val histogramResults2: Array[Long] = rdd.histogram(buckets, true) val expectedHistogramResults: Array[Long] = Array(4) assert(histogramResults === expectedHistogramResults) + assert(histogramResults2 === expectedHistogramResults) } test("WorksInRangeWithOneBucketExactMatch") { // Verify the basic case of one bucket and all elements in that bucket works - val rdd: RDD[Double] = sc.parallelize(Seq(1,2,3,4)) + val rdd: RDD[Double] = sc.parallelize(Seq(1, 2, 3, 4)) val buckets: Array[Double] = Array(1.0, 4.0) val histogramResults: Array[Long] = rdd.histogram(buckets) + val histogramResults2: Array[Long] = rdd.histogram(buckets, true) val expectedHistogramResults: Array[Long] = Array(4) assert(histogramResults === expectedHistogramResults) + assert(histogramResults2 === expectedHistogramResults) } test("WorksWithOutOfRangeWithTwoBuckets") { // Verify that out of range works with two buckets - val rdd: RDD[Double] = sc.parallelize(Seq(10.01,-0.01)) + val rdd: RDD[Double] = sc.parallelize(Seq(10.01, -0.01)) val buckets: Array[Double] = Array(0.0, 5.0, 10.0) val histogramResults: Array[Long] = rdd.histogram(buckets) - val expectedHistogramResults: Array[Long] = Array(0,0) + val histogramResults2: Array[Long] = rdd.histogram(buckets, true) + val expectedHistogramResults: Array[Long] = Array(0, 0) assert(histogramResults === expectedHistogramResults) + assert(histogramResults2 === expectedHistogramResults) } test("WorksWithOutOfRangeWithTwoUnEvenBuckets") { // Verify that out of range works with two un even buckets - val rdd: RDD[Double] = sc.parallelize(Seq(10.01,-0.01)) + val rdd: RDD[Double] = sc.parallelize(Seq(10.01, -0.01)) val buckets: Array[Double] = Array(0.0, 4.0, 10.0) val histogramResults: Array[Long] = rdd.histogram(buckets) - val expectedHistogramResults: Array[Long] = Array(0,0) + val expectedHistogramResults: Array[Long] = Array(0, 0) assert(histogramResults === expectedHistogramResults) } test("WorksInRangeWithTwoBuckets") { // Make sure that it works with two equally spaced buckets and elements in each - val rdd: RDD[Double] = sc.parallelize(Seq(1,2,3,5,6)) + val rdd: RDD[Double] = sc.parallelize(Seq(1, 2, 3, 5, 6)) val buckets: Array[Double] = Array(0.0, 5.0, 10.0) val histogramResults: Array[Long] = rdd.histogram(buckets) - val expectedHistogramResults: Array[Long] = Array(3,2) + val histogramResults2: Array[Long] = rdd.histogram(buckets, true) + val expectedHistogramResults: Array[Long] = Array(3, 2) assert(histogramResults === expectedHistogramResults) + assert(histogramResults2 === expectedHistogramResults) } test("WorksInRangeWithTwoBucketsAndNaN") { // Make sure that it works with two equally spaced buckets and elements in each - val rdd: RDD[Double] = sc.parallelize(Seq(1,2,3,5,6,Double.NaN)) + val rdd: RDD[Double] = sc.parallelize(Seq(1, 2, 3, 5, 6, Double.NaN)) val buckets: Array[Double] = Array(0.0, 5.0, 10.0) val histogramResults: Array[Long] = rdd.histogram(buckets) - val expectedHistogramResults: Array[Long] = Array(3,2) + val histogramResults2: Array[Long] = rdd.histogram(buckets, true) + val expectedHistogramResults: Array[Long] = Array(3, 2) assert(histogramResults === expectedHistogramResults) + assert(histogramResults2 === expectedHistogramResults) } test("WorksInRangeWithTwoUnevenBuckets") { // Make sure that it works with two unequally spaced buckets and elements in each - val rdd: RDD[Double] = sc.parallelize(Seq(1,2,3,5,6)) + val rdd: RDD[Double] = sc.parallelize(Seq(1, 2, 3, 5, 6)) val buckets: Array[Double] = Array(0.0, 5.0, 11.0) val histogramResults: Array[Long] = rdd.histogram(buckets) - val expectedHistogramResults: Array[Long] = Array(3,2) + val expectedHistogramResults: Array[Long] = Array(3, 2) assert(histogramResults === expectedHistogramResults) } test("WorksMixedRangeWithTwoUnevenBuckets") { // Make sure that it works with two unequally spaced buckets and elements in each - val rdd: RDD[Double] = sc.parallelize(Seq(-0.01,0.0,1,2,3,5,6,11.0,11.01)) + val rdd: RDD[Double] = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.0, 11.01)) val buckets: Array[Double] = Array(0.0, 5.0, 11.0) val histogramResults: Array[Long] = rdd.histogram(buckets) - val expectedHistogramResults: Array[Long] = Array(4,3) + val expectedHistogramResults: Array[Long] = Array(4, 3) assert(histogramResults === expectedHistogramResults) } test("WorksMixedRangeWithFourUnevenBuckets") { // Make sure that it works with two unequally spaced buckets and elements in each - val rdd: RDD[Double] = sc.parallelize(Seq(-0.01,0.0,1,2,3,5,6,11.01,12.0,199.0,200.0,200.1)) - val buckets: Array[Double] = Array(0.0, 5.0, 11.0,12.0,200.0) + val rdd: RDD[Double] = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, + 200.0, 200.1)) + val buckets: Array[Double] = Array(0.0, 5.0, 11.0, 12.0, 200.0) val histogramResults: Array[Long] = rdd.histogram(buckets) - val expectedHistogramResults: Array[Long] = Array(4,2,1,3) + val expectedHistogramResults: Array[Long] = Array(4, 2, 1, 3) assert(histogramResults === expectedHistogramResults) } test("WorksMixedRangeWithUnevenBucketsAndNaN") { // Make sure that it works with two unequally spaced buckets and elements in each - val rdd: RDD[Double] = sc.parallelize(Seq(-0.01,0.0,1,2,3,5,6,11.01,12.0,199.0,200.0,200.1,Double.NaN)) - val buckets: Array[Double] = Array(0.0, 5.0, 11.0,12.0,200.0) + val rdd: RDD[Double] = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, + 200.0, 200.1, Double.NaN)) + val buckets: Array[Double] = Array(0.0, 5.0, 11.0, 12.0, 200.0) val histogramResults: Array[Long] = rdd.histogram(buckets) - val expectedHistogramResults: Array[Long] = Array(4,2,1,3) + val expectedHistogramResults: Array[Long] = Array(4, 2, 1, 3) assert(histogramResults === expectedHistogramResults) } // Make sure this works with a NaN end bucket test("WorksMixedRangeWithUnevenBucketsAndNaNAndNaNRange") { // Make sure that it works with two unequally spaced buckets and elements in each - val rdd: RDD[Double] = sc.parallelize(Seq(-0.01,0.0,1,2,3,5,6,11.01,12.0,199.0,200.0,200.1,Double.NaN)) - val buckets: Array[Double] = Array(0.0, 5.0, 11.0,12.0,200.0,Double.NaN) + val rdd: RDD[Double] = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, + 200.0, 200.1, Double.NaN)) + val buckets: Array[Double] = Array(0.0, 5.0, 11.0, 12.0, 200.0, Double.NaN) val histogramResults: Array[Long] = rdd.histogram(buckets) - val expectedHistogramResults: Array[Long] = Array(4,2,1,2,3) + val expectedHistogramResults: Array[Long] = Array(4, 2, 1, 2, 3) assert(histogramResults === expectedHistogramResults) } // Make sure this works with a NaN end bucket and an inifity test("WorksMixedRangeWithUnevenBucketsAndNaNAndNaNRangeAndInfity") { // Make sure that it works with two unequally spaced buckets and elements in each - val rdd: RDD[Double] = sc.parallelize(Seq(-0.01,0.0,1,2,3,5,6,11.01,12.0,199.0,200.0,200.1,1.0/0.0,-1.0/0.0,Double.NaN)) - val buckets: Array[Double] = Array(0.0, 5.0, 11.0,12.0,200.0,Double.NaN) + val rdd: RDD[Double] = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, + 200.0, 200.1, 1.0/0.0, -1.0/0.0, Double.NaN)) + val buckets: Array[Double] = Array(0.0, 5.0, 11.0, 12.0, 200.0, Double.NaN) val histogramResults: Array[Long] = rdd.histogram(buckets) - val expectedHistogramResults: Array[Long] = Array(4,2,1,2,4) + val expectedHistogramResults: Array[Long] = Array(4, 2, 1, 2, 4) assert(histogramResults === expectedHistogramResults) } test("WorksWithOutOfRangeWithInfiniteBuckets") { // Verify that out of range works with two buckets - val rdd: RDD[Double] = sc.parallelize(Seq(10.01,-0.01,Double.NaN)) - val buckets: Array[Double] = Array(-1.0/0.0 ,0.0, 1.0/0.0) + val rdd: RDD[Double] = sc.parallelize(Seq(10.01, -0.01, Double.NaN)) + val buckets: Array[Double] = Array(-1.0/0.0 , 0.0, 1.0/0.0) val histogramResults: Array[Long] = rdd.histogram(buckets) - val expectedHistogramResults: Array[Long] = Array(1,1) + val expectedHistogramResults: Array[Long] = Array(1, 1) assert(histogramResults === expectedHistogramResults) } // Test the failure mode with an invalid bucket array test("ThrowsExceptionOnInvalidBucketArray") { val rdd: RDD[Double] = sc.parallelize(Seq(1.0)) // Empty array - intercept[IllegalArgumentException]{ + intercept[IllegalArgumentException] { val buckets: Array[Double] = Array.empty[Double] val result = rdd.histogram(buckets) } // Single element array - intercept[IllegalArgumentException] - { + intercept[IllegalArgumentException] { val buckets: Array[Double] = Array(1.0) val result = rdd.histogram(buckets) } @@ -170,25 +187,45 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { // Test automatic histogram function test("WorksWithoutBucketsBasic") { // Verify the basic case of one bucket and all elements in that bucket works - val rdd: RDD[Double] = sc.parallelize(Seq(1,2,3,4)) + val rdd: RDD[Double] = sc.parallelize(Seq(1, 2, 3, 4)) val (histogramBuckets, histogramResults) = rdd.histogram(1) val expectedHistogramResults: Array[Long] = Array(4) - val expectedHistogramBuckets: Array[Double] = Array(1.0,4.0) + val expectedHistogramBuckets: Array[Double] = Array(1.0, 4.0) + assert(histogramResults === expectedHistogramResults) + assert(histogramBuckets === expectedHistogramBuckets) + } + // Test automatic histogram function with a single element + test("WorksWithoutBucketsBasicSingleElement") { + // Verify the basic case of one bucket and all elements in that bucket works + val rdd: RDD[Double] = sc.parallelize(Seq(1)) + val (histogramBuckets, histogramResults) = rdd.histogram(1) + val expectedHistogramResults: Array[Long] = Array(1) + val expectedHistogramBuckets: Array[Double] = Array(1.0, 1.0) + assert(histogramResults === expectedHistogramResults) + assert(histogramBuckets === expectedHistogramBuckets) + } + // Test automatic histogram function with a single element + test("WorksWithoutBucketsBasicNoRange") { + // Verify the basic case of one bucket and all elements in that bucket works + val rdd: RDD[Double] = sc.parallelize(Seq(1, 1, 1, 1)) + val (histogramBuckets, histogramResults) = rdd.histogram(1) + val expectedHistogramResults: Array[Long] = Array(4) + val expectedHistogramBuckets: Array[Double] = Array(1.0, 1.0) assert(histogramResults === expectedHistogramResults) assert(histogramBuckets === expectedHistogramBuckets) } test("WorksWithoutBucketsBasicTwo") { // Verify the basic case of one bucket and all elements in that bucket works - val rdd: RDD[Double] = sc.parallelize(Seq(1,2,3,4)) + val rdd: RDD[Double] = sc.parallelize(Seq(1, 2, 3, 4)) val (histogramBuckets, histogramResults) = rdd.histogram(2) - val expectedHistogramResults: Array[Long] = Array(2,2) - val expectedHistogramBuckets: Array[Double] = Array(1.0,2.5,4.0) + val expectedHistogramResults: Array[Long] = Array(2, 2) + val expectedHistogramBuckets: Array[Double] = Array(1.0, 2.5, 4.0) assert(histogramResults === expectedHistogramResults) assert(histogramBuckets === expectedHistogramBuckets) } test("WorksWithoutBucketsWithMoreRequestedThanElements") { // Verify the basic case of one bucket and all elements in that bucket works - val rdd: RDD[Double] = sc.parallelize(Seq(1,2)) + val rdd: RDD[Double] = sc.parallelize(Seq(1, 2)) val (histogramBuckets, histogramResults) = rdd.histogram(10) val expectedHistogramResults: Array[Long] = Array(1, 0, 0, 0, 0, 0, 0, 0, 0, 1) @@ -197,37 +234,24 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { assert(histogramResults === expectedHistogramResults) assert(histogramBuckets === expectedHistogramBuckets) } + // Test the failure mode with an invalid RDD test("ThrowsExceptionOnInvalidRDDs") { // infinity - intercept[UnsupportedOperationException]{ - val rdd: RDD[Double] = sc.parallelize(Seq(1,1.0/0.0)) + intercept[UnsupportedOperationException] { + val rdd: RDD[Double] = sc.parallelize(Seq(1, 1.0/0.0)) val result = rdd.histogram(1) } // NaN - intercept[UnsupportedOperationException] - { - val rdd: RDD[Double] = sc.parallelize(Seq(1,Double.NaN)) + intercept[UnsupportedOperationException] { + val rdd: RDD[Double] = sc.parallelize(Seq(1, Double.NaN)) val result = rdd.histogram(1) } // Empty - intercept[UnsupportedOperationException] - { + intercept[UnsupportedOperationException] { val rdd: RDD[Double] = sc.parallelize(Seq()) val result = rdd.histogram(1) } - // Single element - intercept[UnsupportedOperationException] - { - val rdd: RDD[Double] = sc.parallelize(Seq(1)) - val result = rdd.histogram(1) - } - // No Range - intercept[UnsupportedOperationException] - { - val rdd: RDD[Double] = sc.parallelize(Seq(1,1,1)) - val result = rdd.histogram(1) - } } } -- cgit v1.2.3 From 20b33bc4b5de1addd943c7a1e6d5d2366d9cd445 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 21 Oct 2013 00:21:37 -0700 Subject: Remove extranious type declerations --- core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index 33738ee094..02d75eccc5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -99,13 +99,13 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { throw new UnsupportedOperationException( "Histogram on either an empty RDD or RDD containing +/-infinity or NaN") } - val increment: Double = (max-min)/bucketCount.toDouble + val increment = (max-min)/bucketCount.toDouble val range = if (increment != 0) { Range.Double.inclusive(min, max, increment) } else { List(min, min) } - val buckets: Array[Double] = range.toArray + val buckets = range.toArray (buckets, histogram(buckets, true)) } -- cgit v1.2.3 From a48d88d206fae348720ab077a624b3c57293374f Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 2 Nov 2013 21:13:18 -0700 Subject: Replace magic lengths with constants in PySpark. Write the length of the accumulators section up-front rather than terminating it with a negative length. I find this easier to read. --- .../org/apache/spark/api/python/PythonRDD.scala | 26 +++++++++++++--------- python/pyspark/serializers.py | 6 +++++ python/pyspark/worker.py | 13 ++++++----- 3 files changed, 29 insertions(+), 16 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 12b4d94a56..0d5913ec60 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -132,7 +132,7 @@ private[spark] class PythonRDD[T: ClassManifest]( val obj = new Array[Byte](length) stream.readFully(obj) obj - case -3 => + case SpecialLengths.TIMING_DATA => // Timing data from worker val bootTime = stream.readLong() val initTime = stream.readLong() @@ -143,24 +143,24 @@ private[spark] class PythonRDD[T: ClassManifest]( val total = finishTime - startTime logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, init, finish)) read - case -2 => + case SpecialLengths.PYTHON_EXCEPTION_THROWN => // Signals that an exception has been thrown in python val exLength = stream.readInt() val obj = new Array[Byte](exLength) stream.readFully(obj) throw new PythonException(new String(obj)) - case -1 => + case SpecialLengths.END_OF_DATA_SECTION => // We've finished the data section of the output, but we can still - // read some accumulator updates; let's do that, breaking when we - // get a negative length record. - var len2 = stream.readInt() - while (len2 >= 0) { - val update = new Array[Byte](len2) + // read some accumulator updates: + val numAccumulatorUpdates = stream.readInt() + (1 to numAccumulatorUpdates).foreach { _ => + val updateLen = stream.readInt() + val update = new Array[Byte](updateLen) stream.readFully(update) accumulator += Collections.singletonList(update) - len2 = stream.readInt() + } - new Array[Byte](0) + Array.empty[Byte] } } catch { case eof: EOFException => { @@ -197,6 +197,12 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends val asJavaPairRDD : JavaPairRDD[Long, Array[Byte]] = JavaPairRDD.fromRDD(this) } +private object SpecialLengths { + val END_OF_DATA_SECTION = -1 + val PYTHON_EXCEPTION_THROWN = -2 + val TIMING_DATA = -3 +} + private[spark] object PythonRDD { /** Strips the pickle PROTO and STOP opcodes from the start and end of a pickle */ diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 54fed1c9c7..fbc280fd37 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -19,6 +19,12 @@ import struct import cPickle +class SpecialLengths(object): + END_OF_DATA_SECTION = -1 + PYTHON_EXCEPTION_THROWN = -2 + TIMING_DATA = -3 + + class Batch(object): """ Used to store multiple RDD entries as a single Java object. diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index d63c2aaef7..7696df9d1c 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -31,7 +31,8 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, read_with_length, write_int, \ - read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file + read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file \ + SpecialLengths def load_obj(infile): @@ -39,7 +40,7 @@ def load_obj(infile): def report_times(outfile, boot, init, finish): - write_int(-3, outfile) + write_int(SpecialLengths.TIMING_DATA, outfile) write_long(1000 * boot, outfile) write_long(1000 * init, outfile) write_long(1000 * finish, outfile) @@ -82,16 +83,16 @@ def main(infile, outfile): for obj in func(split_index, iterator): write_with_length(dumps(obj), outfile) except Exception as e: - write_int(-2, outfile) + write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) write_with_length(traceback.format_exc(), outfile) sys.exit(-1) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) # Mark the beginning of the accumulators section of the output - write_int(-1, outfile) - for aid, accum in _accumulatorRegistry.items(): + write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) + write_int(len(_accumulatorRegistry), outfile) + for (aid, accum) in _accumulatorRegistry.items(): write_with_length(dump_pickle((aid, accum._value)), outfile) - write_int(-1, outfile) if __name__ == '__main__': -- cgit v1.2.3 From 7d68a81a8ed5f49fefb3bd0fa0b9d3835cc7d86e Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 3 Nov 2013 11:03:02 -0800 Subject: Remove Pickle-wrapping of Java objects in PySpark. If we support custom serializers, the Python worker will know what type of input to expect, so we won't need to wrap Tuple2 and Strings into pickled tuples and strings. --- .../org/apache/spark/api/python/PythonRDD.scala | 106 ++++++++------------- python/pyspark/context.py | 10 +- python/pyspark/rdd.py | 11 ++- python/pyspark/serializers.py | 18 ++++ python/pyspark/worker.py | 14 ++- 5 files changed, 78 insertions(+), 81 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 0d5913ec60..eb0b0db0cc 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -75,7 +75,7 @@ private[spark] class PythonRDD[T: ClassManifest]( // Partition index dataOut.writeInt(split.index) // sparkFilesDir - PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dataOut) + dataOut.writeUTF(SparkFiles.getRootDirectory) // Broadcast variables dataOut.writeInt(broadcastVars.length) for (broadcast <- broadcastVars) { @@ -85,9 +85,7 @@ private[spark] class PythonRDD[T: ClassManifest]( } // Python includes (*.zip and *.egg files) dataOut.writeInt(pythonIncludes.length) - for (f <- pythonIncludes) { - PythonRDD.writeAsPickle(f, dataOut) - } + pythonIncludes.foreach(dataOut.writeUTF) dataOut.flush() // Serialized user code for (elem <- command) { @@ -96,7 +94,7 @@ private[spark] class PythonRDD[T: ClassManifest]( printOut.flush() // Data values for (elem <- parent.iterator(split, context)) { - PythonRDD.writeAsPickle(elem, dataOut) + PythonRDD.writeToStream(elem, dataOut) } dataOut.flush() printOut.flush() @@ -205,60 +203,7 @@ private object SpecialLengths { private[spark] object PythonRDD { - /** Strips the pickle PROTO and STOP opcodes from the start and end of a pickle */ - def stripPickle(arr: Array[Byte]) : Array[Byte] = { - arr.slice(2, arr.length - 1) - } - - /** - * Write strings, pickled Python objects, or pairs of pickled objects to a data output stream. - * The data format is a 32-bit integer representing the pickled object's length (in bytes), - * followed by the pickled data. - * - * Pickle module: - * - * http://docs.python.org/2/library/pickle.html - * - * The pickle protocol is documented in the source of the `pickle` and `pickletools` modules: - * - * http://hg.python.org/cpython/file/2.6/Lib/pickle.py - * http://hg.python.org/cpython/file/2.6/Lib/pickletools.py - * - * @param elem the object to write - * @param dOut a data output stream - */ - def writeAsPickle(elem: Any, dOut: DataOutputStream) { - if (elem.isInstanceOf[Array[Byte]]) { - val arr = elem.asInstanceOf[Array[Byte]] - dOut.writeInt(arr.length) - dOut.write(arr) - } else if (elem.isInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]) { - val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]] - val length = t._1.length + t._2.length - 3 - 3 + 4 // stripPickle() removes 3 bytes - dOut.writeInt(length) - dOut.writeByte(Pickle.PROTO) - dOut.writeByte(Pickle.TWO) - dOut.write(PythonRDD.stripPickle(t._1)) - dOut.write(PythonRDD.stripPickle(t._2)) - dOut.writeByte(Pickle.TUPLE2) - dOut.writeByte(Pickle.STOP) - } else if (elem.isInstanceOf[String]) { - // For uniformity, strings are wrapped into Pickles. - val s = elem.asInstanceOf[String].getBytes("UTF-8") - val length = 2 + 1 + 4 + s.length + 1 - dOut.writeInt(length) - dOut.writeByte(Pickle.PROTO) - dOut.writeByte(Pickle.TWO) - dOut.write(Pickle.BINUNICODE) - dOut.writeInt(Integer.reverseBytes(s.length)) - dOut.write(s) - dOut.writeByte(Pickle.STOP) - } else { - throw new SparkException("Unexpected RDD type") - } - } - - def readRDDFromPickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) : + def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int): JavaRDD[Array[Byte]] = { val file = new DataInputStream(new FileInputStream(filename)) val objs = new collection.mutable.ArrayBuffer[Array[Byte]] @@ -276,15 +221,46 @@ private[spark] object PythonRDD { JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) } - def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) { + def writeStringAsPickle(elem: String, dOut: DataOutputStream) { + val s = elem.getBytes("UTF-8") + val length = 2 + 1 + 4 + s.length + 1 + dOut.writeInt(length) + dOut.writeByte(Pickle.PROTO) + dOut.writeByte(Pickle.TWO) + dOut.write(Pickle.BINUNICODE) + dOut.writeInt(Integer.reverseBytes(s.length)) + dOut.write(s) + dOut.writeByte(Pickle.STOP) + } + + def writeToStream(elem: Any, dataOut: DataOutputStream) { + elem match { + case bytes: Array[Byte] => + dataOut.writeInt(bytes.length) + dataOut.write(bytes) + case pair: (Array[Byte], Array[Byte]) => + dataOut.writeInt(pair._1.length) + dataOut.write(pair._1) + dataOut.writeInt(pair._2.length) + dataOut.write(pair._2) + case str: String => + // Until we've implemented full custom serializer support, we need to return + // strings as Pickles to properly support union() and cartesian(): + writeStringAsPickle(str, dataOut) + case other => + throw new SparkException("Unexpected element type " + other.getClass) + } + } + + def writeToFile[T](items: java.util.Iterator[T], filename: String) { import scala.collection.JavaConverters._ - writeIteratorToPickleFile(items.asScala, filename) + writeToFile(items.asScala, filename) } - def writeIteratorToPickleFile[T](items: Iterator[T], filename: String) { + def writeToFile[T](items: Iterator[T], filename: String) { val file = new DataOutputStream(new FileOutputStream(filename)) for (item <- items) { - writeAsPickle(item, file) + writeToStream(item, file) } file.close() } @@ -300,10 +276,6 @@ private object Pickle { val TWO: Byte = 0x02.toByte val BINUNICODE: Byte = 'X' val STOP: Byte = '.' - val TUPLE2: Byte = 0x86.toByte - val EMPTY_LIST: Byte = ']' - val MARK: Byte = '(' - val APPENDS: Byte = 'e' } private class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] { diff --git a/python/pyspark/context.py b/python/pyspark/context.py index a7ca8bc888..0fec1a6bf6 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -42,7 +42,7 @@ class SparkContext(object): _gateway = None _jvm = None - _writeIteratorToPickleFile = None + _writeToFile = None _takePartition = None _next_accum_id = 0 _active_spark_context = None @@ -125,8 +125,8 @@ class SparkContext(object): if not SparkContext._gateway: SparkContext._gateway = launch_gateway() SparkContext._jvm = SparkContext._gateway.jvm - SparkContext._writeIteratorToPickleFile = \ - SparkContext._jvm.PythonRDD.writeIteratorToPickleFile + SparkContext._writeToFile = \ + SparkContext._jvm.PythonRDD.writeToFile SparkContext._takePartition = \ SparkContext._jvm.PythonRDD.takePartition @@ -190,8 +190,8 @@ class SparkContext(object): for x in c: write_with_length(dump_pickle(x), tempFile) tempFile.close() - readRDDFromPickleFile = self._jvm.PythonRDD.readRDDFromPickleFile - jrdd = readRDDFromPickleFile(self._jsc, tempFile.name, numSlices) + readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile + jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices) return RDD(jrdd, self) def textFile(self, name, minSplits=None): diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 7019fb8bee..d3c4d13a1e 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -54,6 +54,7 @@ class RDD(object): self.is_checkpointed = False self.ctx = ctx self._partitionFunc = None + self._stage_input_is_pairs = False @property def context(self): @@ -344,6 +345,7 @@ class RDD(object): yield pair else: yield pair + java_cartesian._stage_input_is_pairs = True return java_cartesian.flatMap(unpack_batches) def groupBy(self, f, numPartitions=None): @@ -391,8 +393,8 @@ class RDD(object): """ Return a list that contains all of the elements in this RDD. """ - picklesInJava = self._jrdd.collect().iterator() - return list(self._collect_iterator_through_file(picklesInJava)) + bytesInJava = self._jrdd.collect().iterator() + return list(self._collect_iterator_through_file(bytesInJava)) def _collect_iterator_through_file(self, iterator): # Transferring lots of data through Py4J can be slow because @@ -400,7 +402,7 @@ class RDD(object): # file and read it back. tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir) tempFile.close() - self.ctx._writeIteratorToPickleFile(iterator, tempFile.name) + self.ctx._writeToFile(iterator, tempFile.name) # Read the data into Python and deserialize it: with open(tempFile.name, 'rb') as tempFile: for item in read_from_pickle_file(tempFile): @@ -941,6 +943,7 @@ class PipelinedRDD(RDD): self.func = func self.preservesPartitioning = preservesPartitioning self._prev_jrdd = prev._jrdd + self._stage_input_is_pairs = prev._stage_input_is_pairs self.is_cached = False self.is_checkpointed = False self.ctx = prev.ctx @@ -959,7 +962,7 @@ class PipelinedRDD(RDD): def batched_func(split, iterator): return batched(oldfunc(split, iterator), batchSize) func = batched_func - cmds = [func, self._bypass_serializer] + cmds = [func, self._bypass_serializer, self._stage_input_is_pairs] pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds) broadcast_vars = ListConverter().convert( [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index fbc280fd37..fd02e1ee8f 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -93,6 +93,14 @@ def write_with_length(obj, stream): stream.write(obj) +def read_mutf8(stream): + """ + Read a string written with Java's DataOutputStream.writeUTF() method. + """ + length = struct.unpack('>H', stream.read(2))[0] + return stream.read(length).decode('utf8') + + def read_with_length(stream): length = read_int(stream) obj = stream.read(length) @@ -112,3 +120,13 @@ def read_from_pickle_file(stream): yield obj except EOFError: return + + +def read_pairs_from_pickle_file(stream): + try: + while True: + a = load_pickle(read_with_length(stream)) + b = load_pickle(read_with_length(stream)) + yield (a, b) + except EOFError: + return \ No newline at end of file diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 7696df9d1c..4e64557fc4 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -31,8 +31,8 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, read_with_length, write_int, \ - read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file \ - SpecialLengths + read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file, \ + SpecialLengths, read_mutf8, read_pairs_from_pickle_file def load_obj(infile): @@ -53,7 +53,7 @@ def main(infile, outfile): return # fetch name of workdir - spark_files_dir = load_pickle(read_with_length(infile)) + spark_files_dir = read_mutf8(infile) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True @@ -68,17 +68,21 @@ def main(infile, outfile): sys.path.append(spark_files_dir) # *.py files that were added will be copied here num_python_includes = read_int(infile) for _ in range(num_python_includes): - sys.path.append(os.path.join(spark_files_dir, load_pickle(read_with_length(infile)))) + sys.path.append(os.path.join(spark_files_dir, read_mutf8(infile))) # now load function func = load_obj(infile) bypassSerializer = load_obj(infile) + stageInputIsPairs = load_obj(infile) if bypassSerializer: dumps = lambda x: x else: dumps = dump_pickle init_time = time.time() - iterator = read_from_pickle_file(infile) + if stageInputIsPairs: + iterator = read_pairs_from_pickle_file(infile) + else: + iterator = read_from_pickle_file(infile) try: for obj in func(split_index, iterator): write_with_length(dumps(obj), outfile) -- cgit v1.2.3 From ef85a51f85c9720bc091367a0d4f80e7ed6b9778 Mon Sep 17 00:00:00 2001 From: Russell Cardullo Date: Fri, 8 Nov 2013 16:36:03 -0800 Subject: Add graphite sink for metrics This adds a metrics sink for graphite. The sink must be configured with the host and port of a graphite node and optionally may be configured with a prefix that will be prepended to all metrics that are sent to graphite. --- conf/metrics.properties.template | 8 +++ core/pom.xml | 4 ++ .../apache/spark/metrics/sink/GraphiteSink.scala | 82 ++++++++++++++++++++++ docs/monitoring.md | 1 + project/SparkBuild.scala | 1 + 5 files changed, 96 insertions(+) create mode 100644 core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala (limited to 'core/src/main/scala/org') diff --git a/conf/metrics.properties.template b/conf/metrics.properties.template index ae10f615d1..1c3d94e1b0 100644 --- a/conf/metrics.properties.template +++ b/conf/metrics.properties.template @@ -80,6 +80,14 @@ # /metrics/aplications/json # App information # /metrics/master/json # Master information +# org.apache.spark.metrics.sink.GraphiteSink +# Name: Default: Description: +# host NONE Hostname of Graphite server +# port NONE Port of Graphite server +# period 10 Poll period +# unit seconds Units of poll period +# prefix EMPTY STRING Prefix to prepend to metric name + ## Examples # Enable JmxSink for all instances by class name #*.sink.jmx.class=org.apache.spark.metrics.sink.JmxSink diff --git a/core/pom.xml b/core/pom.xml index 8621d257e5..6af229c71d 100644 --- a/core/pom.xml +++ b/core/pom.xml @@ -158,6 +158,10 @@ com.codahale.metrics metrics-ganglia + + com.codahale.metrics + metrics-graphite + org.apache.derby derby diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala new file mode 100644 index 0000000000..eb1315e6de --- /dev/null +++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.metrics.sink + +import com.codahale.metrics.MetricRegistry +import com.codahale.metrics.graphite.{GraphiteReporter, Graphite} + +import java.util.Properties +import java.util.concurrent.TimeUnit +import java.net.InetSocketAddress + +import org.apache.spark.metrics.MetricsSystem + +class GraphiteSink(val property: Properties, val registry: MetricRegistry) extends Sink { + val GRAPHITE_DEFAULT_PERIOD = 10 + val GRAPHITE_DEFAULT_UNIT = "SECONDS" + val GRAPHITE_DEFAULT_PREFIX = "" + + val GRAPHITE_KEY_HOST = "host" + val GRAPHITE_KEY_PORT = "port" + val GRAPHITE_KEY_PERIOD = "period" + val GRAPHITE_KEY_UNIT = "unit" + val GRAPHITE_KEY_PREFIX = "prefix" + + def propertyToOption(prop: String) = Option(property.getProperty(prop)) + + if (!propertyToOption(GRAPHITE_KEY_HOST).isDefined) { + throw new Exception("Graphite sink requires 'host' property.") + } + + if (!propertyToOption(GRAPHITE_KEY_PORT).isDefined) { + throw new Exception("Graphite sink requires 'port' property.") + } + + val host = propertyToOption(GRAPHITE_KEY_HOST).get + val port = propertyToOption(GRAPHITE_KEY_PORT).get.toInt + + val pollPeriod = Option(property.getProperty(GRAPHITE_KEY_PERIOD)) match { + case Some(s) => s.toInt + case None => GRAPHITE_DEFAULT_PERIOD + } + + val pollUnit = Option(property.getProperty(GRAPHITE_KEY_UNIT)) match { + case Some(s) => TimeUnit.valueOf(s.toUpperCase()) + case None => TimeUnit.valueOf(GRAPHITE_DEFAULT_UNIT) + } + + val prefix = propertyToOption(GRAPHITE_KEY_PREFIX).getOrElse(GRAPHITE_DEFAULT_PREFIX) + + MetricsSystem.checkMinimalPollingPeriod(pollUnit, pollPeriod) + + val graphite: Graphite = new Graphite(new InetSocketAddress(host, port)) + + val reporter: GraphiteReporter = GraphiteReporter.forRegistry(registry) + .convertDurationsTo(TimeUnit.MILLISECONDS) + .convertRatesTo(TimeUnit.SECONDS) + .prefixedWith(prefix) + .build(graphite) + + override def start() { + reporter.start(pollPeriod, pollUnit) + } + + override def stop() { + reporter.stop() + } +} diff --git a/docs/monitoring.md b/docs/monitoring.md index 5f456b999b..5ed0474477 100644 --- a/docs/monitoring.md +++ b/docs/monitoring.md @@ -50,6 +50,7 @@ Each instance can report to zero or more _sinks_. Sinks are contained in the * `GangliaSink`: Sends metrics to a Ganglia node or multicast group. * `JmxSink`: Registers metrics for viewing in a JXM console. * `MetricsServlet`: Adds a servlet within the existing Spark UI to serve metrics data as JSON data. +* `GraphiteSink`: Sends metrics to a Graphite node. The syntax of the metrics configuration file is defined in an example configuration file, `$SPARK_HOME/conf/metrics.conf.template`. diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala index 45fd30a7c8..0bc2ca8d08 100644 --- a/project/SparkBuild.scala +++ b/project/SparkBuild.scala @@ -229,6 +229,7 @@ object SparkBuild extends Build { "com.codahale.metrics" % "metrics-jvm" % "3.0.0", "com.codahale.metrics" % "metrics-json" % "3.0.0", "com.codahale.metrics" % "metrics-ganglia" % "3.0.0", + "com.codahale.metrics" % "metrics-graphite" % "3.0.0", "com.twitter" % "chill_2.9.3" % "0.3.1", "com.twitter" % "chill-java" % "0.3.1" ) -- cgit v1.2.3 From cbb7f04aef2220ece93dea9f3fa98b5db5f270d6 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 5 Nov 2013 17:52:39 -0800 Subject: Add custom serializer support to PySpark. For now, this only adds MarshalSerializer, but it lays the groundwork for other supporting custom serializers. Many of these mechanisms can also be used to support deserialization of different data formats sent by Java, such as data encoded by MsgPack. This also fixes a bug in SparkContext.union(). --- .../org/apache/spark/api/python/PythonRDD.scala | 23 +- python/epydoc.conf | 2 +- python/pyspark/accumulators.py | 6 +- python/pyspark/context.py | 61 ++-- python/pyspark/rdd.py | 86 +++--- python/pyspark/serializers.py | 310 ++++++++++++++++----- python/pyspark/tests.py | 3 +- python/pyspark/worker.py | 41 ++- python/run-tests | 1 + 9 files changed, 363 insertions(+), 170 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index eb0b0db0cc..ef9bf4db9b 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -221,18 +221,6 @@ private[spark] object PythonRDD { JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) } - def writeStringAsPickle(elem: String, dOut: DataOutputStream) { - val s = elem.getBytes("UTF-8") - val length = 2 + 1 + 4 + s.length + 1 - dOut.writeInt(length) - dOut.writeByte(Pickle.PROTO) - dOut.writeByte(Pickle.TWO) - dOut.write(Pickle.BINUNICODE) - dOut.writeInt(Integer.reverseBytes(s.length)) - dOut.write(s) - dOut.writeByte(Pickle.STOP) - } - def writeToStream(elem: Any, dataOut: DataOutputStream) { elem match { case bytes: Array[Byte] => @@ -244,9 +232,7 @@ private[spark] object PythonRDD { dataOut.writeInt(pair._2.length) dataOut.write(pair._2) case str: String => - // Until we've implemented full custom serializer support, we need to return - // strings as Pickles to properly support union() and cartesian(): - writeStringAsPickle(str, dataOut) + dataOut.writeUTF(str) case other => throw new SparkException("Unexpected element type " + other.getClass) } @@ -271,13 +257,6 @@ private[spark] object PythonRDD { } } -private object Pickle { - val PROTO: Byte = 0x80.toByte - val TWO: Byte = 0x02.toByte - val BINUNICODE: Byte = 'X' - val STOP: Byte = '.' -} - private class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] { override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8") } diff --git a/python/epydoc.conf b/python/epydoc.conf index 1d0d002d36..0b42e729f8 100644 --- a/python/epydoc.conf +++ b/python/epydoc.conf @@ -32,6 +32,6 @@ target: docs/ private: no -exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.serializers +exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.java_gateway pyspark.examples pyspark.shell pyspark.test pyspark.rddsampler pyspark.daemon diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index da3d96689a..2204e9c9ca 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -90,9 +90,11 @@ import struct import SocketServer import threading from pyspark.cloudpickle import CloudPickler -from pyspark.serializers import read_int, read_with_length, load_pickle +from pyspark.serializers import read_int, PickleSerializer +pickleSer = PickleSerializer() + # Holds accumulators registered on the current machine, keyed by ID. This is then used to send # the local accumulator updates back to the driver program at the end of a task. _accumulatorRegistry = {} @@ -211,7 +213,7 @@ class _UpdateRequestHandler(SocketServer.StreamRequestHandler): from pyspark.accumulators import _accumulatorRegistry num_updates = read_int(self.rfile) for _ in range(num_updates): - (aid, update) = load_pickle(read_with_length(self.rfile)) + (aid, update) = pickleSer._read_with_length(self.rfile) _accumulatorRegistry[aid] += update # Write a byte in acknowledgement self.wfile.write(struct.pack("!b", 1)) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 0fec1a6bf6..6bb1c6c3a1 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -26,7 +26,7 @@ from pyspark.accumulators import Accumulator from pyspark.broadcast import Broadcast from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway -from pyspark.serializers import dump_pickle, write_with_length, batched +from pyspark.serializers import PickleSerializer, BatchedSerializer, MUTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.rdd import RDD @@ -51,7 +51,7 @@ class SparkContext(object): def __init__(self, master, jobName, sparkHome=None, pyFiles=None, - environment=None, batchSize=1024): + environment=None, batchSize=1024, serializer=PickleSerializer()): """ Create a new SparkContext. @@ -67,6 +67,7 @@ class SparkContext(object): @param batchSize: The number of Python objects represented as a single Java object. Set 1 to disable batching or -1 to use an unlimited batch size. + @param serializer: The serializer for RDDs. >>> from pyspark.context import SparkContext @@ -83,7 +84,13 @@ class SparkContext(object): self.jobName = jobName self.sparkHome = sparkHome or None # None becomes null in Py4J self.environment = environment or {} - self.batchSize = batchSize # -1 represents a unlimited batch size + self._batchSize = batchSize # -1 represents an unlimited batch size + self._unbatched_serializer = serializer + if batchSize == 1: + self.serializer = self._unbatched_serializer + else: + self.serializer = BatchedSerializer(self._unbatched_serializer, + batchSize) # Create the Java SparkContext through Py4J empty_string_array = self._gateway.new_array(self._jvm.String, 0) @@ -184,15 +191,17 @@ class SparkContext(object): # Make sure we distribute data evenly if it's smaller than self.batchSize if "__len__" not in dir(c): c = list(c) # Make it a list so we can compute its length - batchSize = min(len(c) // numSlices, self.batchSize) + batchSize = min(len(c) // numSlices, self._batchSize) if batchSize > 1: - c = batched(c, batchSize) - for x in c: - write_with_length(dump_pickle(x), tempFile) + serializer = BatchedSerializer(self._unbatched_serializer, + batchSize) + else: + serializer = self._unbatched_serializer + serializer.dump_stream(c, tempFile) tempFile.close() readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices) - return RDD(jrdd, self) + return RDD(jrdd, self, serializer) def textFile(self, name, minSplits=None): """ @@ -201,21 +210,39 @@ class SparkContext(object): RDD of Strings. """ minSplits = minSplits or min(self.defaultParallelism, 2) - jrdd = self._jsc.textFile(name, minSplits) - return RDD(jrdd, self) + return RDD(self._jsc.textFile(name, minSplits), self, + MUTF8Deserializer()) - def _checkpointFile(self, name): + def _checkpointFile(self, name, input_deserializer): jrdd = self._jsc.checkpointFile(name) - return RDD(jrdd, self) + return RDD(jrdd, self, input_deserializer) def union(self, rdds): """ Build the union of a list of RDDs. + + This supports unions() of RDDs with different serialized formats, + although this forces them to be reserialized using the default + serializer: + + >>> path = os.path.join(tempdir, "union-text.txt") + >>> with open(path, "w") as testFile: + ... testFile.write("Hello") + >>> textFile = sc.textFile(path) + >>> textFile.collect() + [u'Hello'] + >>> parallelized = sc.parallelize(["World!"]) + >>> sorted(sc.union([textFile, parallelized]).collect()) + [u'Hello', 'World!'] """ + first_jrdd_deserializer = rdds[0]._jrdd_deserializer + if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds): + rdds = [x._reserialize() for x in rdds] first = rdds[0]._jrdd rest = [x._jrdd for x in rdds[1:]] - rest = ListConverter().convert(rest, self.gateway._gateway_client) - return RDD(self._jsc.union(first, rest), self) + rest = ListConverter().convert(rest, self._gateway._gateway_client) + return RDD(self._jsc.union(first, rest), self, + rdds[0]._jrdd_deserializer) def broadcast(self, value): """ @@ -223,7 +250,9 @@ class SparkContext(object): object for reading it in distributed functions. The variable will be sent to each cluster only once. """ - jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value))) + pickleSer = PickleSerializer() + pickled = pickleSer._dumps(value) + jbroadcast = self._jsc.broadcast(bytearray(pickled)) return Broadcast(jbroadcast.id(), value, jbroadcast, self._pickled_broadcast_vars) @@ -235,7 +264,7 @@ class SparkContext(object): and floating-point numbers if you do not provide one. For other types, a custom AccumulatorParam can be used. """ - if accum_param == None: + if accum_param is None: if isinstance(value, int): accum_param = accumulators.INT_ACCUMULATOR_PARAM elif isinstance(value, float): diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index d3c4d13a1e..6691c30519 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -18,7 +18,7 @@ from base64 import standard_b64encode as b64enc import copy from collections import defaultdict -from itertools import chain, ifilter, imap, product +from itertools import chain, ifilter, imap import operator import os import sys @@ -28,8 +28,8 @@ from tempfile import NamedTemporaryFile from threading import Thread from pyspark import cloudpickle -from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \ - read_from_pickle_file, pack_long +from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ + BatchedSerializer, pack_long from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup from pyspark.statcounter import StatCounter @@ -48,13 +48,12 @@ class RDD(object): operated on in parallel. """ - def __init__(self, jrdd, ctx): + def __init__(self, jrdd, ctx, jrdd_deserializer): self._jrdd = jrdd self.is_cached = False self.is_checkpointed = False self.ctx = ctx - self._partitionFunc = None - self._stage_input_is_pairs = False + self._jrdd_deserializer = jrdd_deserializer @property def context(self): @@ -248,7 +247,23 @@ class RDD(object): >>> rdd.union(rdd).collect() [1, 1, 2, 3, 1, 1, 2, 3] """ - return RDD(self._jrdd.union(other._jrdd), self.ctx) + if self._jrdd_deserializer == other._jrdd_deserializer: + rdd = RDD(self._jrdd.union(other._jrdd), self.ctx, + self._jrdd_deserializer) + return rdd + else: + # These RDDs contain data in different serialized formats, so we + # must normalize them to the default serializer. + self_copy = self._reserialize() + other_copy = other._reserialize() + return RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx, + self.ctx.serializer) + + def _reserialize(self): + if self._jrdd_deserializer == self.ctx.serializer: + return self + else: + return self.map(lambda x: x, preservesPartitioning=True) def __add__(self, other): """ @@ -335,18 +350,9 @@ class RDD(object): [(1, 1), (1, 2), (2, 1), (2, 2)] """ # Due to batching, we can't use the Java cartesian method. - java_cartesian = RDD(self._jrdd.cartesian(other._jrdd), self.ctx) - def unpack_batches(pair): - (x, y) = pair - if type(x) == Batch or type(y) == Batch: - xs = x.items if type(x) == Batch else [x] - ys = y.items if type(y) == Batch else [y] - for pair in product(xs, ys): - yield pair - else: - yield pair - java_cartesian._stage_input_is_pairs = True - return java_cartesian.flatMap(unpack_batches) + deserializer = CartesianDeserializer(self._jrdd_deserializer, + other._jrdd_deserializer) + return RDD(self._jrdd.cartesian(other._jrdd), self.ctx, deserializer) def groupBy(self, f, numPartitions=None): """ @@ -405,7 +411,7 @@ class RDD(object): self.ctx._writeToFile(iterator, tempFile.name) # Read the data into Python and deserialize it: with open(tempFile.name, 'rb') as tempFile: - for item in read_from_pickle_file(tempFile): + for item in self._jrdd_deserializer.load_stream(tempFile): yield item os.unlink(tempFile.name) @@ -573,7 +579,7 @@ class RDD(object): items = [] for partition in range(mapped._jrdd.splits().size()): iterator = self.ctx._takePartition(mapped._jrdd.rdd(), partition) - items.extend(self._collect_iterator_through_file(iterator)) + items.extend(mapped._collect_iterator_through_file(iterator)) if len(items) >= num: break return items[:num] @@ -737,6 +743,7 @@ class RDD(object): # Transferring O(n) objects to Java is too expensive. Instead, we'll # form the hash buckets in Python, transferring O(numPartitions) objects # to Java. Each object is a (splitNumber, [objects]) pair. + outputSerializer = self.ctx._unbatched_serializer def add_shuffle_key(split, iterator): buckets = defaultdict(list) @@ -745,14 +752,14 @@ class RDD(object): buckets[partitionFunc(k) % numPartitions].append((k, v)) for (split, items) in buckets.iteritems(): yield pack_long(split) - yield dump_pickle(Batch(items)) + yield outputSerializer._dumps(items) keyed = PipelinedRDD(self, add_shuffle_key) keyed._bypass_serializer = True pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() partitioner = self.ctx._jvm.PythonPartitioner(numPartitions, id(partitionFunc)) jrdd = pairRDD.partitionBy(partitioner).values() - rdd = RDD(jrdd, self.ctx) + rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer)) # This is required so that id(partitionFunc) remains unique, even if # partitionFunc is a lambda: rdd._partitionFunc = partitionFunc @@ -789,7 +796,8 @@ class RDD(object): numPartitions = self.ctx.defaultParallelism def combineLocally(iterator): combiners = {} - for (k, v) in iterator: + for x in iterator: + (k, v) = x if k not in combiners: combiners[k] = createCombiner(v) else: @@ -931,38 +939,38 @@ class PipelinedRDD(RDD): 20 """ def __init__(self, prev, func, preservesPartitioning=False): - if isinstance(prev, PipelinedRDD) and prev._is_pipelinable(): + if not isinstance(prev, PipelinedRDD) or not prev._is_pipelinable(): + # This transformation is the first in its stage: + self.func = func + self.preservesPartitioning = preservesPartitioning + self._prev_jrdd = prev._jrdd + self._prev_jrdd_deserializer = prev._jrdd_deserializer + else: prev_func = prev.func def pipeline_func(split, iterator): return func(split, prev_func(split, iterator)) self.func = pipeline_func self.preservesPartitioning = \ prev.preservesPartitioning and preservesPartitioning - self._prev_jrdd = prev._prev_jrdd - else: - self.func = func - self.preservesPartitioning = preservesPartitioning - self._prev_jrdd = prev._jrdd - self._stage_input_is_pairs = prev._stage_input_is_pairs + self._prev_jrdd = prev._prev_jrdd # maintain the pipeline + self._prev_jrdd_deserializer = prev._prev_jrdd_deserializer self.is_cached = False self.is_checkpointed = False self.ctx = prev.ctx self.prev = prev self._jrdd_val = None + self._jrdd_deserializer = self.ctx.serializer self._bypass_serializer = False @property def _jrdd(self): if self._jrdd_val: return self._jrdd_val - func = self.func - if not self._bypass_serializer and self.ctx.batchSize != 1: - oldfunc = self.func - batchSize = self.ctx.batchSize - def batched_func(split, iterator): - return batched(oldfunc(split, iterator), batchSize) - func = batched_func - cmds = [func, self._bypass_serializer, self._stage_input_is_pairs] + if self._bypass_serializer: + serializer = NoOpSerializer() + else: + serializer = self.ctx.serializer + cmds = [self.func, self._prev_jrdd_deserializer, serializer] pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds) broadcast_vars = ListConverter().convert( [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index fd02e1ee8f..4fb444443f 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -15,8 +15,58 @@ # limitations under the License. # -import struct +""" +PySpark supports custom serializers for transferring data; this can improve +performance. + +By default, PySpark uses L{PickleSerializer} to serialize objects using Python's +C{cPickle} serializer, which can serialize nearly any Python object. +Other serializers, like L{MarshalSerializer}, support fewer datatypes but can be +faster. + +The serializer is chosen when creating L{SparkContext}: + +>>> from pyspark.context import SparkContext +>>> from pyspark.serializers import MarshalSerializer +>>> sc = SparkContext('local', 'test', serializer=MarshalSerializer()) +>>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10) +[0, 2, 4, 6, 8, 10, 12, 14, 16, 18] +>>> sc.stop() + +By default, PySpark serialize objects in batches; the batch size can be +controlled through SparkContext's C{batchSize} parameter +(the default size is 1024 objects): + +>>> sc = SparkContext('local', 'test', batchSize=2) +>>> rdd = sc.parallelize(range(16), 4).map(lambda x: x) + +Behind the scenes, this creates a JavaRDD with four partitions, each of +which contains two batches of two objects: + +>>> rdd.glom().collect() +[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]] +>>> rdd._jrdd.count() +8L +>>> sc.stop() + +A batch size of -1 uses an unlimited batch size, and a size of 1 disables +batching: + +>>> sc = SparkContext('local', 'test', batchSize=1) +>>> rdd = sc.parallelize(range(16), 4).map(lambda x: x) +>>> rdd.glom().collect() +[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]] +>>> rdd._jrdd.count() +16L +""" + import cPickle +from itertools import chain, izip, product +import marshal +import struct + + +__all__ = ["PickleSerializer", "MarshalSerializer"] class SpecialLengths(object): @@ -25,41 +75,206 @@ class SpecialLengths(object): TIMING_DATA = -3 -class Batch(object): +class Serializer(object): + + def dump_stream(self, iterator, stream): + """ + Serialize an iterator of objects to the output stream. + """ + raise NotImplementedError + + def load_stream(self, stream): + """ + Return an iterator of deserialized objects from the input stream. + """ + raise NotImplementedError + + + def _load_stream_without_unbatching(self, stream): + return self.load_stream(stream) + + # Note: our notion of "equality" is that output generated by + # equal serializers can be deserialized using the same serializer. + + # This default implementation handles the simple cases; + # subclasses should override __eq__ as appropriate. + + def __eq__(self, other): + return isinstance(other, self.__class__) + + def __ne__(self, other): + return not self.__eq__(other) + + +class FramedSerializer(Serializer): + """ + Serializer that writes objects as a stream of (length, data) pairs, + where C{length} is a 32-bit integer and data is C{length} bytes. + """ + + def dump_stream(self, iterator, stream): + for obj in iterator: + self._write_with_length(obj, stream) + + def load_stream(self, stream): + while True: + try: + yield self._read_with_length(stream) + except EOFError: + return + + def _write_with_length(self, obj, stream): + serialized = self._dumps(obj) + write_int(len(serialized), stream) + stream.write(serialized) + + def _read_with_length(self, stream): + length = read_int(stream) + obj = stream.read(length) + if obj == "": + raise EOFError + return self._loads(obj) + + def _dumps(self, obj): + """ + Serialize an object into a byte array. + When batching is used, this will be called with an array of objects. + """ + raise NotImplementedError + + def _loads(self, obj): + """ + Deserialize an object from a byte array. + """ + raise NotImplementedError + + +class BatchedSerializer(Serializer): + """ + Serializes a stream of objects in batches by calling its wrapped + Serializer with streams of objects. + """ + + UNLIMITED_BATCH_SIZE = -1 + + def __init__(self, serializer, batchSize=UNLIMITED_BATCH_SIZE): + self.serializer = serializer + self.batchSize = batchSize + + def _batched(self, iterator): + if self.batchSize == self.UNLIMITED_BATCH_SIZE: + yield list(iterator) + else: + items = [] + count = 0 + for item in iterator: + items.append(item) + count += 1 + if count == self.batchSize: + yield items + items = [] + count = 0 + if items: + yield items + + def dump_stream(self, iterator, stream): + if isinstance(iterator, basestring): + iterator = [iterator] + self.serializer.dump_stream(self._batched(iterator), stream) + + def load_stream(self, stream): + return chain.from_iterable(self._load_stream_without_unbatching(stream)) + + def _load_stream_without_unbatching(self, stream): + return self.serializer.load_stream(stream) + + def __eq__(self, other): + return isinstance(other, BatchedSerializer) and \ + other.serializer == self.serializer + + def __str__(self): + return "BatchedSerializer<%s>" % str(self.serializer) + + +class CartesianDeserializer(FramedSerializer): """ - Used to store multiple RDD entries as a single Java object. + Deserializes the JavaRDD cartesian() of two PythonRDDs. + """ + + def __init__(self, key_ser, val_ser): + self.key_ser = key_ser + self.val_ser = val_ser + + def load_stream(self, stream): + key_stream = self.key_ser._load_stream_without_unbatching(stream) + val_stream = self.val_ser._load_stream_without_unbatching(stream) + key_is_batched = isinstance(self.key_ser, BatchedSerializer) + val_is_batched = isinstance(self.val_ser, BatchedSerializer) + for (keys, vals) in izip(key_stream, val_stream): + keys = keys if key_is_batched else [keys] + vals = vals if val_is_batched else [vals] + for pair in product(keys, vals): + yield pair + + def __eq__(self, other): + return isinstance(other, CartesianDeserializer) and \ + self.key_ser == other.key_ser and self.val_ser == other.val_ser + + def __str__(self): + return "CartesianDeserializer<%s, %s>" % \ + (str(self.key_ser), str(self.val_ser)) + + +class NoOpSerializer(FramedSerializer): + + def _loads(self, obj): return obj + def _dumps(self, obj): return obj + + +class PickleSerializer(FramedSerializer): + """ + Serializes objects using Python's cPickle serializer: + + http://docs.python.org/2/library/pickle.html + + This serializer supports nearly any Python object, but may + not be as fast as more specialized serializers. + """ + + def _dumps(self, obj): return cPickle.dumps(obj, 2) + _loads = cPickle.loads + - This relieves us from having to explicitly track whether an RDD - is stored as batches of objects and avoids problems when processing - the union() of batched and unbatched RDDs (e.g. the union() of textFile() - with another RDD). +class MarshalSerializer(FramedSerializer): """ - def __init__(self, items): - self.items = items + Serializes objects using Python's Marshal serializer: + http://docs.python.org/2/library/marshal.html -def batched(iterator, batchSize): - if batchSize == -1: # unlimited batch size - yield Batch(list(iterator)) - else: - items = [] - count = 0 - for item in iterator: - items.append(item) - count += 1 - if count == batchSize: - yield Batch(items) - items = [] - count = 0 - if items: - yield Batch(items) + This serializer is faster than PickleSerializer but supports fewer datatypes. + """ + + _dumps = marshal.dumps + _loads = marshal.loads -def dump_pickle(obj): - return cPickle.dumps(obj, 2) +class MUTF8Deserializer(Serializer): + """ + Deserializes streams written by Java's DataOutputStream.writeUTF(). + """ + def _loads(self, stream): + length = struct.unpack('>H', stream.read(2))[0] + return stream.read(length).decode('utf8') -load_pickle = cPickle.loads + def load_stream(self, stream): + while True: + try: + yield self._loads(stream) + except struct.error: + return + except EOFError: + return def read_long(stream): @@ -90,43 +305,4 @@ def write_int(value, stream): def write_with_length(obj, stream): write_int(len(obj), stream) - stream.write(obj) - - -def read_mutf8(stream): - """ - Read a string written with Java's DataOutputStream.writeUTF() method. - """ - length = struct.unpack('>H', stream.read(2))[0] - return stream.read(length).decode('utf8') - - -def read_with_length(stream): - length = read_int(stream) - obj = stream.read(length) - if obj == "": - raise EOFError - return obj - - -def read_from_pickle_file(stream): - try: - while True: - obj = load_pickle(read_with_length(stream)) - if type(obj) == Batch: # We don't care about inheritance - for item in obj.items: - yield item - else: - yield obj - except EOFError: - return - - -def read_pairs_from_pickle_file(stream): - try: - while True: - a = load_pickle(read_with_length(stream)) - b = load_pickle(read_with_length(stream)) - yield (a, b) - except EOFError: - return \ No newline at end of file + stream.write(obj) \ No newline at end of file diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 29d6a128f6..621e1cb58c 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -86,7 +86,8 @@ class TestCheckpoint(PySparkTestCase): time.sleep(1) # 1 second self.assertTrue(flatMappedRDD.getCheckpointFile() is not None) - recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile()) + recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile(), + flatMappedRDD._jrdd_deserializer) self.assertEquals([1, 2, 3, 4], recovered.collect()) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 4e64557fc4..5b16d5db7e 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -30,13 +30,17 @@ from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler from pyspark.files import SparkFiles -from pyspark.serializers import write_with_length, read_with_length, write_int, \ - read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file, \ - SpecialLengths, read_mutf8, read_pairs_from_pickle_file +from pyspark.serializers import write_with_length, write_int, read_long, \ + write_long, read_int, SpecialLengths, MUTF8Deserializer, PickleSerializer + + +pickleSer = PickleSerializer() +mutf8_deserializer = MUTF8Deserializer() def load_obj(infile): - return load_pickle(standard_b64decode(infile.readline().strip())) + decoded = standard_b64decode(infile.readline().strip()) + return pickleSer._loads(decoded) def report_times(outfile, boot, init, finish): @@ -53,7 +57,7 @@ def main(infile, outfile): return # fetch name of workdir - spark_files_dir = read_mutf8(infile) + spark_files_dir = mutf8_deserializer._loads(infile) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True @@ -61,31 +65,24 @@ def main(infile, outfile): num_broadcast_variables = read_int(infile) for _ in range(num_broadcast_variables): bid = read_long(infile) - value = read_with_length(infile) - _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value)) + value = pickleSer._read_with_length(infile) + _broadcastRegistry[bid] = Broadcast(bid, value) # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH sys.path.append(spark_files_dir) # *.py files that were added will be copied here num_python_includes = read_int(infile) for _ in range(num_python_includes): - sys.path.append(os.path.join(spark_files_dir, read_mutf8(infile))) + filename = mutf8_deserializer._loads(infile) + sys.path.append(os.path.join(spark_files_dir, filename)) - # now load function + # Load this stage's function and serializer: func = load_obj(infile) - bypassSerializer = load_obj(infile) - stageInputIsPairs = load_obj(infile) - if bypassSerializer: - dumps = lambda x: x - else: - dumps = dump_pickle + deserializer = load_obj(infile) + serializer = load_obj(infile) init_time = time.time() - if stageInputIsPairs: - iterator = read_pairs_from_pickle_file(infile) - else: - iterator = read_from_pickle_file(infile) try: - for obj in func(split_index, iterator): - write_with_length(dumps(obj), outfile) + iterator = deserializer.load_stream(infile) + serializer.dump_stream(func(split_index, iterator), outfile) except Exception as e: write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) write_with_length(traceback.format_exc(), outfile) @@ -96,7 +93,7 @@ def main(infile, outfile): write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) write_int(len(_accumulatorRegistry), outfile) for (aid, accum) in _accumulatorRegistry.items(): - write_with_length(dump_pickle((aid, accum._value)), outfile) + pickleSer._write_with_length((aid, accum._value), outfile) if __name__ == '__main__': diff --git a/python/run-tests b/python/run-tests index cbc554ea9d..d4dad672d2 100755 --- a/python/run-tests +++ b/python/run-tests @@ -37,6 +37,7 @@ run_test "pyspark/rdd.py" run_test "pyspark/context.py" run_test "-m doctest pyspark/broadcast.py" run_test "-m doctest pyspark/accumulators.py" +run_test "-m doctest pyspark/serializers.py" run_test "pyspark/tests.py" if [[ $FAILED != 0 ]]; then -- cgit v1.2.3 From ffa5bedf46fbc89ad5c5658f3b423dfff49b70f0 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 10 Nov 2013 12:58:28 -0800 Subject: Send PySpark commands as bytes insetad of strings. --- .../org/apache/spark/api/python/PythonRDD.scala | 24 ++++------------------ python/pyspark/rdd.py | 12 +++++------ python/pyspark/serializers.py | 5 +++++ python/pyspark/worker.py | 12 ++--------- 4 files changed, 17 insertions(+), 36 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index ef9bf4db9b..132e4fb0d2 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -27,13 +27,12 @@ import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} import org.apache.spark.broadcast.Broadcast import org.apache.spark._ import org.apache.spark.rdd.RDD -import org.apache.spark.rdd.PipedRDD import org.apache.spark.util.Utils private[spark] class PythonRDD[T: ClassManifest]( parent: RDD[T], - command: Seq[String], + command: Array[Byte], envVars: JMap[String, String], pythonIncludes: JList[String], preservePartitoning: Boolean, @@ -44,21 +43,10 @@ private[spark] class PythonRDD[T: ClassManifest]( val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt - // Similar to Runtime.exec(), if we are given a single string, split it into words - // using a standard StringTokenizer (i.e. by spaces) - def this(parent: RDD[T], command: String, envVars: JMap[String, String], - pythonIncludes: JList[String], - preservePartitoning: Boolean, pythonExec: String, - broadcastVars: JList[Broadcast[Array[Byte]]], - accumulator: Accumulator[JList[Array[Byte]]]) = - this(parent, PipedRDD.tokenize(command), envVars, pythonIncludes, preservePartitoning, pythonExec, - broadcastVars, accumulator) - override def getPartitions = parent.partitions override val partitioner = if (preservePartitoning) parent.partitioner else None - override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { val startTime = System.currentTimeMillis val env = SparkEnv.get @@ -71,7 +59,6 @@ private[spark] class PythonRDD[T: ClassManifest]( SparkEnv.set(env) val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) val dataOut = new DataOutputStream(stream) - val printOut = new PrintWriter(stream) // Partition index dataOut.writeInt(split.index) // sparkFilesDir @@ -87,17 +74,14 @@ private[spark] class PythonRDD[T: ClassManifest]( dataOut.writeInt(pythonIncludes.length) pythonIncludes.foreach(dataOut.writeUTF) dataOut.flush() - // Serialized user code - for (elem <- command) { - printOut.println(elem) - } - printOut.flush() + // Serialized command: + dataOut.writeInt(command.length) + dataOut.write(command) // Data values for (elem <- parent.iterator(split, context)) { PythonRDD.writeToStream(elem, dataOut) } dataOut.flush() - printOut.flush() worker.shutdownOutput() } catch { case e: IOException => diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 6691c30519..062f44f81e 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -27,9 +27,8 @@ from subprocess import Popen, PIPE from tempfile import NamedTemporaryFile from threading import Thread -from pyspark import cloudpickle from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ - BatchedSerializer, pack_long + BatchedSerializer, CloudPickleSerializer, pack_long from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup from pyspark.statcounter import StatCounter @@ -970,8 +969,8 @@ class PipelinedRDD(RDD): serializer = NoOpSerializer() else: serializer = self.ctx.serializer - cmds = [self.func, self._prev_jrdd_deserializer, serializer] - pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds) + command = (self.func, self._prev_jrdd_deserializer, serializer) + pickled_command = CloudPickleSerializer()._dumps(command) broadcast_vars = ListConverter().convert( [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], self.ctx._gateway._gateway_client) @@ -982,8 +981,9 @@ class PipelinedRDD(RDD): includes = ListConverter().convert(self.ctx._python_includes, self.ctx._gateway._gateway_client) python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), - pipe_command, env, includes, self.preservesPartitioning, self.ctx.pythonExec, - broadcast_vars, self.ctx._javaAccumulator, class_manifest) + bytearray(pickled_command), env, includes, self.preservesPartitioning, + self.ctx.pythonExec, broadcast_vars, self.ctx._javaAccumulator, + class_manifest) self._jrdd_val = python_rdd.asJavaRDD() return self._jrdd_val diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 4fb444443f..b23804b33c 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -64,6 +64,7 @@ import cPickle from itertools import chain, izip, product import marshal import struct +from pyspark import cloudpickle __all__ = ["PickleSerializer", "MarshalSerializer"] @@ -244,6 +245,10 @@ class PickleSerializer(FramedSerializer): def _dumps(self, obj): return cPickle.dumps(obj, 2) _loads = cPickle.loads +class CloudPickleSerializer(PickleSerializer): + + def _dumps(self, obj): return cloudpickle.dumps(obj, 2) + class MarshalSerializer(FramedSerializer): """ diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 5b16d5db7e..2751f1239e 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -23,7 +23,6 @@ import sys import time import socket import traceback -from base64 import standard_b64decode # CloudPickler needs to be imported so that depicklers are registered using the # copy_reg module. from pyspark.accumulators import _accumulatorRegistry @@ -38,11 +37,6 @@ pickleSer = PickleSerializer() mutf8_deserializer = MUTF8Deserializer() -def load_obj(infile): - decoded = standard_b64decode(infile.readline().strip()) - return pickleSer._loads(decoded) - - def report_times(outfile, boot, init, finish): write_int(SpecialLengths.TIMING_DATA, outfile) write_long(1000 * boot, outfile) @@ -75,10 +69,8 @@ def main(infile, outfile): filename = mutf8_deserializer._loads(infile) sys.path.append(os.path.join(spark_files_dir, filename)) - # Load this stage's function and serializer: - func = load_obj(infile) - deserializer = load_obj(infile) - serializer = load_obj(infile) + command = pickleSer._read_with_length(infile) + (func, deserializer, serializer) = command init_time = time.time() try: iterator = deserializer.load_stream(infile) -- cgit v1.2.3 From c33f802044c02025c3e4530add1b82586156cbb0 Mon Sep 17 00:00:00 2001 From: Henry Saputra Date: Fri, 15 Nov 2013 10:32:20 -0800 Subject: Simple cleanup on Spark's Scala code while testing core and yarn modules: -) Remove some of unused imports as I found them -) Remove ";" in the imports statements -) Remove () at the end of method call like size that does not have size effect. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 1 - .../main/scala/org/apache/spark/deploy/LocalSparkCluster.scala | 6 +++--- .../org/apache/spark/executor/CoarseGrainedExecutorBackend.scala | 2 +- .../src/main/scala/org/apache/spark/executor/ExecutorSource.scala | 2 -- .../scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala | 8 +++----- yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala | 2 +- .../apache/spark/deploy/yarn/ClientDistributedCacheManager.scala | 2 +- .../spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala | 2 +- 8 files changed, 10 insertions(+), 15 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index d884095671..8525844d10 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -24,7 +24,6 @@ import java.util.concurrent.atomic.AtomicInteger import scala.collection.Map import scala.collection.generic.Growable -import scala.collection.JavaConverters._ import scala.collection.mutable.ArrayBuffer import scala.collection.mutable.HashMap diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala index 308a2bfa22..a724900943 100644 --- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala @@ -17,12 +17,12 @@ package org.apache.spark.deploy -import akka.actor.{ActorRef, Props, Actor, ActorSystem, Terminated} +import akka.actor.ActorSystem import org.apache.spark.deploy.worker.Worker import org.apache.spark.deploy.master.Master -import org.apache.spark.util.{Utils, AkkaUtils} -import org.apache.spark.{Logging} +import org.apache.spark.util.Utils +import org.apache.spark.Logging import scala.collection.mutable.ArrayBuffer diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index caee6b01ab..8332631838 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -22,7 +22,7 @@ import java.nio.ByteBuffer import akka.actor.{ActorRef, Actor, Props, Terminated} import akka.remote.{RemoteClientLifeCycleEvent, RemoteClientShutdown, RemoteClientDisconnected} -import org.apache.spark.{Logging, SparkEnv} +import org.apache.spark.Logging import org.apache.spark.TaskState.TaskState import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util.{Utils, AkkaUtils} diff --git a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala index 34ed9c8f73..97176e4f5b 100644 --- a/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala +++ b/core/src/main/scala/org/apache/spark/executor/ExecutorSource.scala @@ -20,8 +20,6 @@ package org.apache.spark.executor import com.codahale.metrics.{Gauge, MetricRegistry} import org.apache.hadoop.fs.FileSystem -import org.apache.hadoop.hdfs.DistributedFileSystem -import org.apache.hadoop.fs.LocalFileSystem import scala.collection.JavaConversions._ diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 4302ef4cda..0e47bd7a10 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -17,9 +17,8 @@ package org.apache.spark.deploy.yarn -import java.io.IOException; +import java.io.IOException import java.net.Socket -import java.security.PrivilegedExceptionAction import java.util.concurrent.CopyOnWriteArrayList import java.util.concurrent.atomic.{AtomicInteger, AtomicReference} import org.apache.hadoop.conf.Configuration @@ -34,7 +33,6 @@ import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{ConverterUtils, Records} import org.apache.spark.{SparkContext, Logging} import org.apache.spark.util.Utils -import org.apache.hadoop.security.UserGroupInformation import scala.collection.JavaConversions._ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) extends Logging { @@ -186,8 +184,8 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e var successed = false try { // Copy - var mainArgs: Array[String] = new Array[String](args.userArgs.size()) - args.userArgs.copyToArray(mainArgs, 0, args.userArgs.size()) + var mainArgs: Array[String] = new Array[String](args.userArgs.size) + args.userArgs.copyToArray(mainArgs, 0, args.userArgs.size) mainMethod.invoke(null, mainArgs) // some job script has "System.exit(0)" at the end, for example SparkPi, SparkLR // userThread will stop here unless it has uncaught exception thrown out diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 4e0e060ddc..c38bdd14ec 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -22,7 +22,7 @@ import java.nio.ByteBuffer import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileContext, FileStatus, FileSystem, Path, FileUtil} -import org.apache.hadoop.fs.permission.FsPermission; +import org.apache.hadoop.fs.permission.FsPermission import org.apache.hadoop.mapred.Master import org.apache.hadoop.net.NetUtils import org.apache.hadoop.io.DataOutputBuffer diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala index 07686fefd7..674c8f8112 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.yarn -import java.net.URI; +import java.net.URI import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.FileStatus diff --git a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala index c0a2af0c6f..2941356bc5 100644 --- a/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala +++ b/yarn/src/test/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManagerSuite.scala @@ -17,7 +17,7 @@ package org.apache.spark.deploy.yarn -import java.net.URI; +import java.net.URI import org.scalatest.FunSuite import org.scalatest.mock.MockitoSugar -- cgit v1.2.3 From b60839e56a335d0e30578d1c4cad5b0319d565df Mon Sep 17 00:00:00 2001 From: BlackNiuza Date: Sun, 17 Nov 2013 21:38:57 +0800 Subject: correct number of tasks in ExecutorsUI --- .../scala/org/apache/spark/ui/exec/ExecutorsUI.scala | 19 ++++++++++--------- 1 file changed, 10 insertions(+), 9 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala index 42e9be6e19..ba198b211d 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala @@ -100,15 +100,16 @@ private[spark] class ExecutorsUI(val sc: SparkContext) { } def getExecInfo(a: Int): Seq[String] = { - val execId = sc.getExecutorStorageStatus(a).blockManagerId.executorId - val hostPort = sc.getExecutorStorageStatus(a).blockManagerId.hostPort - val rddBlocks = sc.getExecutorStorageStatus(a).blocks.size.toString - val memUsed = sc.getExecutorStorageStatus(a).memUsed().toString - val maxMem = sc.getExecutorStorageStatus(a).maxMem.toString - val diskUsed = sc.getExecutorStorageStatus(a).diskUsed().toString - val activeTasks = listener.executorToTasksActive.get(a.toString).map(l => l.size).getOrElse(0) - val failedTasks = listener.executorToTasksFailed.getOrElse(a.toString, 0) - val completedTasks = listener.executorToTasksComplete.getOrElse(a.toString, 0) + val status = sc.getExecutorStorageStatus(a) + val execId = status.blockManagerId.executorId + val hostPort = status.blockManagerId.hostPort + val rddBlocks = status.blocks.size.toString + val memUsed = status.memUsed().toString + val maxMem = status.maxMem.toString + val diskUsed = status.diskUsed().toString + val activeTasks = listener.executorToTasksActive.getOrElse(execId, Seq[Long]()).size + val failedTasks = listener.executorToTasksFailed.getOrElse(execId, 0) + val completedTasks = listener.executorToTasksComplete.getOrElse(execId, 0) val totalTasks = activeTasks + failedTasks + completedTasks Seq( -- cgit v1.2.3 From c30979c7d6009936853e731bfde38ec9d04ea347 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 17 Nov 2013 17:09:40 -0800 Subject: Slightly enhanced PrimitiveVector: 1. Added trim() method 2. Added size method. 3. Renamed getUnderlyingArray to array. 4. Minor documentation update. --- .../spark/util/collection/PrimitiveVector.scala | 40 ++++++++++++++-------- 1 file changed, 26 insertions(+), 14 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala index 369519c559..54a5569b3d 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala @@ -17,35 +17,47 @@ package org.apache.spark.util.collection -/** Provides a simple, non-threadsafe, array-backed vector that can store primitives. */ +/** + * An append-only, non-threadsafe, array-backed vector that is optimized for primitive types. + */ private[spark] class PrimitiveVector[@specialized(Long, Int, Double) V: ClassManifest](initialSize: Int = 64) { - private var numElements = 0 - private var array: Array[V] = _ + private var _numElements = 0 + private var _array: Array[V] = _ // NB: This must be separate from the declaration, otherwise the specialized parent class - // will get its own array with the same initial size. TODO: Figure out why... - array = new Array[V](initialSize) + // will get its own array with the same initial size. + _array = new Array[V](initialSize) def apply(index: Int): V = { - require(index < numElements) - array(index) + require(index < _numElements) + _array(index) } def +=(value: V) { - if (numElements == array.length) { resize(array.length * 2) } - array(numElements) = value - numElements += 1 + if (_numElements == _array.length) { + resize(_array.length * 2) + } + _array(_numElements) = value + _numElements += 1 } - def length = numElements + def capacity: Int = _array.length + + def length: Int = _numElements + + def size: Int = _numElements + + /** Get the underlying array backing this vector. */ + def array: Array[V] = _array - def getUnderlyingArray = array + /** Trims this vector so that the capacity is equal to the size. */ + def trim(): Unit = resize(size) /** Resizes the array, dropping elements if the total length decreases. */ def resize(newLength: Int) { val newArray = new Array[V](newLength) - array.copyToArray(newArray) - array = newArray + _array.copyToArray(newArray) + _array = newArray } } -- cgit v1.2.3 From ecfbaf24426948a9c09225190e71bc1148a9944b Mon Sep 17 00:00:00 2001 From: BlackNiuza Date: Mon, 18 Nov 2013 09:51:40 +0800 Subject: rename "a" to "statusId" --- core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala index ba198b211d..26245a6540 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala @@ -76,7 +76,7 @@ private[spark] class ExecutorsUI(val sc: SparkContext) { } - val execInfo = for (b <- 0 until storageStatusList.size) yield getExecInfo(b) + val execInfo = for (statusId <- 0 until storageStatusList.size) yield getExecInfo(statusId) val execTable = UIUtils.listingTable(execHead, execRow, execInfo) val content = @@ -99,8 +99,8 @@ private[spark] class ExecutorsUI(val sc: SparkContext) { UIUtils.headerSparkPage(content, sc, "Executors (" + execInfo.size + ")", Executors) } - def getExecInfo(a: Int): Seq[String] = { - val status = sc.getExecutorStorageStatus(a) + def getExecInfo(statusId: Int): Seq[String] = { + val status = sc.getExecutorStorageStatus(statusId) val execId = status.blockManagerId.executorId val hostPort = status.blockManagerId.hostPort val rddBlocks = status.blocks.size.toString -- cgit v1.2.3 From 16a2286d6d0e692e0d2e2d568a3c72c053f5047a Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 17 Nov 2013 17:52:02 -0800 Subject: Return the vector itself for trim and resize method in PrimitiveVector. --- .../scala/org/apache/spark/util/collection/PrimitiveVector.scala | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala index 54a5569b3d..b4fcc9229b 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala @@ -48,16 +48,17 @@ class PrimitiveVector[@specialized(Long, Int, Double) V: ClassManifest](initialS def size: Int = _numElements - /** Get the underlying array backing this vector. */ + /** Gets the underlying array backing this vector. */ def array: Array[V] = _array /** Trims this vector so that the capacity is equal to the size. */ - def trim(): Unit = resize(size) + def trim(): PrimitiveVector[V] = resize(size) /** Resizes the array, dropping elements if the total length decreases. */ - def resize(newLength: Int) { + def resize(newLength: Int): PrimitiveVector[V] = { val newArray = new Array[V](newLength) _array.copyToArray(newArray) _array = newArray + this } } -- cgit v1.2.3 From 85763f4942afc095595dc32c853d077bdbf49644 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Sun, 17 Nov 2013 17:59:18 -0800 Subject: Add PrimitiveVectorSuite and fix bug in resize() --- .../spark/util/collection/PrimitiveVector.scala | 3 + .../util/collection/PrimitiveVectorSuite.scala | 117 +++++++++++++++++++++ 2 files changed, 120 insertions(+) create mode 100644 core/src/test/scala/org/apache/spark/util/collection/PrimitiveVectorSuite.scala (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala index b4fcc9229b..20554f0aab 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/PrimitiveVector.scala @@ -59,6 +59,9 @@ class PrimitiveVector[@specialized(Long, Int, Double) V: ClassManifest](initialS val newArray = new Array[V](newLength) _array.copyToArray(newArray) _array = newArray + if (newLength < _numElements) { + _numElements = newLength + } this } } diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveVectorSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveVectorSuite.scala new file mode 100644 index 0000000000..970dade628 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveVectorSuite.scala @@ -0,0 +1,117 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util.collection + +import org.scalatest.FunSuite + +import org.apache.spark.util.SizeEstimator + +class PrimitiveVectorSuite extends FunSuite { + + test("primitive value") { + val vector = new PrimitiveVector[Int] + + for (i <- 0 until 1000) { + vector += i + assert(vector(i) === i) + } + + assert(vector.size === 1000) + assert(vector.size == vector.length) + intercept[IllegalArgumentException] { + vector(1000) + } + + for (i <- 0 until 1000) { + assert(vector(i) == i) + } + } + + test("non-primitive value") { + val vector = new PrimitiveVector[String] + + for (i <- 0 until 1000) { + vector += i.toString + assert(vector(i) === i.toString) + } + + assert(vector.size === 1000) + assert(vector.size == vector.length) + intercept[IllegalArgumentException] { + vector(1000) + } + + for (i <- 0 until 1000) { + assert(vector(i) == i.toString) + } + } + + test("ideal growth") { + val vector = new PrimitiveVector[Long](initialSize = 1) + vector += 1 + for (i <- 1 until 1024) { + vector += i + assert(vector.size === i + 1) + assert(vector.capacity === Integer.highestOneBit(i) * 2) + } + assert(vector.capacity === 1024) + vector += 1024 + assert(vector.capacity === 2048) + } + + test("ideal size") { + val vector = new PrimitiveVector[Long](8192) + for (i <- 0 until 8192) { + vector += i + } + assert(vector.size === 8192) + assert(vector.capacity === 8192) + val actualSize = SizeEstimator.estimate(vector) + val expectedSize = 8192 * 8 + // Make sure we are not allocating a significant amount of memory beyond our expected. + // Due to specialization wonkiness, we need to ensure we don't have 2 copies of the array. + assert(actualSize < expectedSize * 1.1) + } + + test("resizing") { + val vector = new PrimitiveVector[Long] + for (i <- 0 until 4097) { + vector += i + } + assert(vector.size === 4097) + assert(vector.capacity === 8192) + vector.trim() + assert(vector.size === 4097) + assert(vector.capacity === 4097) + vector.resize(5000) + assert(vector.size === 4097) + assert(vector.capacity === 5000) + vector.resize(4000) + assert(vector.size === 4000) + assert(vector.capacity === 4000) + vector.resize(5000) + assert(vector.size === 4000) + assert(vector.capacity === 5000) + for (i <- 0 until 4000) { + assert(vector(i) == i) + } + intercept[IllegalArgumentException] { + vector(4000) + } + } +} -- cgit v1.2.3 From eda05fa43953f601d14853f3416e99e012a1bbba Mon Sep 17 00:00:00 2001 From: "shiyun.wxm" Date: Mon, 18 Nov 2013 13:31:14 +0800 Subject: use HashSet.empty[Long] instead of Seq[Long] --- core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala index 26245a6540..e596690bc3 100644 --- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsUI.scala @@ -107,7 +107,7 @@ private[spark] class ExecutorsUI(val sc: SparkContext) { val memUsed = status.memUsed().toString val maxMem = status.maxMem.toString val diskUsed = status.diskUsed().toString - val activeTasks = listener.executorToTasksActive.getOrElse(execId, Seq[Long]()).size + val activeTasks = listener.executorToTasksActive.getOrElse(execId, HashSet.empty[Long]).size val failedTasks = listener.executorToTasksFailed.getOrElse(execId, 0) val completedTasks = listener.executorToTasksComplete.getOrElse(execId, 0) val totalTasks = activeTasks + failedTasks + completedTasks -- cgit v1.2.3 From 1360f62d15170bd295ceaba85f39401fd8109e51 Mon Sep 17 00:00:00 2001 From: Russell Cardullo Date: Mon, 18 Nov 2013 08:37:09 -0800 Subject: Cleanup GraphiteSink.scala based on feedback * Reorder imports according to the style guide * Consistently use propertyToOption in all places --- .../scala/org/apache/spark/metrics/sink/GraphiteSink.scala | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala index eb1315e6de..cdcfec8ca7 100644 --- a/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala +++ b/core/src/main/scala/org/apache/spark/metrics/sink/GraphiteSink.scala @@ -17,13 +17,13 @@ package org.apache.spark.metrics.sink -import com.codahale.metrics.MetricRegistry -import com.codahale.metrics.graphite.{GraphiteReporter, Graphite} - import java.util.Properties import java.util.concurrent.TimeUnit import java.net.InetSocketAddress +import com.codahale.metrics.MetricRegistry +import com.codahale.metrics.graphite.{GraphiteReporter, Graphite} + import org.apache.spark.metrics.MetricsSystem class GraphiteSink(val property: Properties, val registry: MetricRegistry) extends Sink { @@ -50,12 +50,12 @@ class GraphiteSink(val property: Properties, val registry: MetricRegistry) exten val host = propertyToOption(GRAPHITE_KEY_HOST).get val port = propertyToOption(GRAPHITE_KEY_PORT).get.toInt - val pollPeriod = Option(property.getProperty(GRAPHITE_KEY_PERIOD)) match { + val pollPeriod = propertyToOption(GRAPHITE_KEY_PERIOD) match { case Some(s) => s.toInt case None => GRAPHITE_DEFAULT_PERIOD } - val pollUnit = Option(property.getProperty(GRAPHITE_KEY_UNIT)) match { + val pollUnit = propertyToOption(GRAPHITE_KEY_UNIT) match { case Some(s) => TimeUnit.valueOf(s.toUpperCase()) case None => TimeUnit.valueOf(GRAPHITE_DEFAULT_UNIT) } -- cgit v1.2.3 From 09bdfe3b163559fdcf8771b52ffbe2542883c912 Mon Sep 17 00:00:00 2001 From: Marek Kolodziej Date: Mon, 18 Nov 2013 15:21:43 -0500 Subject: XORShift RNG with unit tests and benchmark To run unit test, start SBT console and type: compile test-only org.apache.spark.util.XORShiftRandomSuite To run benchmark, type: project core console Once the Scala console starts, type: org.apache.spark.util.XORShiftRandom.benchmark(100000000) --- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 2 +- .../main/scala/org/apache/spark/util/Utils.scala | 35 +++++++++- .../org/apache/spark/util/XORShiftRandom.scala | 63 ++++++++++++++++++ .../apache/spark/util/XORShiftRandomSuite.scala | 76 ++++++++++++++++++++++ .../org/apache/spark/mllib/clustering/KMeans.scala | 2 +- 5 files changed, 175 insertions(+), 3 deletions(-) create mode 100644 core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala create mode 100644 core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 6e88be6f6a..dd9c32f253 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -17,7 +17,7 @@ package org.apache.spark.rdd -import java.util.Random +import org.apache.spark.util.{XORShiftRandom => Random} import scala.collection.Map import scala.collection.JavaConversions.mapAsScalaMap diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index fe932d8ede..2df7108d31 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -818,9 +818,42 @@ private[spark] object Utils extends Logging { hashAbs } - /** Returns a copy of the system properties that is thread-safe to iterator over. */ + /* Returns a copy of the system properties that is thread-safe to iterator over. */ def getSystemProperties(): Map[String, String] = { return System.getProperties().clone() .asInstanceOf[java.util.Properties].toMap[String, String] } + + /* Used for performance tersting along with the intToTimesInt() and timeIt methods + * It uses a while loop instead of a for comprehension since the JIT will + * optimize the while loop better than the "for" closure + * e.g. + * import org.apache.spark.util.Utils.{TimesInt, intToTimesInt, timeIt} + * import java.util.Random + * val rand = new Random() + * timeIt(rand.nextDouble, 10000000) + */ + class TimesInt(i: Int) { + def times(f: => Unit) = { + var x = 1 + while (x <= i) { + f + x += 1 + } + } + } + + /* Used in conjunction with TimesInt since it's Scala 2.9.3 + * instead of 2.10 and we don't have implicit classes */ + implicit def intToTimesInt(i: Int) = new TimesInt(i) + + /* See TimesInt for use example */ + def timeIt(f: => Unit, iters: Int): Long = { + + val start = System.currentTimeMillis + iters.times(f) + System.currentTimeMillis - start + + } + } diff --git a/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala new file mode 100644 index 0000000000..3c189c1b69 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala @@ -0,0 +1,63 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.util.{Random => JavaRandom} +import Utils.{TimesInt, intToTimesInt, timeIt} + +class XORShiftRandom(init: Long) extends JavaRandom(init) { + + def this() = this(System.nanoTime) + + var seed = init + + // we need to just override next - this will be called by nextInt, nextDouble, + // nextGaussian, nextLong, etc. + override protected def next(bits: Int): Int = { + + var nextSeed = seed ^ (seed << 21) + nextSeed ^= (nextSeed >>> 35) + nextSeed ^= (nextSeed << 4) + seed = nextSeed + (nextSeed & ((1L << bits) -1)).asInstanceOf[Int] + } +} + +object XORShiftRandom { + + def benchmark(numIters: Int) = { + + val seed = 1L + val million = 1e6.toInt + val javaRand = new JavaRandom(seed) + val xorRand = new XORShiftRandom(seed) + + // warm up the JIT + million.times { + javaRand.nextInt + xorRand.nextInt + } + + /* Return results as a map instead of just printing to screen + in case the user wants to do something with them */ + Map("javaTime" -> timeIt(javaRand.nextInt, numIters), + "xorTime" -> timeIt(xorRand.nextInt, numIters)) + + } + +} \ No newline at end of file diff --git a/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala new file mode 100644 index 0000000000..1691cb4f01 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala @@ -0,0 +1,76 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.util + +import java.util.Random +import org.scalatest.FlatSpec +import org.scalatest.FunSuite +import org.scalatest.matchers.ShouldMatchers +import org.apache.spark.util.Utils.{TimesInt, intToTimesInt, timeIt} + +class XORShiftRandomSuite extends FunSuite with ShouldMatchers { + + def fixture = new { + val seed = 1L + val xorRand = new XORShiftRandom(seed) + val hundMil = 1e8.toInt + } + + /* + * This test is based on a chi-squared test for randomness. The values are hard-coded + * so as not to create Spark's dependency on apache.commons.math3 just to call one + * method for calculating the exact p-value for a given number of random numbers + * and bins. In case one would want to move to a full-fledged test based on + * apache.commons.math3, the relevant class is here: + * org.apache.commons.math3.stat.inference.ChiSquareTest + */ + test ("XORShift generates valid random numbers") { + + val f = fixture + + val numBins = 10 + // create 10 bins + val bins = Array.fill(numBins)(0) + + // populate bins based on modulus of the random number + f.hundMil.times(bins(math.abs(f.xorRand.nextInt) % 10) += 1) + + /* since the seed is deterministic, until the algorithm is changed, we know the result will be + * exactly this: Array(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272, + * 10000790, 10002286, 9998699), so the test will never fail at the prespecified (5%) + * significance level. However, should the RNG implementation change, the test should still + * pass at the same significance level. The chi-squared test done in R gave the following + * results: + * > chisq.test(c(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272, + * 10000790, 10002286, 9998699)) + * Chi-squared test for given probabilities + * data: c(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272, 10000790, + * 10002286, 9998699) + * X-squared = 11.975, df = 9, p-value = 0.2147 + * Note that the p-value was ~0.22. The test will fail if alpha < 0.05, which for 100 million + * random numbers + * and 10 bins will happen at X-squared of ~16.9196. So, the test will fail if X-squared + * is greater than or equal to that number. + */ + val binSize = f.hundMil/numBins + val xSquared = bins.map(x => math.pow((binSize - x), 2)/binSize).sum + xSquared should be < (16.9196) + + } + +} \ No newline at end of file diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index edbf77dbcc..56bcb6c82a 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -18,7 +18,7 @@ package org.apache.spark.mllib.clustering import scala.collection.mutable.ArrayBuffer -import scala.util.Random +import org.apache.spark.util.{XORShiftRandom => Random} import org.apache.spark.SparkContext import org.apache.spark.SparkContext._ -- cgit v1.2.3 From 99cfe89c688ee1499d2723d8ea909651995abe86 Mon Sep 17 00:00:00 2001 From: Marek Kolodziej Date: Mon, 18 Nov 2013 22:00:36 -0500 Subject: Updates to reflect pull request code review --- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 4 +- .../main/scala/org/apache/spark/util/Utils.scala | 43 +++++++---------- .../org/apache/spark/util/XORShiftRandom.scala | 55 +++++++++++++++++----- .../apache/spark/util/XORShiftRandomSuite.scala | 10 ++-- .../org/apache/spark/mllib/clustering/KMeans.scala | 5 +- 5 files changed, 69 insertions(+), 48 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index dd9c32f253..e738bfbdc2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -17,8 +17,6 @@ package org.apache.spark.rdd -import org.apache.spark.util.{XORShiftRandom => Random} - import scala.collection.Map import scala.collection.JavaConversions.mapAsScalaMap import scala.collection.mutable.ArrayBuffer @@ -38,7 +36,7 @@ import org.apache.spark.partial.CountEvaluator import org.apache.spark.partial.GroupedCountEvaluator import org.apache.spark.partial.PartialResult import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{Utils, BoundedPriorityQueue} +import org.apache.spark.util.{Utils, BoundedPriorityQueue, XORShiftRandom => Random} import org.apache.spark.SparkContext._ import org.apache.spark._ diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index 2df7108d31..b98a81053d 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -818,42 +818,33 @@ private[spark] object Utils extends Logging { hashAbs } - /* Returns a copy of the system properties that is thread-safe to iterator over. */ + /** Returns a copy of the system properties that is thread-safe to iterator over. */ def getSystemProperties(): Map[String, String] = { return System.getProperties().clone() .asInstanceOf[java.util.Properties].toMap[String, String] } - /* Used for performance tersting along with the intToTimesInt() and timeIt methods - * It uses a while loop instead of a for comprehension since the JIT will - * optimize the while loop better than the "for" closure - * e.g. - * import org.apache.spark.util.Utils.{TimesInt, intToTimesInt, timeIt} - * import java.util.Random - * val rand = new Random() - * timeIt(rand.nextDouble, 10000000) + /** + * Method executed for repeating a task for side effects. + * Unlike a for comprehension, it permits JVM JIT optimization */ - class TimesInt(i: Int) { - def times(f: => Unit) = { - var x = 1 - while (x <= i) { - f - x += 1 + def times(numIters: Int)(f: => Unit): Unit = { + var i = 0 + while (i < numIters) { + f + i += 1 } - } } - - /* Used in conjunction with TimesInt since it's Scala 2.9.3 - * instead of 2.10 and we don't have implicit classes */ - implicit def intToTimesInt(i: Int) = new TimesInt(i) - - /* See TimesInt for use example */ - def timeIt(f: => Unit, iters: Int): Long = { + /** + * Timing method based on iterations that permit JVM JIT optimization. + * @param numIters number of iterations + * @param f function to be executed + */ + def timeIt(numIters: Int)(f: => Unit): Long = { val start = System.currentTimeMillis - iters.times(f) + times(numIters)(f) System.currentTimeMillis - start - } - + } diff --git a/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala index 3c189c1b69..d443595c24 100644 --- a/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala +++ b/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala @@ -18,18 +18,28 @@ package org.apache.spark.util import java.util.{Random => JavaRandom} -import Utils.{TimesInt, intToTimesInt, timeIt} +import org.apache.spark.util.Utils.timeIt +/** + * This class implements a XORShift random number generator algorithm + * Source: + * Marsaglia, G. (2003). Xorshift RNGs. Journal of Statistical Software, Vol. 8, Issue 14. + * @see Paper + * This implementation is approximately 3.5 times faster than + * {@link java.util.Random java.util.Random}, partly because of the algorithm, but also due + * to renouncing thread safety. JDK's implementation uses an AtomicLong seed, this class + * uses a regular Long. We can forgo thread safety since we use a new instance of the RNG + * for each thread. + */ class XORShiftRandom(init: Long) extends JavaRandom(init) { def this() = this(System.nanoTime) - var seed = init + private var seed = init // we need to just override next - this will be called by nextInt, nextDouble, // nextGaussian, nextLong, etc. - override protected def next(bits: Int): Int = { - + override protected def next(bits: Int): Int = { var nextSeed = seed ^ (seed << 21) nextSeed ^= (nextSeed >>> 35) nextSeed ^= (nextSeed << 4) @@ -38,25 +48,46 @@ class XORShiftRandom(init: Long) extends JavaRandom(init) { } } +/** Contains benchmark method and main method to run benchmark of the RNG */ object XORShiftRandom { + /** + * Main method for running benchmark + * @param args takes one argument - the number of random numbers to generate + */ + def main(args: Array[String]): Unit = { + if (args.length != 1) { + println("Benchmark of XORShiftRandom vis-a-vis java.util.Random") + println("Usage: XORShiftRandom number_of_random_numbers_to_generate") + System.exit(1) + } + println(benchmark(args(0).toInt)) + } + + /** + * @param numIters Number of random numbers to generate while running the benchmark + * @return Map of execution times for {@link java.util.Random java.util.Random} + * and XORShift + */ def benchmark(numIters: Int) = { val seed = 1L val million = 1e6.toInt val javaRand = new JavaRandom(seed) val xorRand = new XORShiftRandom(seed) - - // warm up the JIT - million.times { - javaRand.nextInt - xorRand.nextInt + + // this is just to warm up the JIT - we're not timing anything + timeIt(1e6.toInt) { + javaRand.nextInt() + xorRand.nextInt() } + val iters = timeIt(numIters)(_) + /* Return results as a map instead of just printing to screen - in case the user wants to do something with them */ - Map("javaTime" -> timeIt(javaRand.nextInt, numIters), - "xorTime" -> timeIt(xorRand.nextInt, numIters)) + in case the user wants to do something with them */ + Map("javaTime" -> iters {javaRand.nextInt()}, + "xorTime" -> iters {xorRand.nextInt()}) } diff --git a/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala b/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala index 1691cb4f01..b78367b6ca 100644 --- a/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/XORShiftRandomSuite.scala @@ -21,7 +21,7 @@ import java.util.Random import org.scalatest.FlatSpec import org.scalatest.FunSuite import org.scalatest.matchers.ShouldMatchers -import org.apache.spark.util.Utils.{TimesInt, intToTimesInt, timeIt} +import org.apache.spark.util.Utils.times class XORShiftRandomSuite extends FunSuite with ShouldMatchers { @@ -48,7 +48,7 @@ class XORShiftRandomSuite extends FunSuite with ShouldMatchers { val bins = Array.fill(numBins)(0) // populate bins based on modulus of the random number - f.hundMil.times(bins(math.abs(f.xorRand.nextInt) % 10) += 1) + times(f.hundMil) {bins(math.abs(f.xorRand.nextInt) % 10) += 1} /* since the seed is deterministic, until the algorithm is changed, we know the result will be * exactly this: Array(10004908, 9993136, 9994600, 10000744, 10000091, 10002474, 10002272, @@ -67,9 +67,9 @@ class XORShiftRandomSuite extends FunSuite with ShouldMatchers { * and 10 bins will happen at X-squared of ~16.9196. So, the test will fail if X-squared * is greater than or equal to that number. */ - val binSize = f.hundMil/numBins - val xSquared = bins.map(x => math.pow((binSize - x), 2)/binSize).sum - xSquared should be < (16.9196) + val binSize = f.hundMil/numBins + val xSquared = bins.map(x => math.pow((binSize - x), 2)/binSize).sum + xSquared should be < (16.9196) } diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index 56bcb6c82a..f09ea9e2f7 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -18,15 +18,16 @@ package org.apache.spark.mllib.clustering import scala.collection.mutable.ArrayBuffer -import org.apache.spark.util.{XORShiftRandom => Random} + +import org.jblas.DoubleMatrix import org.apache.spark.SparkContext import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.Logging import org.apache.spark.mllib.util.MLUtils +import org.apache.spark.util.{XORShiftRandom => Random} -import org.jblas.DoubleMatrix /** -- cgit v1.2.3 From 7de180fd13fda2e5d4486dfca9e2a9997ec7f4d0 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 18 Nov 2013 20:05:05 -0800 Subject: Remove explicit boxing --- core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index b002468442..70f7f01d2b 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -169,9 +169,9 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav * If the RDD contains infinity, NaN throws an exception * If the elements in RDD do not vary (max == min) always returns a single bucket. */ - def histogram(bucketCount: Int): Pair[Array[Double], Array[Long]] = { + def histogram(bucketCount: Int): Pair[Array[scala.Double], Array[Long]] = { val result = srdd.histogram(bucketCount) - (result._1.map(scala.Double.box(_)), result._2) + (result._1, result._2) } /** -- cgit v1.2.3 From 13b9bf494b0d1d0e65dc357efe832763127aefd2 Mon Sep 17 00:00:00 2001 From: Matthew Taylor Date: Mon, 18 Nov 2013 06:41:21 +0000 Subject: PartitionPruningRDD is using index from parent --- .../org/apache/spark/rdd/PartitionPruningRDD.scala | 6 +- .../apache/spark/PartitionPruningRDDSuite.scala | 70 ++++++++++++++++++---- 2 files changed, 63 insertions(+), 13 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala index 165cd412fc..2738a00894 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala @@ -34,10 +34,12 @@ class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boo @transient val partitions: Array[Partition] = rdd.partitions.zipWithIndex - .filter(s => partitionFilterFunc(s._2)) + .filter(s => partitionFilterFunc(s._2)).map(_._1).zipWithIndex .map { case(split, idx) => new PartitionPruningRDDPartition(idx, split) : Partition } - override def getParents(partitionId: Int) = List(partitions(partitionId).index) + override def getParents(partitionId: Int) = { + List(partitions(partitionId).asInstanceOf[PartitionPruningRDDPartition].parentSplit.index) + } } diff --git a/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala b/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala index 21f16ef2c6..28e71e835f 100644 --- a/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala @@ -19,27 +19,75 @@ package org.apache.spark import org.scalatest.FunSuite import org.apache.spark.SparkContext._ -import org.apache.spark.rdd.{RDD, PartitionPruningRDD} +import org.apache.spark.rdd.{PartitionPruningRDDPartition, RDD, PartitionPruningRDD} class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext { + test("Pruned Partitions inherit locality prefs correctly") { - class TestPartition(i: Int) extends Partition { - def index = i - } + val rdd = new RDD[Int](sc, Nil) { override protected def getPartitions = { Array[Partition]( - new TestPartition(1), - new TestPartition(2), - new TestPartition(3)) + new TestPartition(0, 1), + new TestPartition(1, 1), + new TestPartition(2, 1)) + } + + def compute(split: Partition, context: TaskContext) = { + Iterator() } - def compute(split: Partition, context: TaskContext) = {Iterator()} } - val prunedRDD = PartitionPruningRDD.create(rdd, {x => if (x==2) true else false}) - val p = prunedRDD.partitions(0) - assert(p.index == 2) + val prunedRDD = PartitionPruningRDD.create(rdd, { + x => if (x == 2) true else false + }) assert(prunedRDD.partitions.length == 1) + val p = prunedRDD.partitions(0) + assert(p.index == 0) + assert(p.asInstanceOf[PartitionPruningRDDPartition].parentSplit.index == 2) } + + + test("Pruned Partitions can be merged ") { + + val rdd = new RDD[Int](sc, Nil) { + override protected def getPartitions = { + Array[Partition]( + new TestPartition(0, 4), + new TestPartition(1, 5), + new TestPartition(2, 6)) + } + + def compute(split: Partition, context: TaskContext) = { + List(split.asInstanceOf[TestPartition].testValue).iterator + } + } + val prunedRDD1 = PartitionPruningRDD.create(rdd, { + x => if (x == 0) true else false + }) + + val prunedRDD2 = PartitionPruningRDD.create(rdd, { + x => if (x == 2) true else false + }) + + val merged = prunedRDD1 ++ prunedRDD2 + + assert(merged.count() == 2) + val take = merged.take(2) + + assert(take.apply(0) == 4) + + assert(take.apply(1) == 6) + + + } + } + +class TestPartition(i: Int, value: Int) extends Partition with Serializable { + def index = i + + def testValue = this.value + +} \ No newline at end of file -- cgit v1.2.3 From f639b65eabcc8666b74af8f13a37c5fdf7e0185f Mon Sep 17 00:00:00 2001 From: Matthew Taylor Date: Tue, 19 Nov 2013 10:48:48 +0000 Subject: PartitionPruningRDD is using index from parent(review changes) --- .../org/apache/spark/rdd/PartitionPruningRDD.scala | 4 +- .../apache/spark/PartitionPruningRDDSuite.scala | 93 ---------------------- .../spark/rdd/PartitionPruningRDDSuite.scala | 86 ++++++++++++++++++++ 3 files changed, 88 insertions(+), 95 deletions(-) delete mode 100644 core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala create mode 100644 core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala index 2738a00894..574dd4233f 100644 --- a/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/PartitionPruningRDD.scala @@ -33,8 +33,8 @@ class PruneDependency[T](rdd: RDD[T], @transient partitionFilterFunc: Int => Boo extends NarrowDependency[T](rdd) { @transient - val partitions: Array[Partition] = rdd.partitions.zipWithIndex - .filter(s => partitionFilterFunc(s._2)).map(_._1).zipWithIndex + val partitions: Array[Partition] = rdd.partitions + .filter(s => partitionFilterFunc(s.index)).zipWithIndex .map { case(split, idx) => new PartitionPruningRDDPartition(idx, split) : Partition } override def getParents(partitionId: Int) = { diff --git a/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala b/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala deleted file mode 100644 index 28e71e835f..0000000000 --- a/core/src/test/scala/org/apache/spark/PartitionPruningRDDSuite.scala +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import org.scalatest.FunSuite -import org.apache.spark.SparkContext._ -import org.apache.spark.rdd.{PartitionPruningRDDPartition, RDD, PartitionPruningRDD} - - -class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext { - - - test("Pruned Partitions inherit locality prefs correctly") { - - val rdd = new RDD[Int](sc, Nil) { - override protected def getPartitions = { - Array[Partition]( - new TestPartition(0, 1), - new TestPartition(1, 1), - new TestPartition(2, 1)) - } - - def compute(split: Partition, context: TaskContext) = { - Iterator() - } - } - val prunedRDD = PartitionPruningRDD.create(rdd, { - x => if (x == 2) true else false - }) - assert(prunedRDD.partitions.length == 1) - val p = prunedRDD.partitions(0) - assert(p.index == 0) - assert(p.asInstanceOf[PartitionPruningRDDPartition].parentSplit.index == 2) - } - - - test("Pruned Partitions can be merged ") { - - val rdd = new RDD[Int](sc, Nil) { - override protected def getPartitions = { - Array[Partition]( - new TestPartition(0, 4), - new TestPartition(1, 5), - new TestPartition(2, 6)) - } - - def compute(split: Partition, context: TaskContext) = { - List(split.asInstanceOf[TestPartition].testValue).iterator - } - } - val prunedRDD1 = PartitionPruningRDD.create(rdd, { - x => if (x == 0) true else false - }) - - val prunedRDD2 = PartitionPruningRDD.create(rdd, { - x => if (x == 2) true else false - }) - - val merged = prunedRDD1 ++ prunedRDD2 - - assert(merged.count() == 2) - val take = merged.take(2) - - assert(take.apply(0) == 4) - - assert(take.apply(1) == 6) - - - } - -} - -class TestPartition(i: Int, value: Int) extends Partition with Serializable { - def index = i - - def testValue = this.value - -} \ No newline at end of file diff --git a/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala new file mode 100644 index 0000000000..53a7b7c44d --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rdd/PartitionPruningRDDSuite.scala @@ -0,0 +1,86 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import org.scalatest.FunSuite +import org.apache.spark.{TaskContext, Partition, SharedSparkContext} + + +class PartitionPruningRDDSuite extends FunSuite with SharedSparkContext { + + + test("Pruned Partitions inherit locality prefs correctly") { + + val rdd = new RDD[Int](sc, Nil) { + override protected def getPartitions = { + Array[Partition]( + new TestPartition(0, 1), + new TestPartition(1, 1), + new TestPartition(2, 1)) + } + + def compute(split: Partition, context: TaskContext) = { + Iterator() + } + } + val prunedRDD = PartitionPruningRDD.create(rdd, { + x => if (x == 2) true else false + }) + assert(prunedRDD.partitions.length == 1) + val p = prunedRDD.partitions(0) + assert(p.index == 0) + assert(p.asInstanceOf[PartitionPruningRDDPartition].parentSplit.index == 2) + } + + + test("Pruned Partitions can be unioned ") { + + val rdd = new RDD[Int](sc, Nil) { + override protected def getPartitions = { + Array[Partition]( + new TestPartition(0, 4), + new TestPartition(1, 5), + new TestPartition(2, 6)) + } + + def compute(split: Partition, context: TaskContext) = { + List(split.asInstanceOf[TestPartition].testValue).iterator + } + } + val prunedRDD1 = PartitionPruningRDD.create(rdd, { + x => if (x == 0) true else false + }) + + val prunedRDD2 = PartitionPruningRDD.create(rdd, { + x => if (x == 2) true else false + }) + + val merged = prunedRDD1 ++ prunedRDD2 + assert(merged.count() == 2) + val take = merged.take(2) + assert(take.apply(0) == 4) + assert(take.apply(1) == 6) + } +} + +class TestPartition(i: Int, value: Int) extends Partition with Serializable { + def index = i + + def testValue = this.value + +} -- cgit v1.2.3 From 9c934b640f76b17097f2cae87fef30b05ce854b7 Mon Sep 17 00:00:00 2001 From: Henry Saputra Date: Tue, 19 Nov 2013 10:19:03 -0800 Subject: Remove the semicolons at the end of Scala code to make it more pure Scala code. Also remove unused imports as I found them along the way. Remove return statements when returning value in the Scala code. Passing compile and tests. --- .../apache/spark/deploy/FaultToleranceTest.scala | 28 +++++++++++----------- .../scala/org/apache/spark/rdd/CartesianRDD.scala | 2 +- .../scala/org/apache/spark/LocalSparkContext.scala | 2 +- .../scala/org/apache/spark/PartitioningSuite.scala | 10 ++++---- .../scala/org/apache/spark/examples/LocalALS.scala | 2 +- .../scala/org/apache/spark/examples/SparkTC.scala | 2 +- .../spark/streaming/examples/ActorWordCount.scala | 2 +- .../spark/streaming/examples/MQTTWordCount.scala | 4 ++-- .../org/apache/spark/streaming/Checkpoint.scala | 2 +- .../streaming/api/java/JavaStreamingContext.scala | 7 +++--- .../streaming/dstream/FlumeInputDStream.scala | 4 ++-- .../apache/spark/streaming/InputStreamsSuite.scala | 4 ++-- .../org/apache/spark/streaming/TestSuiteBase.scala | 2 +- .../spark/deploy/yarn/ApplicationMaster.scala | 4 +++- .../org/apache/spark/deploy/yarn/Client.scala | 27 ++++++++++----------- .../yarn/ClientDistributedCacheManager.scala | 2 +- .../apache/spark/deploy/yarn/WorkerRunnable.scala | 9 ++++--- .../spark/deploy/yarn/YarnSparkHadoopUtil.scala | 5 +--- 18 files changed, 56 insertions(+), 62 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala index 668032a3a2..0aa8852649 100644 --- a/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala +++ b/core/src/main/scala/org/apache/spark/deploy/FaultToleranceTest.scala @@ -1,19 +1,19 @@ /* * - * * Licensed to the Apache Software Foundation (ASF) under one or more - * * contributor license agreements. See the NOTICE file distributed with - * * this work for additional information regarding copyright ownership. - * * The ASF licenses this file to You under the Apache License, Version 2.0 - * * (the "License"); you may not use this file except in compliance with - * * the License. You may obtain a copy of the License at - * * - * * http://www.apache.org/licenses/LICENSE-2.0 - * * - * * Unless required by applicable law or agreed to in writing, software - * * distributed under the License is distributed on an "AS IS" BASIS, - * * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * * See the License for the specific language governing permissions and - * * limitations under the License. + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. * */ diff --git a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala index 9b0c882481..0de22f0e06 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CartesianRDD.scala @@ -70,7 +70,7 @@ class CartesianRDD[T: ClassManifest, U:ClassManifest]( override def compute(split: Partition, context: TaskContext) = { val currSplit = split.asInstanceOf[CartesianPartition] for (x <- rdd1.iterator(currSplit.s1, context); - y <- rdd2.iterator(currSplit.s2, context)) yield (x, y) + y <- rdd2.iterator(currSplit.s2, context)) yield (x, y) } override def getDependencies: Seq[Dependency[_]] = List( diff --git a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala index 459e257d79..8dd5786da6 100644 --- a/core/src/test/scala/org/apache/spark/LocalSparkContext.scala +++ b/core/src/test/scala/org/apache/spark/LocalSparkContext.scala @@ -30,7 +30,7 @@ trait LocalSparkContext extends BeforeAndAfterEach with BeforeAndAfterAll { self @transient var sc: SparkContext = _ override def beforeAll() { - InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory()); + InternalLoggerFactory.setDefaultFactory(new Slf4JLoggerFactory()) super.beforeAll() } diff --git a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala index 7d938917f2..1374d01774 100644 --- a/core/src/test/scala/org/apache/spark/PartitioningSuite.scala +++ b/core/src/test/scala/org/apache/spark/PartitioningSuite.scala @@ -142,11 +142,11 @@ class PartitioningSuite extends FunSuite with SharedSparkContext { .filter(_ >= 0.0) // Run the partitions, including the consecutive empty ones, through StatCounter - val stats: StatCounter = rdd.stats(); - assert(abs(6.0 - stats.sum) < 0.01); - assert(abs(6.0/2 - rdd.mean) < 0.01); - assert(abs(1.0 - rdd.variance) < 0.01); - assert(abs(1.0 - rdd.stdev) < 0.01); + val stats: StatCounter = rdd.stats() + assert(abs(6.0 - stats.sum) < 0.01) + assert(abs(6.0/2 - rdd.mean) < 0.01) + assert(abs(1.0 - rdd.variance) < 0.01) + assert(abs(1.0 - rdd.stdev) < 0.01) // Add other tests here for classes that should be able to handle empty partitions correctly } diff --git a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala index 4af45b2b4a..83db8b9e26 100644 --- a/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala +++ b/examples/src/main/scala/org/apache/spark/examples/LocalALS.scala @@ -120,7 +120,7 @@ object LocalALS { System.exit(1) } } - printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS); + printf("Running with M=%d, U=%d, F=%d, iters=%d\n", M, U, F, ITERATIONS) val R = generateR() diff --git a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala index 5a7a9d1bd8..8543ce0e32 100644 --- a/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala +++ b/examples/src/main/scala/org/apache/spark/examples/SparkTC.scala @@ -65,7 +65,7 @@ object SparkTC { oldCount = nextCount // Perform the join, obtaining an RDD of (y, (z, x)) pairs, // then project the result to obtain the new (x, z) paths. - tc = tc.union(tc.join(edges).map(x => (x._2._2, x._2._1))).distinct().cache(); + tc = tc.union(tc.join(edges).map(x => (x._2._2, x._2._1))).distinct().cache() nextCount = tc.count() } while (nextCount != oldCount) diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala index cd3423a07b..af52b7e9a1 100644 --- a/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/ActorWordCount.scala @@ -120,7 +120,7 @@ object FeederActor { println("Feeder started as:" + feeder) - actorSystem.awaitTermination(); + actorSystem.awaitTermination() } } diff --git a/examples/src/main/scala/org/apache/spark/streaming/examples/MQTTWordCount.scala b/examples/src/main/scala/org/apache/spark/streaming/examples/MQTTWordCount.scala index af698a01d5..ff332a0282 100644 --- a/examples/src/main/scala/org/apache/spark/streaming/examples/MQTTWordCount.scala +++ b/examples/src/main/scala/org/apache/spark/streaming/examples/MQTTWordCount.scala @@ -54,12 +54,12 @@ object MQTTPublisher { client.connect() - val msgtopic: MqttTopic = client.getTopic(topic); + val msgtopic: MqttTopic = client.getTopic(topic) val msg: String = "hello mqtt demo for spark streaming" while (true) { val message: MqttMessage = new MqttMessage(String.valueOf(msg).getBytes()) - msgtopic.publish(message); + msgtopic.publish(message) println("Published data. topic: " + msgtopic.getName() + " Message: " + message) } client.disconnect() diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index bb9febad38..78a2c07204 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -94,7 +94,7 @@ class CheckpointWriter(checkpointDir: String) extends Logging { fs.delete(file, false) fs.rename(writeFile, file) - val finishTime = System.currentTimeMillis(); + val finishTime = System.currentTimeMillis() logInfo("Checkpoint for time " + checkpointTime + " saved to file '" + file + "', took " + bytes.length + " bytes and " + (finishTime - startTime) + " milliseconds") return diff --git a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala index cf30b541e1..7f9dab0ef9 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/api/java/JavaStreamingContext.scala @@ -17,7 +17,7 @@ package org.apache.spark.streaming.api.java -import java.lang.{Long => JLong, Integer => JInt} +import java.lang.{Integer => JInt} import java.io.InputStream import java.util.{Map => JMap, List => JList} @@ -33,10 +33,9 @@ import twitter4j.auth.Authorization import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.api.java.function.{Function => JFunction, Function2 => JFunction2} -import org.apache.spark.api.java.{JavaPairRDD, JavaRDDLike, JavaSparkContext, JavaRDD} +import org.apache.spark.api.java.{JavaPairRDD, JavaSparkContext, JavaRDD} import org.apache.spark.streaming._ import org.apache.spark.streaming.dstream._ -import org.apache.spark.streaming.receivers.{ActorReceiver, ReceiverSupervisorStrategy} /** * A StreamingContext is the main entry point for Spark Streaming functionality. Besides the basic @@ -311,7 +310,7 @@ class JavaStreamingContext(val ssc: StreamingContext) { implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[V]] implicit val cmf: ClassManifest[F] = implicitly[ClassManifest[AnyRef]].asInstanceOf[ClassManifest[F]] - ssc.fileStream[K, V, F](directory); + ssc.fileStream[K, V, F](directory) } /** diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlumeInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlumeInputDStream.scala index 18de772946..a0189eca04 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlumeInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/FlumeInputDStream.scala @@ -137,8 +137,8 @@ class FlumeReceiver( protected override def onStart() { val responder = new SpecificResponder( - classOf[AvroSourceProtocol], new FlumeEventServer(this)); - val server = new NettyServer(responder, new InetSocketAddress(host, port)); + classOf[AvroSourceProtocol], new FlumeEventServer(this)) + val server = new NettyServer(responder, new InetSocketAddress(host, port)) blockGenerator.start() server.start() logInfo("Flume receiver started") diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index a559db468a..7dc82decef 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -124,9 +124,9 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { val clock = ssc.scheduler.clock.asInstanceOf[ManualClock] val input = Seq(1, 2, 3, 4, 5) Thread.sleep(1000) - val transceiver = new NettyTransceiver(new InetSocketAddress("localhost", testPort)); + val transceiver = new NettyTransceiver(new InetSocketAddress("localhost", testPort)) val client = SpecificRequestor.getClient( - classOf[AvroSourceProtocol], transceiver); + classOf[AvroSourceProtocol], transceiver) for (i <- 0 until input.size) { val event = new AvroFlumeEvent diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index be140699c2..8c8c359e6e 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -251,7 +251,7 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { Thread.sleep(500) // Give some time for the forgetting old RDDs to complete } catch { - case e: Exception => e.printStackTrace(); throw e; + case e: Exception => {e.printStackTrace(); throw e} } finally { ssc.stop() } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 0e47bd7a10..3f6e151d89 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -21,6 +21,7 @@ import java.io.IOException import java.net.Socket import java.util.concurrent.CopyOnWriteArrayList import java.util.concurrent.atomic.{AtomicInteger, AtomicReference} + import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.hadoop.net.NetUtils @@ -33,6 +34,7 @@ import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{ConverterUtils, Records} import org.apache.spark.{SparkContext, Logging} import org.apache.spark.util.Utils + import scala.collection.JavaConversions._ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) extends Logging { @@ -63,7 +65,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e ShutdownHookManager.get().addShutdownHook(new AppMasterShutdownHook(this), 30) appAttemptId = getApplicationAttemptId() - isLastAMRetry = appAttemptId.getAttemptId() >= maxAppAttempts; + isLastAMRetry = appAttemptId.getAttemptId() >= maxAppAttempts resourceManager = registerWithResourceManager() // Workaround until hadoop moves to something which has diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index c38bdd14ec..038de30de5 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -17,14 +17,13 @@ package org.apache.spark.deploy.yarn -import java.net.{InetAddress, InetSocketAddress, UnknownHostException, URI} +import java.net.{InetAddress, UnknownHostException, URI} import java.nio.ByteBuffer import org.apache.hadoop.conf.Configuration import org.apache.hadoop.fs.{FileContext, FileStatus, FileSystem, Path, FileUtil} import org.apache.hadoop.fs.permission.FsPermission import org.apache.hadoop.mapred.Master -import org.apache.hadoop.net.NetUtils import org.apache.hadoop.io.DataOutputBuffer import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.api._ @@ -40,9 +39,7 @@ import scala.collection.mutable.HashMap import scala.collection.mutable.Map import scala.collection.JavaConversions._ -import org.apache.spark.Logging -import org.apache.spark.util.Utils -import org.apache.spark.deploy.SparkHadoopUtil +import org.apache.spark.Logging class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl with Logging { @@ -105,7 +102,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl // if we have requested more then the clusters max for a single resource then exit. if (args.workerMemory > maxMem) { - logError("the worker size is to large to run on this cluster " + args.workerMemory); + logError("the worker size is to large to run on this cluster " + args.workerMemory) System.exit(1) } val amMem = args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD @@ -142,8 +139,8 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl var dstHost = dstUri.getHost() if ((srcHost != null) && (dstHost != null)) { try { - srcHost = InetAddress.getByName(srcHost).getCanonicalHostName(); - dstHost = InetAddress.getByName(dstHost).getCanonicalHostName(); + srcHost = InetAddress.getByName(srcHost).getCanonicalHostName() + dstHost = InetAddress.getByName(dstHost).getCanonicalHostName() } catch { case e: UnknownHostException => return false @@ -160,7 +157,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl if (srcUri.getPort() != dstUri.getPort()) { return false } - return true; + return true } /** @@ -172,13 +169,13 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl replication: Short, setPerms: Boolean = false): Path = { val fs = FileSystem.get(conf) - val remoteFs = originalPath.getFileSystem(conf); + val remoteFs = originalPath.getFileSystem(conf) var newPath = originalPath if (! compareFs(remoteFs, fs)) { newPath = new Path(dstDir, originalPath.getName()) logInfo("Uploading " + originalPath + " to " + newPath) - FileUtil.copy(remoteFs, originalPath, fs, newPath, false, conf); - fs.setReplication(newPath, replication); + FileUtil.copy(remoteFs, originalPath, fs, newPath, false, conf) + fs.setReplication(newPath, replication) if (setPerms) fs.setPermission(newPath, new FsPermission(APP_FILE_PERMISSION)) } // resolve any symlinks in the URI path so using a "current" symlink @@ -196,7 +193,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl // Add them as local resources to the AM val fs = FileSystem.get(conf) - val delegTokenRenewer = Master.getMasterPrincipal(conf); + val delegTokenRenewer = Master.getMasterPrincipal(conf) if (UserGroupInformation.isSecurityEnabled()) { if (delegTokenRenewer == null || delegTokenRenewer.length() == 0) { logError("Can't get Master Kerberos principal for use as renewer") @@ -208,7 +205,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl if (UserGroupInformation.isSecurityEnabled()) { val dstFs = dst.getFileSystem(conf) - dstFs.addDelegationTokens(delegTokenRenewer, credentials); + dstFs.addDelegationTokens(delegTokenRenewer, credentials) } val localResources = HashMap[String, LocalResource]() FileSystem.mkdirs(fs, dst, new FsPermission(STAGING_DIR_PERMISSION)) @@ -273,7 +270,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl } } - UserGroupInformation.getCurrentUser().addCredentials(credentials); + UserGroupInformation.getCurrentUser().addCredentials(credentials) return localResources } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala index 674c8f8112..268ab950e8 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala @@ -197,7 +197,7 @@ class ClientDistributedCacheManager() extends Logging { */ def checkPermissionOfOther(fs: FileSystem, path: Path, action: FsAction, statCache: Map[URI, FileStatus]): Boolean = { - val status = getFileStatus(fs, path.toUri(), statCache); + val status = getFileStatus(fs, path.toUri(), statCache) val perms = status.getPermission() val otherAction = perms.getOtherAction() if (otherAction.implies(action)) { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala index 7a66532254..fb966b4784 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala @@ -22,7 +22,7 @@ import java.nio.ByteBuffer import java.security.PrivilegedExceptionAction import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.fs.{FileStatus, FileSystem, Path} +import org.apache.hadoop.fs.Path import org.apache.hadoop.io.DataOutputBuffer import org.apache.hadoop.net.NetUtils import org.apache.hadoop.security.UserGroupInformation @@ -38,7 +38,6 @@ import scala.collection.JavaConversions._ import scala.collection.mutable.HashMap import org.apache.spark.Logging -import org.apache.spark.util.Utils class WorkerRunnable(container: Container, conf: Configuration, masterAddress: String, slaveId: String, hostname: String, workerMemory: Int, workerCores: Int) @@ -204,8 +203,8 @@ class WorkerRunnable(container: Container, conf: Configuration, masterAddress: S // use doAs and remoteUser here so we can add the container token and not // pollute the current users credentials with all of the individual container tokens - val user = UserGroupInformation.createRemoteUser(container.getId().toString()); - val containerToken = container.getContainerToken(); + val user = UserGroupInformation.createRemoteUser(container.getId().toString()) + val containerToken = container.getContainerToken() if (containerToken != null) { user.addToken(ProtoUtils.convertFromProtoFormat(containerToken, cmAddress)) } @@ -217,7 +216,7 @@ class WorkerRunnable(container: Container, conf: Configuration, masterAddress: S cmAddress, conf).asInstanceOf[ContainerManager] } }); - return proxy; + proxy } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala index ca2f1e2565..2ba2366ead 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnSparkHadoopUtil.scala @@ -18,13 +18,10 @@ package org.apache.spark.deploy.yarn import org.apache.spark.deploy.SparkHadoopUtil -import collection.mutable.HashMap import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.yarn.api.ApplicationConstants.Environment -import java.security.PrivilegedExceptionAction /** * Contains util methods to interact with Hadoop from spark. @@ -40,7 +37,7 @@ class YarnSparkHadoopUtil extends SparkHadoopUtil { // add any user credentials to the job conf which are necessary for running on a secure Hadoop cluster override def addCredentials(conf: JobConf) { - val jobCreds = conf.getCredentials(); + val jobCreds = conf.getCredentials() jobCreds.mergeAll(UserGroupInformation.getCurrentUser().getCredentials()) } } -- cgit v1.2.3 From 4093e9393aef95793f2d1d77fd0bbe80c8bb8d11 Mon Sep 17 00:00:00 2001 From: tgravescs Date: Tue, 19 Nov 2013 12:39:26 -0600 Subject: Impove Spark on Yarn Error handling --- .../cluster/CoarseGrainedSchedulerBackend.scala | 1 + .../scheduler/cluster/SimrSchedulerBackend.scala | 1 - docs/running-on-yarn.md | 2 ++ .../spark/deploy/yarn/ApplicationMaster.scala | 39 ++++++++++++++-------- .../org/apache/spark/deploy/yarn/Client.scala | 32 +++++++++++------- .../spark/deploy/yarn/YarnAllocationHandler.scala | 16 ++++++--- 6 files changed, 61 insertions(+), 30 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index a45bee536c..d0ba5bf55d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -199,6 +199,7 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac } override def stop() { + stopExecutors() try { if (driverActor != null) { val future = driverActor.ask(StopDriver)(timeout) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala index 0ea35e2b7a..e000531a26 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala @@ -62,7 +62,6 @@ private[spark] class SimrSchedulerBackend( val conf = new Configuration() val fs = FileSystem.get(conf) fs.delete(new Path(driverFilePath), false) - super.stopExecutors() super.stop() } } diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 6fd1d0d150..4056e9c15d 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -37,6 +37,8 @@ System Properties: * 'spark.yarn.applicationMaster.waitTries', property to set the number of times the ApplicationMaster waits for the the spark master and then also the number of tries it waits for the Spark Context to be intialized. Default is 10. * 'spark.yarn.submit.file.replication', the HDFS replication level for the files uploaded into HDFS for the application. These include things like the spark jar, the app jar, and any distributed cache files/archives. * 'spark.yarn.preserve.staging.files', set to true to preserve the staged files(spark jar, app jar, distributed cache files) at the end of the job rather then delete them. +* 'spark.yarn.scheduler.heartbeat.interval-ms', the interval in ms in which the Spark application master heartbeats into the YARN ResourceManager. Default is 5 seconds. +* 'spark.yarn.max.worker.failures', the maximum number of worker failures before failing the application. Default is the number of workers requested times 2 with minimum of 3. # Launching Spark on YARN diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 0e47bd7a10..89b00415da 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -52,7 +52,9 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e private val maxAppAttempts: Int = conf.getInt(YarnConfiguration.RM_AM_MAX_RETRIES, YarnConfiguration.DEFAULT_RM_AM_MAX_RETRIES) private var isLastAMRetry: Boolean = true - + // default to numWorkers * 2, with minimum of 3 + private val maxNumWorkerFailures = System.getProperty("spark.yarn.max.worker.failures", + math.max(args.numWorkers * 2, 3).toString()).toInt def run() { // setup the directories so things go to yarn approved directories rather @@ -225,12 +227,13 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e if (null != sparkContext) { uiAddress = sparkContext.ui.appUIAddress - this.yarnAllocator = YarnAllocationHandler.newAllocator(yarnConf, resourceManager, appAttemptId, args, - sparkContext.preferredNodeLocationData) + this.yarnAllocator = YarnAllocationHandler.newAllocator(yarnConf, resourceManager, + appAttemptId, args, sparkContext.preferredNodeLocationData) } else { logWarning("Unable to retrieve sparkContext inspite of waiting for " + count * waitTime + - ", numTries = " + numTries) - this.yarnAllocator = YarnAllocationHandler.newAllocator(yarnConf, resourceManager, appAttemptId, args) + ", numTries = " + numTries) + this.yarnAllocator = YarnAllocationHandler.newAllocator(yarnConf, resourceManager, + appAttemptId, args) } } } finally { @@ -249,8 +252,11 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e while(yarnAllocator.getNumWorkersRunning < args.numWorkers && // If user thread exists, then quit ! userThread.isAlive) { - - this.yarnAllocator.allocateContainers(math.max(args.numWorkers - yarnAllocator.getNumWorkersRunning, 0)) + if (yarnAllocator.getNumWorkersFailed >= maxNumWorkerFailures) { + finishApplicationMaster(FinalApplicationStatus.FAILED, + "max number of worker failures reached") + } + yarnAllocator.allocateContainers(math.max(args.numWorkers - yarnAllocator.getNumWorkersRunning, 0)) ApplicationMaster.incrementAllocatorLoop(1) Thread.sleep(100) } @@ -266,21 +272,27 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e // ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapse. val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000) - // must be <= timeoutInterval/ 2. - // On other hand, also ensure that we are reasonably responsive without causing too many requests to RM. - // so atleast 1 minute or timeoutInterval / 10 - whichever is higher. - val interval = math.min(timeoutInterval / 2, math.max(timeoutInterval/ 10, 60000L)) + + // we want to be reasonably responsive without causing too many requests to RM. + val schedulerInterval = + System.getProperty("spark.yarn.scheduler.heartbeat.interval-ms", "5000").toLong + + // must be <= timeoutInterval / 2. + val interval = math.min(timeoutInterval / 2, schedulerInterval) launchReporterThread(interval) } } - // TODO: We might want to extend this to allocate more containers in case they die ! private def launchReporterThread(_sleepTime: Long): Thread = { val sleepTime = if (_sleepTime <= 0 ) 0 else _sleepTime val t = new Thread { override def run() { while (userThread.isAlive) { + if (yarnAllocator.getNumWorkersFailed >= maxNumWorkerFailures) { + finishApplicationMaster(FinalApplicationStatus.FAILED, + "max number of worker failures reached") + } val missingWorkerCount = args.numWorkers - yarnAllocator.getNumWorkersRunning if (missingWorkerCount > 0) { logInfo("Allocating " + missingWorkerCount + " containers to make up for (potentially ?) lost containers") @@ -319,7 +331,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e } */ - def finishApplicationMaster(status: FinalApplicationStatus) { + def finishApplicationMaster(status: FinalApplicationStatus, diagnostics: String = "") { synchronized { if (isFinished) { @@ -333,6 +345,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e .asInstanceOf[FinishApplicationMasterRequest] finishReq.setAppAttemptId(appAttemptId) finishReq.setFinishApplicationStatus(status) + finishReq.setDiagnostics(diagnostics) // set tracking url to empty since we don't have a history server finishReq.setTrackingUrl("") resourceManager.finishApplicationMaster(finishReq) diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index c38bdd14ec..1078d5b826 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -60,6 +60,8 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl val APP_FILE_PERMISSION: FsPermission = FsPermission.createImmutable(0644:Short) def run() { + validateArgs() + init(yarnConf) start() logClusterResourceDetails() @@ -84,6 +86,23 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl System.exit(0) } + def validateArgs() = { + Map((System.getenv("SPARK_JAR") == null) -> "Error: You must set SPARK_JAR environment variable!", + (args.userJar == null) -> "Error: You must specify a user jar!", + (args.userClass == null) -> "Error: You must specify a user class!", + (args.numWorkers <= 0) -> "Error: You must specify atleast 1 worker!", + (args.amMemory <= YarnAllocationHandler.MEMORY_OVERHEAD) -> + ("Error: AM memory size must be greater then: " + YarnAllocationHandler.MEMORY_OVERHEAD), + (args.workerMemory <= YarnAllocationHandler.MEMORY_OVERHEAD) -> + ("Error: Worker memory size must be greater then: " + YarnAllocationHandler.MEMORY_OVERHEAD.toString())) + .foreach { case(cond, errStr) => + if (cond) { + logError(errStr) + args.printUsageAndExit(1) + } + } + } + def getAppStagingDir(appId: ApplicationId): String = { SPARK_STAGING + Path.SEPARATOR + appId.toString() + Path.SEPARATOR } @@ -97,7 +116,6 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl ", queueMaxCapacity=" + queueInfo.getMaximumCapacity + ", queueApplicationCount=" + queueInfo.getApplications.size + ", queueChildQueueCount=" + queueInfo.getChildQueues.size) } - def verifyClusterResources(app: GetNewApplicationResponse) = { val maxMem = app.getMaximumResourceCapability().getMemory() @@ -215,11 +233,6 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() - if (System.getenv("SPARK_JAR") == null || args.userJar == null) { - logError("Error: You must set SPARK_JAR environment variable and specify a user jar!") - System.exit(1) - } - Map(Client.SPARK_JAR -> System.getenv("SPARK_JAR"), Client.APP_JAR -> args.userJar, Client.LOG4J_PROP -> System.getenv("SPARK_LOG4J_CONF")) .foreach { case(destName, _localPath) => @@ -334,7 +347,6 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl JAVA_OPTS += " -Djava.io.tmpdir=" + new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR) + " " - // Commenting it out for now - so that people can refer to the properties if required. Remove it once cpuset version is pushed out. // The context is, default gc for server class machines end up using all cores to do gc - hence if there are multiple containers in same // node, spark gc effects all other containers performance (which can also be other spark containers) @@ -360,11 +372,6 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl javaCommand = Environment.JAVA_HOME.$() + "/bin/java" } - if (args.userClass == null) { - logError("Error: You must specify a user class!") - System.exit(1) - } - val commands = List[String](javaCommand + " -server " + JAVA_OPTS + @@ -442,6 +449,7 @@ object Client { System.setProperty("SPARK_YARN_MODE", "true") val args = new ClientArguments(argStrings) + new Client(args).run } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala index 25da9aa917..507a0743fd 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala @@ -72,9 +72,11 @@ private[yarn] class YarnAllocationHandler(val conf: Configuration, val resourceM // Used to generate a unique id per worker private val workerIdCounter = new AtomicInteger() private val lastResponseId = new AtomicInteger() + private val numWorkersFailed = new AtomicInteger() def getNumWorkersRunning: Int = numWorkersRunning.intValue + def getNumWorkersFailed: Int = numWorkersFailed.intValue def isResourceConstraintSatisfied(container: Container): Boolean = { container.getResource.getMemory >= (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD) @@ -253,8 +255,16 @@ private[yarn] class YarnAllocationHandler(val conf: Configuration, val resourceM else { // simply decrement count - next iteration of ReporterThread will take care of allocating ! numWorkersRunning.decrementAndGet() - logInfo("Container completed ? nodeId: " + containerId + ", state " + completedContainer.getState + - " httpaddress: " + completedContainer.getDiagnostics) + logInfo("Container completed not by us ? nodeId: " + containerId + ", state " + completedContainer.getState + + " httpaddress: " + completedContainer.getDiagnostics + " exit status: " + completedContainer.getExitStatus()) + + // Hadoop 2.2.X added a ContainerExitStatus we should switch to use + // there are some exit status' we shouldn't necessarily count against us, but for + // now I think its ok as none of the containers are expected to exit + if (completedContainer.getExitStatus() != 0) { + logInfo("Container marked as failed: " + containerId) + numWorkersFailed.incrementAndGet() + } } allocatedHostToContainersMap.synchronized { @@ -378,8 +388,6 @@ private[yarn] class YarnAllocationHandler(val conf: Configuration, val resourceM val releasedContainerList = createReleasedContainerList() req.addAllReleases(releasedContainerList) - - if (numWorkers > 0) { logInfo("Allocating " + numWorkers + " worker containers with " + (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + " of memory each.") } -- cgit v1.2.3 From 10be58f251b5e883295bd46383c0a9758555f8fc Mon Sep 17 00:00:00 2001 From: Henry Saputra Date: Tue, 19 Nov 2013 16:56:23 -0800 Subject: Another set of changes to remove unnecessary semicolon (;) from Scala code. Passed the sbt/sbt compile and test --- .../src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala | 2 +- core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala | 2 +- streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala | 4 +++- .../main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala | 2 +- yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala | 2 +- .../org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala | 2 +- yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala | 4 ++-- 7 files changed, 10 insertions(+), 8 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala index 481ff8c3e0..b1e1576dad 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala @@ -76,7 +76,7 @@ private[spark] object ShuffleCopier extends Logging { extends FileClientHandler with Logging { override def handle(ctx: ChannelHandlerContext, in: ByteBuf, header: FileHeader) { - logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)"); + logDebug("Received Block: " + header.blockId + " (" + header.fileLen + "B)") resultCollectCallBack(header.blockId, header.fileLen.toLong, in.readBytes(header.fileLen)) } diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index c1c7aa70e6..fbd822867f 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -133,7 +133,7 @@ private[spark] class StagePage(parent: JobProgressUI) { summary ++

Summary Metrics for {numCompleted} Completed Tasks

++
{summaryTable.getOrElse("No tasks have reported metrics yet.")}
++ -

Tasks

++ taskTable; +

Tasks

++ taskTable headerSparkPage(content, parent.sc, "Details for Stage %d".format(stageId), Stages) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index 78a2c07204..9271914eb5 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -124,7 +124,9 @@ class CheckpointWriter(checkpointDir: String) extends Logging { def stop() { synchronized { - if (stopped) return ; + if (stopped) { + return + } stopped = true } executor.shutdown() diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 3f6e151d89..997a6dc1ec 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -195,7 +195,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e successed = true } finally { logDebug("finishing main") - isLastAMRetry = true; + isLastAMRetry = true if (successed) { ApplicationMaster.this.finishApplicationMaster(FinalApplicationStatus.SUCCEEDED) } else { diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 038de30de5..49a8cfde81 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -351,7 +351,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl } // Command for the ApplicationMaster - var javaCommand = "java"; + var javaCommand = "java" val javaHome = System.getenv("JAVA_HOME") if ((javaHome != null && !javaHome.isEmpty()) || env.isDefinedAt("JAVA_HOME")) { javaCommand = Environment.JAVA_HOME.$() + "/bin/java" diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala index 268ab950e8..5f159b073f 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientDistributedCacheManager.scala @@ -201,7 +201,7 @@ class ClientDistributedCacheManager() extends Logging { val perms = status.getPermission() val otherAction = perms.getOtherAction() if (otherAction.implies(action)) { - return true; + return true } return false } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala index fb966b4784..a4d6e1d87d 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala @@ -107,7 +107,7 @@ class WorkerRunnable(container: Container, conf: Configuration, masterAddress: S credentials.writeTokenStorageToStream(dob) ctx.setContainerTokens(ByteBuffer.wrap(dob.getData())) - var javaCommand = "java"; + var javaCommand = "java" val javaHome = System.getenv("JAVA_HOME") if ((javaHome != null && !javaHome.isEmpty()) || env.isDefinedAt("JAVA_HOME")) { javaCommand = Environment.JAVA_HOME.$() + "/bin/java" @@ -215,7 +215,7 @@ class WorkerRunnable(container: Container, conf: Configuration, masterAddress: S return rpc.getProxy(classOf[ContainerManager], cmAddress, conf).asInstanceOf[ContainerManager] } - }); + }) proxy } -- cgit v1.2.3 From bcc6ed30bf7189ebf0226f212b4e39830b830b6e Mon Sep 17 00:00:00 2001 From: Marek Kolodziej Date: Tue, 19 Nov 2013 20:50:38 -0500 Subject: Formatting and scoping (private[spark]) updates --- core/src/main/scala/org/apache/spark/util/Utils.scala | 2 +- core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index b98a81053d..a79e64e810 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -833,7 +833,7 @@ private[spark] object Utils extends Logging { while (i < numIters) { f i += 1 - } + } } /** diff --git a/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala b/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala index d443595c24..e9907e6c85 100644 --- a/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala +++ b/core/src/main/scala/org/apache/spark/util/XORShiftRandom.scala @@ -31,7 +31,7 @@ import org.apache.spark.util.Utils.timeIt * uses a regular Long. We can forgo thread safety since we use a new instance of the RNG * for each thread. */ -class XORShiftRandom(init: Long) extends JavaRandom(init) { +private[spark] class XORShiftRandom(init: Long) extends JavaRandom(init) { def this() = this(System.nanoTime) @@ -49,7 +49,7 @@ class XORShiftRandom(init: Long) extends JavaRandom(init) { } /** Contains benchmark method and main method to run benchmark of the RNG */ -object XORShiftRandom { +private[spark] object XORShiftRandom { /** * Main method for running benchmark -- cgit v1.2.3 From 22724659db8d711492f58c90d530be2f4a5b3de9 Mon Sep 17 00:00:00 2001 From: Marek Kolodziej Date: Wed, 20 Nov 2013 07:03:36 -0500 Subject: Make XORShiftRandom explicit in KMeans and roll it back for RDD --- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 4 +++- .../src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala | 8 ++++---- 2 files changed, 7 insertions(+), 5 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index e738bfbdc2..6e88be6f6a 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -17,6 +17,8 @@ package org.apache.spark.rdd +import java.util.Random + import scala.collection.Map import scala.collection.JavaConversions.mapAsScalaMap import scala.collection.mutable.ArrayBuffer @@ -36,7 +38,7 @@ import org.apache.spark.partial.CountEvaluator import org.apache.spark.partial.GroupedCountEvaluator import org.apache.spark.partial.PartialResult import org.apache.spark.storage.StorageLevel -import org.apache.spark.util.{Utils, BoundedPriorityQueue, XORShiftRandom => Random} +import org.apache.spark.util.{Utils, BoundedPriorityQueue} import org.apache.spark.SparkContext._ import org.apache.spark._ diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala index f09ea9e2f7..0dee9399a8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/KMeans.scala @@ -26,7 +26,7 @@ import org.apache.spark.SparkContext._ import org.apache.spark.rdd.RDD import org.apache.spark.Logging import org.apache.spark.mllib.util.MLUtils -import org.apache.spark.util.{XORShiftRandom => Random} +import org.apache.spark.util.XORShiftRandom @@ -196,7 +196,7 @@ class KMeans private ( */ private def initRandom(data: RDD[Array[Double]]): Array[ClusterCenters] = { // Sample all the cluster centers in one pass to avoid repeated scans - val sample = data.takeSample(true, runs * k, new Random().nextInt()).toSeq + val sample = data.takeSample(true, runs * k, new XORShiftRandom().nextInt()).toSeq Array.tabulate(runs)(r => sample.slice(r * k, (r + 1) * k).toArray) } @@ -211,7 +211,7 @@ class KMeans private ( */ private def initKMeansParallel(data: RDD[Array[Double]]): Array[ClusterCenters] = { // Initialize each run's center to a random point - val seed = new Random().nextInt() + val seed = new XORShiftRandom().nextInt() val sample = data.takeSample(true, runs, seed).toSeq val centers = Array.tabulate(runs)(r => ArrayBuffer(sample(r))) @@ -223,7 +223,7 @@ class KMeans private ( for (r <- 0 until runs) yield (r, KMeans.pointCost(centerArrays(r), point)) }.reduceByKey(_ + _).collectAsMap() val chosen = data.mapPartitionsWithIndex { (index, points) => - val rand = new Random(seed ^ (step << 16) ^ index) + val rand = new XORShiftRandom(seed ^ (step << 16) ^ index) for { p <- points r <- 0 until runs -- cgit v1.2.3 From fc78f67da2fd28744e8119e28f4bb8a29926b3ad Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Thu, 21 Nov 2013 16:54:23 -0800 Subject: Added logging of scheduler delays to UI --- .../scala/org/apache/spark/ui/jobs/StagePage.scala | 33 ++++++++++++++++++++-- 1 file changed, 31 insertions(+), 2 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index fbd822867f..fc8c334cb5 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -60,11 +60,13 @@ private[spark] class StagePage(parent: JobProgressUI) { var activeTime = 0L listener.stageIdToTasksActive(stageId).foreach(activeTime += _.timeRunning(now)) + val finishedTasks = listener.stageIdToTaskInfos(stageId).filter(_._1.finished) + val summary =
  • - CPU time: + Total duration across all tasks: {parent.formatDuration(listener.stageIdToTime.getOrElse(stageId, 0L) + activeTime)}
  • {if (hasShuffleRead) @@ -104,6 +106,30 @@ private[spark] class StagePage(parent: JobProgressUI) { val serviceQuantiles = "Duration" +: Distribution(serviceTimes).get.getQuantiles().map( ms => parent.formatDuration(ms.toLong)) + val gettingResultTimes = validTasks.map{case (info, metrics, exception) => + if (info.gettingResultTime > 0) { + (info.finishTime - info.gettingResultTime).toDouble + } else { + 0.0 + } + } + val gettingResultQuantiles = ("Time spent fetching task results" +: + Distribution(gettingResultTimes).get.getQuantiles().map( + millis => parent.formatDuration(millis.toLong))) + // The scheduler delay includes the network delay to send the task to the worker + // machine and to send back the result (but not the time to fetch the task result, + // if it needed to be fetched from the block manager on the worker). + val schedulerDelays = validTasks.map{case (info, metrics, exception) => + if (info.gettingResultTime > 0) { + (info.gettingResultTime - info.launchTime).toDouble + } else { + (info.finishTime - info.launchTime).toDouble + } + } + val schedulerDelayQuantiles = ("Scheduler delay" +: + Distribution(schedulerDelays).get.getQuantiles().map( + millis => parent.formatDuration(millis.toLong))) + def getQuantileCols(data: Seq[Double]) = Distribution(data).get.getQuantiles().map(d => Utils.bytesToString(d.toLong)) @@ -119,7 +145,10 @@ private[spark] class StagePage(parent: JobProgressUI) { } val shuffleWriteQuantiles = "Shuffle Write" +: getQuantileCols(shuffleWriteSizes) - val listings: Seq[Seq[String]] = Seq(serviceQuantiles, + val listings: Seq[Seq[String]] = Seq( + serviceQuantiles, + gettingResultQuantiles, + schedulerDelayQuantiles, if (hasShuffleRead) shuffleReadQuantiles else Nil, if (hasShuffleWrite) shuffleWriteQuantiles else Nil) -- cgit v1.2.3 From 53b94ef2f5179bdbebe70883b2593b569518e77e Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Thu, 21 Nov 2013 17:17:06 -0800 Subject: TimeTrackingOutputStream should pass on calls to close() and flush(). Without this fix you get a huge number of open shuffles after running shuffles. --- core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala | 2 ++ 1 file changed, 2 insertions(+) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala index 32d2dd0694..0a32df7c89 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockObjectWriter.scala @@ -101,6 +101,8 @@ class DiskBlockObjectWriter( def write(i: Int): Unit = callWithTiming(out.write(i)) override def write(b: Array[Byte]) = callWithTiming(out.write(b)) override def write(b: Array[Byte], off: Int, len: Int) = callWithTiming(out.write(b, off, len)) + override def close() = out.close() + override def flush() = out.flush() } private val syncWrites = System.getProperty("spark.shuffle.sync", "false").toBoolean -- cgit v1.2.3 From ab3cefde5349d0de85b23b49feef493ff0b2d1ed Mon Sep 17 00:00:00 2001 From: Raymond Liu Date: Wed, 23 Oct 2013 09:42:25 +0800 Subject: Add YarnClientClusterScheduler and Backend. With this scheduler, the user application is launched locally, While the executor will be launched by YARN on remote nodes. This enables spark-shell to run upon YARN. --- .../main/scala/org/apache/spark/SparkContext.scala | 25 +++ docs/running-on-yarn.md | 27 ++- .../org/apache/spark/deploy/yarn/Client.scala | 13 +- .../apache/spark/deploy/yarn/ClientArguments.scala | 40 ++-- .../apache/spark/deploy/yarn/WorkerLauncher.scala | 246 +++++++++++++++++++++ .../cluster/YarnClientClusterScheduler.scala | 47 ++++ .../cluster/YarnClientSchedulerBackend.scala | 109 +++++++++ 7 files changed, 484 insertions(+), 23 deletions(-) create mode 100644 yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala create mode 100644 yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala create mode 100644 yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 42b2985b50..3a80241daa 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -226,6 +226,31 @@ class SparkContext( scheduler.initialize(backend) scheduler + case "yarn-client" => + val scheduler = try { + val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClientClusterScheduler") + val cons = clazz.getConstructor(classOf[SparkContext]) + cons.newInstance(this).asInstanceOf[ClusterScheduler] + + } catch { + case th: Throwable => { + throw new SparkException("YARN mode not available ?", th) + } + } + + val backend = try { + val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend") + val cons = clazz.getConstructor(classOf[ClusterScheduler], classOf[SparkContext]) + cons.newInstance(scheduler, this).asInstanceOf[CoarseGrainedSchedulerBackend] + } catch { + case th: Throwable => { + throw new SparkException("YARN mode not available ?", th) + } + } + + scheduler.initialize(backend) + scheduler + case MESOS_REGEX(mesosUrl) => MesosNativeLibrary.load() val scheduler = new ClusterScheduler(this) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 4056e9c15d..68fd6c2ab1 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -45,6 +45,10 @@ System Properties: Ensure that HADOOP_CONF_DIR or YARN_CONF_DIR points to the directory which contains the (client side) configuration files for the hadoop cluster. This would be used to connect to the cluster, write to the dfs and submit jobs to the resource manager. +There are two scheduler mode that can be used to launch spark application on YARN. + +## Launch spark application by YARN Client with yarn-standalone mode. + The command to launch the YARN Client is as follows: SPARK_JAR= ./spark-class org.apache.spark.deploy.yarn.Client \ @@ -52,6 +56,7 @@ The command to launch the YARN Client is as follows: --class \ --args \ --num-workers \ + --master-class --master-memory \ --worker-memory \ --worker-cores \ @@ -85,11 +90,29 @@ For example: $ cat $YARN_APP_LOGS_DIR/$YARN_APP_ID/container*_000001/stdout Pi is roughly 3.13794 -The above starts a YARN Client programs which periodically polls the Application Master for status updates and displays them in the console. The client will exit once your application has finished running. +The above starts a YARN Client programs which start the default Application Master. Then SparkPi will be run as a child thread of Application Master, YARN Client will periodically polls the Application Master for status updates and displays them in the console. The client will exit once your application has finished running. + +With this mode, your application is actually run on the remote machine where the Application Master is run upon. Thus application that involve local interaction will not work well, e.g. spark-shell. + +## Launch spark application with yarn-client mode. + +With yarn-client mode, the application will be launched locally. Just like running application or spark-shell on Local / Mesos / Standalone mode. The launch method is also the similar with them, just make sure that when you need to specify a master url, use "yarn-client" instead. And you also need to export the env value for SPARK_JAR and SPARK_YARN_APP_JAR + +In order to tune worker core/number/memory etc. You need to export SPARK_WORKER_CORES, SPARK_WORKER_MEMORY, SPARK_WORKER_INSTANCES e.g. by ./conf/spark-env.sh + +For example: + + SPARK_JAR=./assembly/target/scala-{{site.SCALA_VERSION}}/spark-assembly-{{site.SPARK_VERSION}}-hadoop2.0.5-alpha.jar \ + SPARK_YARN_APP_JAR=examples/target/scala-{{site.SCALA_VERSION}}/spark-examples-assembly-{{site.SPARK_VERSION}}.jar \ + ./run-example org.apache.spark.examples.SparkPi yarn-client + + + SPARK_JAR=./assembly/target/scala-{{site.SCALA_VERSION}}/spark-assembly-{{site.SPARK_VERSION}}-hadoop2.0.5-alpha.jar \ + SPARK_YARN_APP_JAR=examples/target/scala-{{site.SCALA_VERSION}}/spark-examples-assembly-{{site.SPARK_VERSION}}.jar \ + MASTER=yarn-client ./spark-shell # Important Notes -- When your application instantiates a Spark context it must use a special "yarn-standalone" master url. This starts the scheduler without forcing it to connect to a cluster. A good way to handle this is to pass "yarn-standalone" as an argument to your program, as shown in the example above. - We do not requesting container resources based on the number of cores. Thus the numbers of cores given via command line arguments cannot be guaranteed. - The local directories used for spark will be the local directories configured for YARN (Hadoop Yarn config yarn.nodemanager.local-dirs). If the user specifies spark.local.dir, it will be ignored. - The --files and --archives options support specifying file names with the # similar to Hadoop. For example you can specify: --files localtest.txt#appSees.txt and this will upload the file you have locally named localtest.txt into HDFS but this will be linked to by the name appSees.txt and your application should use the name as appSees.txt to reference it when running on YARN. diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 94e353af2e..bb73f6d337 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -54,9 +54,10 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl // staging directory is private! -> rwx-------- val STAGING_DIR_PERMISSION: FsPermission = FsPermission.createImmutable(0700:Short) // app files are world-wide readable and owner writable -> rw-r--r-- - val APP_FILE_PERMISSION: FsPermission = FsPermission.createImmutable(0644:Short) + val APP_FILE_PERMISSION: FsPermission = FsPermission.createImmutable(0644:Short) - def run() { + // for client user who want to monitor app status by itself. + def runApp() = { validateArgs() init(yarnConf) @@ -78,7 +79,11 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl appContext.setUser(UserGroupInformation.getCurrentUser().getShortUserName()) submitApp(appContext) - + appId + } + + def run() { + val appId = runApp() monitorApplication(appId) System.exit(0) } @@ -372,7 +377,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl val commands = List[String](javaCommand + " -server " + JAVA_OPTS + - " org.apache.spark.deploy.yarn.ApplicationMaster" + + " " + args.amClass + " --class " + args.userClass + " --jar " + args.userJar + userArgsToString(args) + diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 852dbd7dab..b9dbc3fb87 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -35,6 +35,7 @@ class ClientArguments(val args: Array[String]) { var numWorkers = 2 var amQueue = System.getProperty("QUEUE", "default") var amMemory: Int = 512 + var amClass: String = "org.apache.spark.deploy.yarn.ApplicationMaster" var appName: String = "Spark" // TODO var inputFormatInfo: List[InputFormatInfo] = null @@ -62,18 +63,22 @@ class ClientArguments(val args: Array[String]) { userArgsBuffer += value args = tail - case ("--master-memory") :: MemoryParam(value) :: tail => - amMemory = value + case ("--master-class") :: value :: tail => + amClass = value args = tail - case ("--num-workers") :: IntParam(value) :: tail => - numWorkers = value + case ("--master-memory") :: MemoryParam(value) :: tail => + amMemory = value args = tail case ("--worker-memory") :: MemoryParam(value) :: tail => workerMemory = value args = tail + case ("--num-workers") :: IntParam(value) :: tail => + numWorkers = value + args = tail + case ("--worker-cores") :: IntParam(value) :: tail => workerCores = value args = tail @@ -119,19 +124,20 @@ class ClientArguments(val args: Array[String]) { System.err.println( "Usage: org.apache.spark.deploy.yarn.Client [options] \n" + "Options:\n" + - " --jar JAR_PATH Path to your application's JAR file (required)\n" + - " --class CLASS_NAME Name of your application's main class (required)\n" + - " --args ARGS Arguments to be passed to your application's main class.\n" + - " Mutliple invocations are possible, each will be passed in order.\n" + - " --num-workers NUM Number of workers to start (Default: 2)\n" + - " --worker-cores NUM Number of cores for the workers (Default: 1). This is unsused right now.\n" + - " --master-memory MEM Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb)\n" + - " --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n" + - " --name NAME The name of your application (Default: Spark)\n" + - " --queue QUEUE The hadoop queue to use for allocation requests (Default: 'default')\n" + - " --addJars jars Comma separated list of local jars that want SparkContext.addJar to work with.\n" + - " --files files Comma separated list of files to be distributed with the job.\n" + - " --archives archives Comma separated list of archives to be distributed with the job." + " --jar JAR_PATH Path to your application's JAR file (required)\n" + + " --class CLASS_NAME Name of your application's main class (required)\n" + + " --args ARGS Arguments to be passed to your application's main class.\n" + + " Mutliple invocations are possible, each will be passed in order.\n" + + " --num-workers NUM Number of workers to start (Default: 2)\n" + + " --worker-cores NUM Number of cores for the workers (Default: 1). This is unsused right now.\n" + + " --master-class CLASS_NAME Class Name for Master (Default: spark.deploy.yarn.ApplicationMaster)\n" + + " --master-memory MEM Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb)\n" + + " --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n" + + " --name NAME The name of your application (Default: Spark)\n" + + " --queue QUEUE The hadoop queue to use for allocation requests (Default: 'default')\n" + + " --addJars jars Comma separated list of local jars that want SparkContext.addJar to work with.\n" + + " --files files Comma separated list of files to be distributed with the job.\n" + + " --archives archives Comma separated list of archives to be distributed with the job." ) System.exit(exitCode) } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala new file mode 100644 index 0000000000..421a83c87a --- /dev/null +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import java.net.Socket +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.net.NetUtils +import org.apache.hadoop.yarn.api._ +import org.apache.hadoop.yarn.api.records._ +import org.apache.hadoop.yarn.api.protocolrecords._ +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.ipc.YarnRPC +import org.apache.hadoop.yarn.util.{ConverterUtils, Records} +import akka.actor._ +import akka.remote.{RemoteClientShutdown, RemoteClientDisconnected, RemoteClientLifeCycleEvent} +import akka.remote.RemoteClientShutdown +import akka.actor.Terminated +import akka.remote.RemoteClientDisconnected +import org.apache.spark.{SparkContext, Logging} +import org.apache.spark.util.{Utils, AkkaUtils} +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend +import org.apache.spark.scheduler.SplitInfo + +class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration) extends Logging { + + def this(args: ApplicationMasterArguments) = this(args, new Configuration()) + + private val rpc: YarnRPC = YarnRPC.create(conf) + private var resourceManager: AMRMProtocol = null + private var appAttemptId: ApplicationAttemptId = null + private var reporterThread: Thread = null + private val yarnConf: YarnConfiguration = new YarnConfiguration(conf) + + private var yarnAllocator: YarnAllocationHandler = null + private var driverClosed:Boolean = false + + val actorSystem : ActorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0)._1 + var actor: ActorRef = null + + // This actor just working as a monitor to watch on Driver Actor. + class MonitorActor(driverUrl: String) extends Actor { + + var driver: ActorRef = null + + override def preStart() { + logInfo("Listen to driver: " + driverUrl) + driver = context.actorFor(driverUrl) + driver ! "hello" + context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) + context.watch(driver) // Doesn't work with remote actors, but useful for testing + } + + override def receive = { + case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) => + logInfo("Driver terminated or disconnected! Shutting down.") + driverClosed = true + } + } + + def run() { + + appAttemptId = getApplicationAttemptId() + resourceManager = registerWithResourceManager() + val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster() + + // Compute number of threads for akka + val minimumMemory = appMasterResponse.getMinimumResourceCapability().getMemory() + + if (minimumMemory > 0) { + val mem = args.workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD + val numCore = (mem / minimumMemory) + (if (0 != (mem % minimumMemory)) 1 else 0) + + if (numCore > 0) { + // do not override - hits https://issues.apache.org/jira/browse/HADOOP-8406 + // TODO: Uncomment when hadoop is on a version which has this fixed. + // args.workerCores = numCore + } + } + + waitForSparkMaster() + + // Allocate all containers + allocateWorkers() + + // Launch a progress reporter thread, else app will get killed after expiration (def: 10mins) timeout + // ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapse. + + val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000) + // must be <= timeoutInterval/ 2. + // On other hand, also ensure that we are reasonably responsive without causing too many requests to RM. + // so atleast 1 minute or timeoutInterval / 10 - whichever is higher. + val interval = math.min(timeoutInterval / 2, math.max(timeoutInterval/ 10, 60000L)) + reporterThread = launchReporterThread(interval) + + // Wait for the reporter thread to Finish. + reporterThread.join() + + finishApplicationMaster(FinalApplicationStatus.SUCCEEDED) + actorSystem.shutdown() + + logInfo("Exited") + System.exit(0) + } + + private def getApplicationAttemptId(): ApplicationAttemptId = { + val envs = System.getenv() + val containerIdString = envs.get(ApplicationConstants.AM_CONTAINER_ID_ENV) + val containerId = ConverterUtils.toContainerId(containerIdString) + val appAttemptId = containerId.getApplicationAttemptId() + logInfo("ApplicationAttemptId: " + appAttemptId) + return appAttemptId + } + + private def registerWithResourceManager(): AMRMProtocol = { + val rmAddress = NetUtils.createSocketAddr(yarnConf.get( + YarnConfiguration.RM_SCHEDULER_ADDRESS, + YarnConfiguration.DEFAULT_RM_SCHEDULER_ADDRESS)) + logInfo("Connecting to ResourceManager at " + rmAddress) + return rpc.getProxy(classOf[AMRMProtocol], rmAddress, conf).asInstanceOf[AMRMProtocol] + } + + private def registerApplicationMaster(): RegisterApplicationMasterResponse = { + logInfo("Registering the ApplicationMaster") + val appMasterRequest = Records.newRecord(classOf[RegisterApplicationMasterRequest]) + .asInstanceOf[RegisterApplicationMasterRequest] + appMasterRequest.setApplicationAttemptId(appAttemptId) + // Setting this to master host,port - so that the ApplicationReport at client has some sensible info. + // Users can then monitor stderr/stdout on that node if required. + appMasterRequest.setHost(Utils.localHostName()) + appMasterRequest.setRpcPort(0) + // What do we provide here ? Might make sense to expose something sensible later ? + appMasterRequest.setTrackingUrl("") + return resourceManager.registerApplicationMaster(appMasterRequest) + } + + private def waitForSparkMaster() { + logInfo("Waiting for spark driver to be reachable.") + var driverUp = false + val hostport = args.userArgs(0) + val (driverHost, driverPort) = Utils.parseHostPort(hostport) + while(!driverUp) { + try { + val socket = new Socket(driverHost, driverPort) + socket.close() + logInfo("Master now available: " + driverHost + ":" + driverPort) + driverUp = true + } catch { + case e: Exception => + logError("Failed to connect to driver at " + driverHost + ":" + driverPort) + Thread.sleep(100) + } + } + System.setProperty("spark.driver.host", driverHost) + System.setProperty("spark.driver.port", driverPort.toString) + + val driverUrl = "akka://spark@%s:%s/user/%s".format( + driverHost, driverPort.toString, CoarseGrainedSchedulerBackend.ACTOR_NAME) + + actor = actorSystem.actorOf(Props(new MonitorActor(driverUrl)), name = "YarnAM") + } + + + private def allocateWorkers() { + + // Fixme: should get preferredNodeLocationData from SparkContext, just fake a empty one for now. + val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] = scala.collection.immutable.Map() + + yarnAllocator = YarnAllocationHandler.newAllocator(yarnConf, resourceManager, appAttemptId, args, preferredNodeLocationData) + + logInfo("Allocating " + args.numWorkers + " workers.") + // Wait until all containers have finished + // TODO: This is a bit ugly. Can we make it nicer? + // TODO: Handle container failure + while(yarnAllocator.getNumWorkersRunning < args.numWorkers) { + yarnAllocator.allocateContainers(math.max(args.numWorkers - yarnAllocator.getNumWorkersRunning, 0)) + Thread.sleep(100) + } + + logInfo("All workers have launched.") + + } + + // TODO: We might want to extend this to allocate more containers in case they die ! + private def launchReporterThread(_sleepTime: Long): Thread = { + val sleepTime = if (_sleepTime <= 0 ) 0 else _sleepTime + + val t = new Thread { + override def run() { + while (!driverClosed) { + val missingWorkerCount = args.numWorkers - yarnAllocator.getNumWorkersRunning + if (missingWorkerCount > 0) { + logInfo("Allocating " + missingWorkerCount + " containers to make up for (potentially ?) lost containers") + yarnAllocator.allocateContainers(missingWorkerCount) + } + else sendProgress() + Thread.sleep(sleepTime) + } + } + } + // setting to daemon status, though this is usually not a good idea. + t.setDaemon(true) + t.start() + logInfo("Started progress reporter thread - sleep time : " + sleepTime) + return t + } + + private def sendProgress() { + logDebug("Sending progress") + // simulated with an allocate request with no nodes requested ... + yarnAllocator.allocateContainers(0) + } + + def finishApplicationMaster(status: FinalApplicationStatus) { + + logInfo("finish ApplicationMaster with " + status) + val finishReq = Records.newRecord(classOf[FinishApplicationMasterRequest]) + .asInstanceOf[FinishApplicationMasterRequest] + finishReq.setAppAttemptId(appAttemptId) + finishReq.setFinishApplicationStatus(status) + resourceManager.finishApplicationMaster(finishReq) + } + +} + + +object WorkerLauncher { + def main(argStrings: Array[String]) { + val args = new ApplicationMasterArguments(argStrings) + new WorkerLauncher(args).run() + } +} diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala new file mode 100644 index 0000000000..63a0449e5a --- /dev/null +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import org.apache.spark._ +import org.apache.hadoop.conf.Configuration +import org.apache.spark.deploy.yarn.YarnAllocationHandler +import org.apache.spark.util.Utils + +/** + * + * This scheduler launch worker through Yarn - by call into Client to launch WorkerLauncher as AM. + */ +private[spark] class YarnClientClusterScheduler(sc: SparkContext, conf: Configuration) extends ClusterScheduler(sc) { + + def this(sc: SparkContext) = this(sc, new Configuration()) + + // By default, rack is unknown + override def getRackForHost(hostPort: String): Option[String] = { + val host = Utils.parseHostPort(hostPort)._1 + val retval = YarnAllocationHandler.lookupRack(conf, host) + if (retval != null) Some(retval) else None + } + + override def postStartHook() { + + // The yarn application is running, but the worker might not yet ready + // Wait for a few seconds for the slaves to bootstrap and register with master - best case attempt + Thread.sleep(2000L) + logInfo("YarnClientClusterScheduler.postStartHook done") + } +} diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala new file mode 100644 index 0000000000..b206780c78 --- /dev/null +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import org.apache.hadoop.yarn.api.records.{ApplicationId, YarnApplicationState} +import org.apache.spark.{SparkException, Logging, SparkContext} +import org.apache.spark.deploy.yarn.{Client, ClientArguments} + +private[spark] class YarnClientSchedulerBackend( + scheduler: ClusterScheduler, + sc: SparkContext) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) + with Logging { + + var client: Client = null + var appId: ApplicationId = null + + override def start() { + super.start() + + val defalutWorkerCores = "2" + val defalutWorkerMemory = "512m" + val defaultWorkerNumber = "1" + + val userJar = System.getenv("SPARK_YARN_APP_JAR") + var workerCores = System.getenv("SPARK_WORKER_CORES") + var workerMemory = System.getenv("SPARK_WORKER_MEMORY") + var workerNumber = System.getenv("SPARK_WORKER_INSTANCES") + + if (userJar == null) + throw new SparkException("env SPARK_YARN_APP_JAR is not set") + + if (workerCores == null) + workerCores = defalutWorkerCores + if (workerMemory == null) + workerMemory = defalutWorkerMemory + if (workerNumber == null) + workerNumber = defaultWorkerNumber + + val driverHost = System.getProperty("spark.driver.host") + val driverPort = System.getProperty("spark.driver.port") + val hostport = driverHost + ":" + driverPort + + val argsArray = Array[String]( + "--class", "notused", + "--jar", userJar, + "--args", hostport, + "--worker-memory", workerMemory, + "--worker-cores", workerCores, + "--num-workers", workerNumber, + "--master-class", "org.apache.spark.deploy.yarn.WorkerLauncher" + ) + + val args = new ClientArguments(argsArray) + client = new Client(args) + appId = client.runApp() + waitForApp() + } + + def waitForApp() { + + // TODO : need a better way to find out whether the workers are ready or not + // maybe by resource usage report? + while(true) { + val report = client.getApplicationReport(appId) + + logInfo("Application report from ASM: \n" + + "\t appMasterRpcPort: " + report.getRpcPort() + "\n" + + "\t appStartTime: " + report.getStartTime() + "\n" + + "\t yarnAppState: " + report.getYarnApplicationState() + "\n" + ) + + // Ready to go, or already gone. + val state = report.getYarnApplicationState() + if (state == YarnApplicationState.RUNNING) { + return + } else if (state == YarnApplicationState.FINISHED || + state == YarnApplicationState.FAILED || + state == YarnApplicationState.KILLED) { + throw new SparkException("Yarn application already ended," + + "might be killed or not able to launch application master.") + } + + Thread.sleep(1000) + } + } + + override def stop() { + super.stop() + client.stop() + logInfo("Stoped") + } + +} -- cgit v1.2.3 From ccea38b759c81abea27bc0a51157a31d369839b5 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Thu, 21 Nov 2013 21:36:08 -0800 Subject: Fix 'timeWriting' stat for shuffle files Due to concurrent git branches, changes from shuffle file consolidation patch caused the shuffle write timing patch to no longer actually measure the time, since it requires time be measured after the stream has been closed. --- .../main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 1dc71a0428..0f2deb4bcb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -167,6 +167,7 @@ private[spark] class ShuffleMapTask( var totalTime = 0L val compressedSizes: Array[Byte] = shuffle.writers.map { writer: BlockObjectWriter => writer.commit() + writer.close() val size = writer.fileSegment().length totalBytes += size totalTime += writer.timeWriting() @@ -184,14 +185,16 @@ private[spark] class ShuffleMapTask( } catch { case e: Exception => // If there is an exception from running the task, revert the partial writes // and throw the exception upstream to Spark. - if (shuffle != null) { - shuffle.writers.foreach(_.revertPartialWrites()) + if (shuffle != null && shuffle.writers != null) { + for (writer <- shuffle.writers) { + writer.revertPartialWrites() + writer.close() + } } throw e } finally { // Release the writers back to the shuffle block manager. if (shuffle != null && shuffle.writers != null) { - shuffle.writers.foreach(_.close()) shuffle.releaseWriters(success) } // Execute the callbacks on task completion. -- cgit v1.2.3 From c1507afc6ca161608a83967cdebe1404051658d3 Mon Sep 17 00:00:00 2001 From: Ankur Dave Date: Sat, 23 Nov 2013 02:32:37 -0800 Subject: Support preservesPartitioning in RDD.zipPartitions --- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 21 ++++++++++++++++++--- .../org/apache/spark/rdd/ZippedPartitionsRDD.scala | 21 ++++++++++++++------- 2 files changed, 32 insertions(+), 10 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 6e88be6f6a..7623c44d88 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -545,20 +545,35 @@ abstract class RDD[T: ClassManifest]( * *same number of partitions*, but does *not* require them to have the same number * of elements in each partition. */ + def zipPartitions[B: ClassManifest, V: ClassManifest] + (rdd2: RDD[B], preservesPartitioning: Boolean) + (f: (Iterator[T], Iterator[B]) => Iterator[V]): RDD[V] = + new ZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2, preservesPartitioning) + def zipPartitions[B: ClassManifest, V: ClassManifest] (rdd2: RDD[B]) (f: (Iterator[T], Iterator[B]) => Iterator[V]): RDD[V] = - new ZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2) + new ZippedPartitionsRDD2(sc, sc.clean(f), this, rdd2, false) + + def zipPartitions[B: ClassManifest, C: ClassManifest, V: ClassManifest] + (rdd2: RDD[B], rdd3: RDD[C], preservesPartitioning: Boolean) + (f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V]): RDD[V] = + new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3, preservesPartitioning) def zipPartitions[B: ClassManifest, C: ClassManifest, V: ClassManifest] (rdd2: RDD[B], rdd3: RDD[C]) (f: (Iterator[T], Iterator[B], Iterator[C]) => Iterator[V]): RDD[V] = - new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3) + new ZippedPartitionsRDD3(sc, sc.clean(f), this, rdd2, rdd3, false) + + def zipPartitions[B: ClassManifest, C: ClassManifest, D: ClassManifest, V: ClassManifest] + (rdd2: RDD[B], rdd3: RDD[C], rdd4: RDD[D], preservesPartitioning: Boolean) + (f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V]): RDD[V] = + new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4, preservesPartitioning) def zipPartitions[B: ClassManifest, C: ClassManifest, D: ClassManifest, V: ClassManifest] (rdd2: RDD[B], rdd3: RDD[C], rdd4: RDD[D]) (f: (Iterator[T], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V]): RDD[V] = - new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4) + new ZippedPartitionsRDD4(sc, sc.clean(f), this, rdd2, rdd3, rdd4, false) // Actions (launch a job to return a value to the user program) diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala index 31e6fd519d..faeb316664 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala @@ -39,9 +39,13 @@ private[spark] class ZippedPartitionsPartition( abstract class ZippedPartitionsBaseRDD[V: ClassManifest]( sc: SparkContext, - var rdds: Seq[RDD[_]]) + var rdds: Seq[RDD[_]], + preservesPartitioning: Boolean = false) extends RDD[V](sc, rdds.map(x => new OneToOneDependency(x))) { + override val partitioner = + if (preservesPartitioning) firstParent[Any].partitioner else None + override def getPartitions: Array[Partition] = { val sizes = rdds.map(x => x.partitions.size) if (!sizes.forall(x => x == sizes(0))) { @@ -76,8 +80,9 @@ class ZippedPartitionsRDD2[A: ClassManifest, B: ClassManifest, V: ClassManifest] sc: SparkContext, f: (Iterator[A], Iterator[B]) => Iterator[V], var rdd1: RDD[A], - var rdd2: RDD[B]) - extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2)) { + var rdd2: RDD[B], + preservesPartitioning: Boolean = false) + extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2), preservesPartitioning) { override def compute(s: Partition, context: TaskContext): Iterator[V] = { val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions @@ -97,8 +102,9 @@ class ZippedPartitionsRDD3 f: (Iterator[A], Iterator[B], Iterator[C]) => Iterator[V], var rdd1: RDD[A], var rdd2: RDD[B], - var rdd3: RDD[C]) - extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3)) { + var rdd3: RDD[C], + preservesPartitioning: Boolean = false) + extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3), preservesPartitioning) { override def compute(s: Partition, context: TaskContext): Iterator[V] = { val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions @@ -122,8 +128,9 @@ class ZippedPartitionsRDD4 var rdd1: RDD[A], var rdd2: RDD[B], var rdd3: RDD[C], - var rdd4: RDD[D]) - extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4)) { + var rdd4: RDD[D], + preservesPartitioning: Boolean = false) + extends ZippedPartitionsBaseRDD[V](sc, List(rdd1, rdd2, rdd3, rdd4), preservesPartitioning) { override def compute(s: Partition, context: TaskContext): Iterator[V] = { val partitions = s.asInstanceOf[ZippedPartitionsPartition].partitions -- cgit v1.2.3 From 4f1c3fa5d7e6fe509b1cea550eaa213a185ec964 Mon Sep 17 00:00:00 2001 From: Harvey Feng Date: Sat, 23 Nov 2013 17:07:19 -0800 Subject: Hadoop 2.2 YARN API migration for `SPARK_HOME/new-yarn` --- .../main/scala/org/apache/spark/SparkContext.scala | 2 +- .../spark/deploy/yarn/ApplicationMaster.scala | 155 +++--- .../org/apache/spark/deploy/yarn/Client.scala | 163 +++--- .../apache/spark/deploy/yarn/ClientArguments.scala | 19 +- .../apache/spark/deploy/yarn/WorkerRunnable.scala | 48 +- .../spark/deploy/yarn/YarnAllocationHandler.scala | 570 +++++++++++---------- 6 files changed, 468 insertions(+), 489 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 42b2985b50..fad54683bc 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -81,7 +81,7 @@ class SparkContext( val sparkHome: String = null, val jars: Seq[String] = Nil, val environment: Map[String, String] = Map(), - // This is used only by yarn for now, but should be relevant to other cluster types (mesos, etc) + // This is used only by YARN for now, but should be relevant to other cluster types (Mesos, etc) // too. This is typically generated from InputFormatInfo.computePreferredLocations .. host, set // of data-local splits on host val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] = diff --git a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 9c43a7287d..eeeca3ea8a 100644 --- a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -30,8 +30,10 @@ import org.apache.hadoop.net.NetUtils import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.util.ShutdownHookManager import org.apache.hadoop.yarn.api._ -import org.apache.hadoop.yarn.api.records._ import org.apache.hadoop.yarn.api.protocolrecords._ +import org.apache.hadoop.yarn.api.records._ +import org.apache.hadoop.yarn.client.api.AMRMClient +import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{ConverterUtils, Records} @@ -45,55 +47,43 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e def this(args: ApplicationMasterArguments) = this(args, new Configuration()) private var rpc: YarnRPC = YarnRPC.create(conf) - private var resourceManager: AMRMProtocol = _ + private val yarnConf: YarnConfiguration = new YarnConfiguration(conf) private var appAttemptId: ApplicationAttemptId = _ private var userThread: Thread = _ - private val yarnConf: YarnConfiguration = new YarnConfiguration(conf) private val fs = FileSystem.get(yarnConf) private var yarnAllocator: YarnAllocationHandler = _ private var isFinished: Boolean = false private var uiAddress: String = _ - private val maxAppAttempts: Int = conf.getInt(YarnConfiguration.RM_AM_MAX_RETRIES, - YarnConfiguration.DEFAULT_RM_AM_MAX_RETRIES) + private val maxAppAttempts: Int = conf.getInt( + YarnConfiguration.RM_AM_MAX_ATTEMPTS, YarnConfiguration.DEFAULT_RM_AM_MAX_ATTEMPTS) private var isLastAMRetry: Boolean = true - // default to numWorkers * 2, with minimum of 3 + private var amClient: AMRMClient[ContainerRequest] = _ + + // Default to numWorkers * 2, with minimum of 3 private val maxNumWorkerFailures = System.getProperty("spark.yarn.max.worker.failures", math.max(args.numWorkers * 2, 3).toString()).toInt def run() { - // Setup the directories so things go to yarn approved directories rather - // then user specified and /tmp. + // Setup the directories so things go to YARN approved directories rather + // than user specified and /tmp. System.setProperty("spark.local.dir", getLocalDirs()) - // Use priority 30 as its higher then HDFS. Its same priority as MapReduce is using. + // Use priority 30 as it's higher then HDFS. It's same priority as MapReduce is using. ShutdownHookManager.get().addShutdownHook(new AppMasterShutdownHook(this), 30) - + appAttemptId = getApplicationAttemptId() isLastAMRetry = appAttemptId.getAttemptId() >= maxAppAttempts - resourceManager = registerWithResourceManager() + amClient = AMRMClient.createAMRMClient() + amClient.init(yarnConf) + amClient.start() // Workaround until hadoop moves to something which has // https://issues.apache.org/jira/browse/HADOOP-8406 - fixed in (2.0.2-alpha but no 0.23 line) - // ignore result. - // This does not, unfortunately, always work reliably ... but alleviates the bug a lot of times - // Hence args.workerCores = numCore disabled above. Any better option? - - // Compute number of threads for akka - //val minimumMemory = appMasterResponse.getMinimumResourceCapability().getMemory() - //if (minimumMemory > 0) { - // val mem = args.workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD - // val numCore = (mem / minimumMemory) + (if (0 != (mem % minimumMemory)) 1 else 0) - - // if (numCore > 0) { - // do not override - hits https://issues.apache.org/jira/browse/HADOOP-8406 - // TODO: Uncomment when hadoop is on a version which has this fixed. - // args.workerCores = numCore - // } - //} // org.apache.hadoop.io.compress.CompressionCodecFactory.getCodecClasses(conf) ApplicationMaster.register(this) + // Start the user's JAR userThread = startUserClass() @@ -103,12 +93,12 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e waitForSparkContextInitialized() - // Do this after spark master is up and SparkContext is created so that we can register UI Url + // Do this after Spark master is up and SparkContext is created so that we can register UI Url. val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster() - + // Allocate all containers allocateWorkers() - + // Wait for the user class to Finish userThread.join() @@ -132,41 +122,24 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e private def getApplicationAttemptId(): ApplicationAttemptId = { val envs = System.getenv() - val containerIdString = envs.get(ApplicationConstants.AM_CONTAINER_ID_ENV) + val containerIdString = envs.get(ApplicationConstants.Environment.CONTAINER_ID.name()) val containerId = ConverterUtils.toContainerId(containerIdString) val appAttemptId = containerId.getApplicationAttemptId() logInfo("ApplicationAttemptId: " + appAttemptId) appAttemptId } - private def registerWithResourceManager(): AMRMProtocol = { - val rmAddress = NetUtils.createSocketAddr(yarnConf.get( - YarnConfiguration.RM_SCHEDULER_ADDRESS, - YarnConfiguration.DEFAULT_RM_SCHEDULER_ADDRESS)) - logInfo("Connecting to ResourceManager at " + rmAddress) - rpc.getProxy(classOf[AMRMProtocol], rmAddress, conf).asInstanceOf[AMRMProtocol] - } - private def registerApplicationMaster(): RegisterApplicationMasterResponse = { logInfo("Registering the ApplicationMaster") - val appMasterRequest = Records.newRecord(classOf[RegisterApplicationMasterRequest]) - .asInstanceOf[RegisterApplicationMasterRequest] - appMasterRequest.setApplicationAttemptId(appAttemptId) - // Setting this to master host,port - so that the ApplicationReport at client has some - // sensible info. - // Users can then monitor stderr/stdout on that node if required. - appMasterRequest.setHost(Utils.localHostName()) - appMasterRequest.setRpcPort(0) - appMasterRequest.setTrackingUrl(uiAddress) - resourceManager.registerApplicationMaster(appMasterRequest) + amClient.registerApplicationMaster(Utils.localHostName(), 0, uiAddress) } private def waitForSparkMaster() { - logInfo("Waiting for spark driver to be reachable.") + logInfo("Waiting for Spark driver to be reachable.") var driverUp = false var tries = 0 val numTries = System.getProperty("spark.yarn.applicationMaster.waitTries", "10").toInt - while(!driverUp && tries < numTries) { + while (!driverUp && tries < numTries) { val driverHost = System.getProperty("spark.driver.host") val driverPort = System.getProperty("spark.driver.port") try { @@ -176,8 +149,8 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e driverUp = true } catch { case e: Exception => { - logWarning("Failed to connect to driver at %s:%s, retrying ..."). - format(driverHost, driverPort) + logWarning("Failed to connect to driver at %s:%s, retrying ...". + format(driverHost, driverPort)) Thread.sleep(100) tries = tries + 1 } @@ -218,44 +191,44 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e t } - // this need to happen before allocateWorkers + // This need to happen before allocateWorkers() private def waitForSparkContextInitialized() { - logInfo("Waiting for spark context initialization") + logInfo("Waiting for Spark context initialization") try { var sparkContext: SparkContext = null ApplicationMaster.sparkContextRef.synchronized { - var count = 0 + var numTries = 0 val waitTime = 10000L - val numTries = System.getProperty("spark.yarn.ApplicationMaster.waitTries", "10").toInt - while (ApplicationMaster.sparkContextRef.get() == null && count < numTries) { - logInfo("Waiting for spark context initialization ... " + count) - count = count + 1 + val maxNumTries = System.getProperty("spark.yarn.ApplicationMaster.waitTries", "10").toInt + while (ApplicationMaster.sparkContextRef.get() == null && numTries < maxNumTries) { + logInfo("Waiting for Spark context initialization ... " + numTries) + numTries = numTries + 1 ApplicationMaster.sparkContextRef.wait(waitTime) } sparkContext = ApplicationMaster.sparkContextRef.get() - assert(sparkContext != null || count >= numTries) + assert(sparkContext != null || numTries >= maxNumTries) - if (null != sparkContext) { + if (sparkContext != null) { uiAddress = sparkContext.ui.appUIAddress this.yarnAllocator = YarnAllocationHandler.newAllocator( yarnConf, - resourceManager, + amClient, appAttemptId, args, - sparkContext.preferredNodeLocationData) + sparkContext.preferredNodeLocationData) } else { - logWarning("Unable to retrieve sparkContext inspite of waiting for %d, numTries = %d". - format(count * waitTime, numTries)) + logWarning("Unable to retrieve SparkContext inspite of waiting for %d, maxNumTries = %d". + format(numTries * waitTime, maxNumTries)) this.yarnAllocator = YarnAllocationHandler.newAllocator( yarnConf, - resourceManager, + amClient, appAttemptId, args) } } } finally { - // in case of exceptions, etc - ensure that count is atleast ALLOCATOR_LOOP_WAIT_COUNT : - // so that the loop (in ApplicationMaster.sparkContextInitialized) breaks + // In case of exceptions, etc - ensure that count is at least ALLOCATOR_LOOP_WAIT_COUNT : + // so that the loop (in ApplicationMaster.sparkContextInitialized) breaks. ApplicationMaster.incrementAllocatorLoop(ApplicationMaster.ALLOCATOR_LOOP_WAIT_COUNT) } } @@ -266,15 +239,14 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e // Wait until all containers have finished // TODO: This is a bit ugly. Can we make it nicer? // TODO: Handle container failure - - // Exists the loop if the user thread exits. + yarnAllocator.addResourceRequests(args.numWorkers) + // Exits the loop if the user thread exits. while (yarnAllocator.getNumWorkersRunning < args.numWorkers && userThread.isAlive) { if (yarnAllocator.getNumWorkersFailed >= maxNumWorkerFailures) { finishApplicationMaster(FinalApplicationStatus.FAILED, "max number of worker failures reached") } - yarnAllocator.allocateContainers( - math.max(args.numWorkers - yarnAllocator.getNumWorkersRunning, 0)) + yarnAllocator.allocateResources() ApplicationMaster.incrementAllocatorLoop(1) Thread.sleep(100) } @@ -287,7 +259,6 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e // Launch a progress reporter thread, else the app will get killed after expiration // (def: 10mins) timeout. - // TODO(harvey): Verify the timeout if (userThread.isAlive) { // Ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapses. val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000) @@ -313,13 +284,14 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e finishApplicationMaster(FinalApplicationStatus.FAILED, "max number of worker failures reached") } - val missingWorkerCount = args.numWorkers - yarnAllocator.getNumWorkersRunning + val missingWorkerCount = args.numWorkers - yarnAllocator.getNumWorkersRunning - + yarnAllocator.getNumPendingAllocate if (missingWorkerCount > 0) { logInfo("Allocating %d containers to make up for (potentially) lost containers". format(missingWorkerCount)) - yarnAllocator.allocateContainers(missingWorkerCount) + yarnAllocator.addResourceRequests(missingWorkerCount) } - else sendProgress() + sendProgress() Thread.sleep(sleepTime) } } @@ -333,8 +305,8 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e private def sendProgress() { logDebug("Sending progress") - // Simulated with an allocate request with no nodes requested ... - yarnAllocator.allocateContainers(0) + // Simulated with an allocate request with no nodes requested. + yarnAllocator.allocateResources() } /* @@ -361,14 +333,8 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e } logInfo("finishApplicationMaster with " + status) - val finishReq = Records.newRecord(classOf[FinishApplicationMasterRequest]) - .asInstanceOf[FinishApplicationMasterRequest] - finishReq.setAppAttemptId(appAttemptId) - finishReq.setFinishApplicationStatus(status) - finishReq.setDiagnostics(diagnostics) - // Set tracking url to empty since we don't have a history server. - finishReq.setTrackingUrl("") - resourceManager.finishApplicationMaster(finishReq) + // Set tracking URL to empty since we don't have a history server. + amClient.unregisterApplicationMaster(status, "" /* appMessage */, "" /* appTrackingUrl */) } /** @@ -412,6 +378,14 @@ object ApplicationMaster { // TODO: Currently, task to container is computed once (TaskSetManager) - which need not be // optimal as more containers are available. Might need to handle this better. private val ALLOCATOR_LOOP_WAIT_COUNT = 30 + + private val applicationMasters = new CopyOnWriteArrayList[ApplicationMaster]() + + val sparkContextRef: AtomicReference[SparkContext] = + new AtomicReference[SparkContext](null /* initialValue */) + + val yarnAllocatorLoop: AtomicInteger = new AtomicInteger(0) + def incrementAllocatorLoop(by: Int) { val count = yarnAllocatorLoop.getAndAdd(by) if (count >= ALLOCATOR_LOOP_WAIT_COUNT) { @@ -422,16 +396,11 @@ object ApplicationMaster { } } - private val applicationMasters = new CopyOnWriteArrayList[ApplicationMaster]() - def register(master: ApplicationMaster) { applicationMasters.add(master) } - val sparkContextRef: AtomicReference[SparkContext] = - new AtomicReference[SparkContext](null /* initialValue */) - val yarnAllocatorLoop: AtomicInteger = new AtomicInteger(0) - + // TODO(harvey): See whether this should be discarded - it isn't used anywhere atm... def sparkContextInitialized(sc: SparkContext): Boolean = { var modified = false sparkContextRef.synchronized { diff --git a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 86310f32d5..ee90086729 100644 --- a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -35,7 +35,7 @@ import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.api.protocolrecords._ import org.apache.hadoop.yarn.api.records._ -import org.apache.hadoop.yarn.client.YarnClientImpl +import org.apache.hadoop.yarn.client.api.impl.YarnClientImpl import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{Apps, Records} @@ -45,10 +45,13 @@ import org.apache.spark.util.Utils import org.apache.spark.deploy.SparkHadoopUtil +/** + * The entry point (starting in Client#main() and Client#run()) for launching Spark on YARN. The + * Client submits an application to the global ResourceManager to launch Spark's ApplicationMaster, + * which will launch a Spark master process and negotiate resources throughout its duration. + */ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl with Logging { - def this(args: ClientArguments) = this(new Configuration(), args) - var rpc: YarnRPC = YarnRPC.create(conf) val yarnConf: YarnConfiguration = new YarnConfiguration(conf) val credentials = UserGroupInformation.getCurrentUser().getCredentials() @@ -56,48 +59,68 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl private val distCacheMgr = new ClientDistributedCacheManager() // Staging directory is private! -> rwx-------- - val STAGING_DIR_PERMISSION: FsPermission = FsPermission.createImmutable(0700:Short) + val STAGING_DIR_PERMISSION: FsPermission = FsPermission.createImmutable(0700: Short) // App files are world-wide readable and owner writable -> rw-r--r-- - val APP_FILE_PERMISSION: FsPermission = FsPermission.createImmutable(0644:Short) + val APP_FILE_PERMISSION: FsPermission = FsPermission.createImmutable(0644: Short) + + def this(args: ClientArguments) = this(new Configuration(), args) def run() { validateArgs() + // Initialize and start the client service. init(yarnConf) start() + + // Log details about this YARN cluster (e.g, the number of slave machines/NodeManagers). logClusterResourceDetails() - val newApp = super.getNewApplication() - val appId = newApp.getApplicationId() + // Prepare to submit a request to the ResourcManager (specifically its ApplicationsManager (ASM) + // interface). - verifyClusterResources(newApp) - val appContext = createApplicationSubmissionContext(appId) + // Get a new client application. + val newApp = super.createApplication() + val newAppResponse = newApp.getNewApplicationResponse() + val appId = newAppResponse.getApplicationId() + + verifyClusterResources(newAppResponse) + + // Set up resource and environment variables. val appStagingDir = getAppStagingDir(appId) val localResources = prepareLocalResources(appStagingDir) - val env = setupLaunchEnv(localResources, appStagingDir) - val amContainer = createContainerLaunchContext(newApp, localResources, env) + val launchEnv = setupLaunchEnv(localResources, appStagingDir) + val amContainer = createContainerLaunchContext(newAppResponse, localResources, launchEnv) + // Set up an application submission context. + val appContext = newApp.getApplicationSubmissionContext() + appContext.setApplicationName(args.appName) appContext.setQueue(args.amQueue) appContext.setAMContainerSpec(amContainer) - appContext.setUser(UserGroupInformation.getCurrentUser().getShortUserName()) - submitApp(appContext) + // Memory for the ApplicationMaster. + val memoryResource = Records.newRecord(classOf[Resource]).asInstanceOf[Resource] + memoryResource.setMemory(args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + appContext.setResource(memoryResource) + // Finally, submit and monitor the application. + submitApp(appContext) monitorApplication(appId) + System.exit(0) } + // TODO(harvey): This could just go in ClientArguments. def validateArgs() = { Map( (System.getenv("SPARK_JAR") == null) -> "Error: You must set SPARK_JAR environment variable!", (args.userJar == null) -> "Error: You must specify a user jar!", (args.userClass == null) -> "Error: You must specify a user class!", (args.numWorkers <= 0) -> "Error: You must specify atleast 1 worker!", - (args.amMemory <= YarnAllocationHandler.MEMORY_OVERHEAD) -> "Error: AM memory size must be + - greater then: " + YarnAllocationHandler.MEMORY_OVERHEAD, - (args.workerMemory <= YarnAllocationHandler.MEMORY_OVERHEAD) -> "Error: Worker memory size + - must be greater then: " + YarnAllocationHandler.MEMORY_OVERHEAD.toString - .foreach { case(cond, errStr) => + (args.amMemory <= YarnAllocationHandler.MEMORY_OVERHEAD) -> ("Error: AM memory size must be" + + "greater than: " + YarnAllocationHandler.MEMORY_OVERHEAD), + (args.workerMemory <= YarnAllocationHandler.MEMORY_OVERHEAD) -> ("Error: Worker memory size" + + "must be greater than: " + YarnAllocationHandler.MEMORY_OVERHEAD.toString) + ).foreach { case(cond, errStr) => if (cond) { logError(errStr) args.printUsageAndExit(1) @@ -111,17 +134,17 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl def logClusterResourceDetails() { val clusterMetrics: YarnClusterMetrics = super.getYarnClusterMetrics - logInfo("Got Cluster metric info from ASM, numNodeManagers = " + + logInfo("Got Cluster metric info from ApplicationsManager (ASM), number of NodeManagers: " + clusterMetrics.getNumNodeManagers) val queueInfo: QueueInfo = super.getQueueInfo(args.amQueue) - logInfo("""Queue info ... queueName = %s, queueCurrentCapacity = %s, queueMaxCapacity = %s, + logInfo("""Queue info ... queueName: %s, queueCurrentCapacity: %s, queueMaxCapacity: %s, queueApplicationCount = %s, queueChildQueueCount = %s""".format( queueInfo.getQueueName, queueInfo.getCurrentCapacity, queueInfo.getMaximumCapacity, queueInfo.getApplications.size, - queueInfo.getChildQueues.size) + queueInfo.getChildQueues.size)) } def verifyClusterResources(app: GetNewApplicationResponse) = { @@ -130,25 +153,19 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl // If we have requested more then the clusters max for a single resource then exit. if (args.workerMemory > maxMem) { - logError("the worker size is to large to run on this cluster " + args.workerMemory) + logError("Required worker memory (%d MB), is above the max threshold (%d MB) of this cluster.". + format(args.workerMemory, maxMem)) System.exit(1) } val amMem = args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD if (amMem > maxMem) { - logError("AM size is to large to run on this cluster " + amMem) + logError("Required AM memory (%d) is above the max threshold (%d) of this cluster". + format(args.amMemory, maxMem)) System.exit(1) } // We could add checks to make sure the entire cluster has enough resources but that involves - // getting all the node reports and computing ourselves - } - - def createApplicationSubmissionContext(appId: ApplicationId): ApplicationSubmissionContext = { - logInfo("Setting up application submission context for ASM") - val appContext = Records.newRecord(classOf[ApplicationSubmissionContext]) - appContext.setApplicationId(appId) - appContext.setApplicationName(args.appName) - return appContext + // getting all the node reports and computing ourselves. } /** See if two file systems are the same or not. */ @@ -213,7 +230,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl def prepareLocalResources(appStagingDir: String): HashMap[String, LocalResource] = { logInfo("Preparing Local resources") // Upload Spark and the application JAR to the remote file system if necessary. Add them as - // local resources to the AM. + // local resources to the application master. val fs = FileSystem.get(conf) val delegTokenRenewer = Master.getMasterPrincipal(conf) @@ -230,18 +247,20 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl val dstFs = dst.getFileSystem(conf) dstFs.addDelegationTokens(delegTokenRenewer, credentials) } + val localResources = HashMap[String, LocalResource]() FileSystem.mkdirs(fs, dst, new FsPermission(STAGING_DIR_PERMISSION)) val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() - Map(Client.SPARK_JAR -> System.getenv("SPARK_JAR"), Client.APP_JAR -> args.userJar, - Client.LOG4J_PROP -> System.getenv("SPARK_LOG4J_CONF")) - .foreach { case(destName, _localPath) => + Map( + Client.SPARK_JAR -> System.getenv("SPARK_JAR"), Client.APP_JAR -> args.userJar, + Client.LOG4J_PROP -> System.getenv("SPARK_LOG4J_CONF") + ).foreach { case(destName, _localPath) => val localPath: String = if (_localPath != null) _localPath.trim() else "" if (! localPath.isEmpty()) { var localURI = new URI(localPath) - // if not specified assume these are in the local filesystem to keep behavior like Hadoop + // If not specified assume these are in the local filesystem to keep behavior like Hadoop if (localURI.getScheme() == null) { localURI = new URI(FileSystem.getLocal(conf).makeQualified(new Path(localPath)).toString) } @@ -252,19 +271,21 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl } } - // handle any add jars + // Handle jars local to the ApplicationMaster. if ((args.addJars != null) && (!args.addJars.isEmpty())){ args.addJars.split(',').foreach { case file: String => val localURI = new URI(file.trim()) val localPath = new Path(localURI) val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName()) val destPath = copyRemoteFile(dst, localPath, replication) + // Only add the resource to the Spark ApplicationMaster. + val appMasterOnly = true distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, - linkname, statCache, true) + linkname, statCache, appMasterOnly) } } - // handle any distributed cache files + // Handle any distributed cache files if ((args.files != null) && (!args.files.isEmpty())){ args.files.split(',').foreach { case file: String => val localURI = new URI(file.trim()) @@ -276,7 +297,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl } } - // handle any distributed cache archives + // Handle any distributed cache archives if ((args.archives != null) && (!args.archives.isEmpty())) { args.archives.split(',').foreach { case file:String => val localURI = new URI(file.trim()) @@ -289,7 +310,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl } UserGroupInformation.getCurrentUser().addCredentials(credentials) - return localResources + localResources } def setupLaunchEnv( @@ -311,8 +332,9 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl // Allow users to specify some environment variables. Apps.setEnvFromInputString(env, System.getenv("SPARK_YARN_USER_ENV")) - // Add each SPARK-* key to the environment. + // Add each SPARK_* key to the environment. System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v } + env } @@ -335,33 +357,32 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl amContainer.setLocalResources(localResources) amContainer.setEnvironment(env) - val minResMemory: Int = newApp.getMinimumResourceCapability().getMemory() - - // TODO(harvey): This can probably be a val. - var amMemory = ((args.amMemory / minResMemory) * minResMemory) + - ((if ((args.amMemory % minResMemory) == 0) 0 else minResMemory) - - YarnAllocationHandler.MEMORY_OVERHEAD) + // TODO: Need a replacement for the following code to fix -Xmx? + // val minResMemory: Int = newApp.getMinimumResourceCapability().getMemory() + // var amMemory = ((args.amMemory / minResMemory) * minResMemory) + + // ((if ((args.amMemory % minResMemory) == 0) 0 else minResMemory) - + // YarnAllocationHandler.MEMORY_OVERHEAD) // Extra options for the JVM var JAVA_OPTS = "" - // Add Xmx for am memory - JAVA_OPTS += "-Xmx" + amMemory + "m " + // Add Xmx for AM memory + JAVA_OPTS += "-Xmx" + args.amMemory + "m" - JAVA_OPTS += " -Djava.io.tmpdir=" + - new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR) + " " + val tmpDir = new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR) + JAVA_OPTS += " -Djava.io.tmpdir=" + tmpDir - // Commenting it out for now - so that people can refer to the properties if required. Remove - // it once cpuset version is pushed out. The context is, default gc for server class machines - // end up using all cores to do gc - hence if there are multiple containers in same node, - // spark gc effects all other containers performance (which can also be other spark containers) - // Instead of using this, rely on cpusets by YARN to enforce spark behaves 'properly' in - // multi-tenant environments. Not sure how default java gc behaves if it is limited to subset + // TODO: Remove once cpuset version is pushed out. + // The context is, default gc for server class machines ends up using all cores to do gc - + // hence if there are multiple containers in same node, Spark GC affects all other containers' + // performance (which can be that of other Spark containers) + // Instead of using this, rely on cpusets by YARN to enforce "proper" Spark behavior in + // multi-tenant environments. Not sure how default Java GC behaves if it is limited to subset // of cores on a node. val useConcurrentAndIncrementalGC = env.isDefinedAt("SPARK_USE_CONC_INCR_GC") && java.lang.Boolean.parseBoolean(env("SPARK_USE_CONC_INCR_GC")) if (useConcurrentAndIncrementalGC) { - // In our expts, using (default) throughput collector has severe perf ramnifications in + // In our expts, using (default) throughput collector has severe perf ramifications in // multi-tenant machines JAVA_OPTS += " -XX:+UseConcMarkSweepGC " JAVA_OPTS += " -XX:+CMSIncrementalMode " @@ -371,7 +392,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl } if (env.isDefinedAt("SPARK_JAVA_OPTS")) { - JAVA_OPTS += env("SPARK_JAVA_OPTS") + " " + JAVA_OPTS += " " + env("SPARK_JAVA_OPTS") } // Command for the ApplicationMaster @@ -381,7 +402,8 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl javaCommand = Environment.JAVA_HOME.$() + "/bin/java" } - val commands = List[String](javaCommand + + val commands = List[String]( + javaCommand + " -server " + JAVA_OPTS + " org.apache.spark.deploy.yarn.ApplicationMaster" + @@ -393,18 +415,14 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl " --num-workers " + args.numWorkers + " 1> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stdout" + " 2> " + ApplicationConstants.LOG_DIR_EXPANSION_VAR + "/stderr") - logInfo("Command for the ApplicationMaster: " + commands(0)) - amContainer.setCommands(commands) - val capability = Records.newRecord(classOf[Resource]).asInstanceOf[Resource] - // Memory for the ApplicationMaster. - capability.setMemory(args.amMemory + YarnAllocationHandler.MEMORY_OVERHEAD) - amContainer.setResource(capability) + logInfo("Command for starting the Spark ApplicationMaster: " + commands(0)) + amContainer.setCommands(commands) // Setup security tokens. val dob = new DataOutputBuffer() credentials.writeTokenStorageToStream(dob) - amContainer.setContainerTokens(ByteBuffer.wrap(dob.getData())) + amContainer.setTokens(ByteBuffer.wrap(dob.getData())) amContainer } @@ -423,7 +441,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl logInfo("Application report from ASM: \n" + "\t application identifier: " + appId.toString() + "\n" + "\t appId: " + appId.getId() + "\n" + - "\t clientToken: " + report.getClientToken() + "\n" + + "\t clientToAMToken: " + report.getClientToAMToken() + "\n" + "\t appDiagnostics: " + report.getDiagnostics() + "\n" + "\t appMasterHost: " + report.getHost() + "\n" + "\t appQueue: " + report.getQueue() + "\n" + @@ -454,12 +472,13 @@ object Client { def main(argStrings: Array[String]) { // Set an env variable indicating we are running in YARN mode. - // Note that anything with SPARK prefix gets propagated to all (remote) processes + // Note: anything env variable with SPARK_ prefix gets propagated to all (remote) processes - + // see Client#setupLaunchEnv(). System.setProperty("SPARK_YARN_MODE", "true") val args = new ClientArguments(argStrings) - new Client(args).run + (new Client(args)).run() } // Based on code from org.apache.hadoop.mapreduce.v2.util.MRApps diff --git a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 852dbd7dab..6d3c95867e 100644 --- a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -17,12 +17,14 @@ package org.apache.spark.deploy.yarn -import org.apache.spark.util.MemoryParam -import org.apache.spark.util.IntParam -import collection.mutable.{ArrayBuffer, HashMap} +import scala.collection.mutable.{ArrayBuffer, HashMap} + import org.apache.spark.scheduler.{InputFormatInfo, SplitInfo} +import org.apache.spark.util.IntParam +import org.apache.spark.util.MemoryParam + -// TODO: Add code and support for ensuring that yarn resource 'asks' are location aware ! +// TODO: Add code and support for ensuring that yarn resource 'tasks' are location aware ! class ClientArguments(val args: Array[String]) { var addJars: String = null var files: String = null @@ -30,14 +32,16 @@ class ClientArguments(val args: Array[String]) { var userJar: String = null var userClass: String = null var userArgs: Seq[String] = Seq[String]() - var workerMemory = 1024 + var workerMemory = 1024 // MB var workerCores = 1 var numWorkers = 2 var amQueue = System.getProperty("QUEUE", "default") - var amMemory: Int = 512 + var amMemory: Int = 512 // MB var appName: String = "Spark" // TODO var inputFormatInfo: List[InputFormatInfo] = null + // TODO(harvey) + var priority = 0 parseArgs(args.toList) @@ -47,8 +51,7 @@ class ClientArguments(val args: Array[String]) { var args = inputArgs - while (! args.isEmpty) { - + while (!args.isEmpty) { args match { case ("--jar") :: value :: tail => userJar = value diff --git a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala index 6a90cc51cf..9f5523c4b9 100644 --- a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala +++ b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerRunnable.scala @@ -32,10 +32,12 @@ import org.apache.hadoop.security.UserGroupInformation import org.apache.hadoop.yarn.api._ import org.apache.hadoop.yarn.api.ApplicationConstants.Environment import org.apache.hadoop.yarn.api.records._ +import org.apache.hadoop.yarn.api.records.impl.pb.ProtoUtils import org.apache.hadoop.yarn.api.protocolrecords._ +import org.apache.hadoop.yarn.client.api.NMClient import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.ipc.YarnRPC -import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records, ProtoUtils} +import org.apache.hadoop.yarn.util.{Apps, ConverterUtils, Records} import org.apache.spark.Logging @@ -51,12 +53,14 @@ class WorkerRunnable( extends Runnable with Logging { var rpc: YarnRPC = YarnRPC.create(conf) - var cm: ContainerManager = null + var nmClient: NMClient = _ val yarnConf: YarnConfiguration = new YarnConfiguration(conf) def run = { logInfo("Starting Worker Container") - cm = connectToCM + nmClient = NMClient.createNMClient() + nmClient.init(yarnConf) + nmClient.start() startContainer } @@ -66,8 +70,6 @@ class WorkerRunnable( val ctx = Records.newRecord(classOf[ContainerLaunchContext]) .asInstanceOf[ContainerLaunchContext] - ctx.setContainerId(container.getId()) - ctx.setResource(container.getResource()) val localResources = prepareLocalResources ctx.setLocalResources(localResources) @@ -111,12 +113,10 @@ class WorkerRunnable( } */ - ctx.setUser(UserGroupInformation.getCurrentUser().getShortUserName()) - val credentials = UserGroupInformation.getCurrentUser().getCredentials() val dob = new DataOutputBuffer() credentials.writeTokenStorageToStream(dob) - ctx.setContainerTokens(ByteBuffer.wrap(dob.getData())) + ctx.setTokens(ByteBuffer.wrap(dob.getData())) var javaCommand = "java" val javaHome = System.getenv("JAVA_HOME") @@ -144,10 +144,7 @@ class WorkerRunnable( ctx.setCommands(commands) // Send the start request to the ContainerManager - val startReq = Records.newRecord(classOf[StartContainerRequest]) - .asInstanceOf[StartContainerRequest] - startReq.setContainerLaunchContext(ctx) - cm.startContainer(startReq) + nmClient.startContainer(container, ctx) } private def setupDistributedCache( @@ -194,7 +191,7 @@ class WorkerRunnable( } logInfo("Prepared Local resources " + localResources) - return localResources + localResources } def prepareEnvironment: HashMap[String, String] = { @@ -206,30 +203,7 @@ class WorkerRunnable( Apps.setEnvFromInputString(env, System.getenv("SPARK_YARN_USER_ENV")) System.getenv().filterKeys(_.startsWith("SPARK")).foreach { case (k,v) => env(k) = v } - return env - } - - def connectToCM: ContainerManager = { - val cmHostPortStr = container.getNodeId().getHost() + ":" + container.getNodeId().getPort() - val cmAddress = NetUtils.createSocketAddr(cmHostPortStr) - logInfo("Connecting to ContainerManager at " + cmHostPortStr) - - // Use doAs and remoteUser here so we can add the container token and not pollute the current - // users credentials with all of the individual container tokens - val user = UserGroupInformation.createRemoteUser(container.getId().toString()) - val containerToken = container.getContainerToken() - if (containerToken != null) { - user.addToken(ProtoUtils.convertFromProtoFormat(containerToken, cmAddress)) - } - - val proxy = user - .doAs(new PrivilegedExceptionAction[ContainerManager] { - def run: ContainerManager = { - return rpc.getProxy(classOf[ContainerManager], - cmAddress, conf).asInstanceOf[ContainerManager] - } - }) - proxy + env } } diff --git a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala index 6ce470e8cb..dba0f7640e 100644 --- a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala +++ b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/YarnAllocationHandler.scala @@ -32,11 +32,13 @@ import org.apache.spark.scheduler.cluster.{ClusterScheduler, CoarseGrainedSchedu import org.apache.spark.util.Utils import org.apache.hadoop.conf.Configuration -import org.apache.hadoop.yarn.api.AMRMProtocol -import org.apache.hadoop.yarn.api.records.{AMResponse, ApplicationAttemptId} +import org.apache.hadoop.yarn.api.ApplicationMasterProtocol +import org.apache.hadoop.yarn.api.records.ApplicationAttemptId import org.apache.hadoop.yarn.api.records.{Container, ContainerId, ContainerStatus} import org.apache.hadoop.yarn.api.records.{Priority, Resource, ResourceRequest} import org.apache.hadoop.yarn.api.protocolrecords.{AllocateRequest, AllocateResponse} +import org.apache.hadoop.yarn.client.api.AMRMClient +import org.apache.hadoop.yarn.client.api.AMRMClient.ContainerRequest import org.apache.hadoop.yarn.util.{RackResolver, Records} @@ -56,7 +58,7 @@ object AllocationType extends Enumeration ("HOST", "RACK", "ANY") { // more info on how we are requesting for containers. private[yarn] class YarnAllocationHandler( val conf: Configuration, - val resourceManager: AMRMProtocol, + val amClient: AMRMClient[ContainerRequest], val appAttemptId: ApplicationAttemptId, val maxWorkers: Int, val workerMemory: Int, @@ -83,12 +85,17 @@ private[yarn] class YarnAllocationHandler( // Containers to be released in next request to RM private val pendingReleaseContainers = new ConcurrentHashMap[ContainerId, Boolean] + // Number of container requests that have been sent to, but not yet allocated by the + // ApplicationMaster. + private val numPendingAllocate = new AtomicInteger() private val numWorkersRunning = new AtomicInteger() // Used to generate a unique id per worker private val workerIdCounter = new AtomicInteger() private val lastResponseId = new AtomicInteger() private val numWorkersFailed = new AtomicInteger() + def getNumPendingAllocate: Int = numPendingAllocate.intValue + def getNumWorkersRunning: Int = numWorkersRunning.intValue def getNumWorkersFailed: Int = numWorkersFailed.intValue @@ -97,154 +104,163 @@ private[yarn] class YarnAllocationHandler( container.getResource.getMemory >= (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD) } - def allocateContainers(workersToRequest: Int) { - // We need to send the request only once from what I understand ... but for now, not modifying - // this much. + def releaseContainer(container: Container) { + val containerId = container.getId + pendingReleaseContainers.put(containerId, true) + amClient.releaseAssignedContainer(containerId) + } + + def allocateResources() { + // We have already set the container request. Poll the ResourceManager for a response. + // This doubles as a heartbeat if there are no pending container requests. + val progressIndicator = 0.1f + val allocateResponse = amClient.allocate(progressIndicator) - // Keep polling the Resource Manager for containers - val amResp = allocateWorkerResources(workersToRequest).getAMResponse + val allocatedContainers = allocateResponse.getAllocatedContainers() + if (allocatedContainers.size > 0) { + var numPendingAllocateNow = numPendingAllocate.addAndGet(-1 * allocatedContainers.size) - val _allocatedContainers = amResp.getAllocatedContainers() + if (numPendingAllocateNow < 0) { + numPendingAllocateNow = numPendingAllocate.addAndGet(-1 * numPendingAllocateNow) + } - if (_allocatedContainers.size > 0) { logDebug(""" Allocated containers: %d Current worker count: %d - Containers to-be-released: %d - pendingReleaseContainers: %s + Containers released: %s + Containers to-be-released: %s Cluster resources: %s """.format( allocatedContainers.size, numWorkersRunning.get(), releasedContainerList, pendingReleaseContainers, - amResp.getAvailableResources)) + allocateResponse.getAvailableResources)) val hostToContainers = new HashMap[String, ArrayBuffer[Container]]() - // Ignore if not satisfying constraints { - for (container <- _allocatedContainers) { + for (container <- allocatedContainers) { if (isResourceConstraintSatisfied(container)) { - // allocatedContainers += container - + // Add the accepted `container` to the host's list of already accepted, + // allocated containers val host = container.getNodeId.getHost - val containers = hostToContainers.getOrElseUpdate(host, new ArrayBuffer[Container]()) - - containers += container + val containersForHost = hostToContainers.getOrElseUpdate(host, + new ArrayBuffer[Container]()) + containersForHost += container + } else { + // Release container, since it doesn't satisfy resource constraints. + releaseContainer(container) } - // Add all ignored containers to released list - else releasedContainerList.add(container.getId()) } - // Find the appropriate containers to use. Slightly non trivial groupBy ... + // Find the appropriate containers to use. + // TODO: Cleanup this group-by... val dataLocalContainers = new HashMap[String, ArrayBuffer[Container]]() val rackLocalContainers = new HashMap[String, ArrayBuffer[Container]]() val offRackContainers = new HashMap[String, ArrayBuffer[Container]]() - for (candidateHost <- hostToContainers.keySet) - { + for (candidateHost <- hostToContainers.keySet) { val maxExpectedHostCount = preferredHostToCount.getOrElse(candidateHost, 0) val requiredHostCount = maxExpectedHostCount - allocatedContainersOnHost(candidateHost) - var remainingContainers = hostToContainers.get(candidateHost).getOrElse(null) - assert(remainingContainers != null) + val remainingContainersOpt = hostToContainers.get(candidateHost) + assert(remainingContainersOpt.isDefined) + var remainingContainers = remainingContainersOpt.get - if (requiredHostCount >= remainingContainers.size){ - // Since we got <= required containers, add all to dataLocalContainers + if (requiredHostCount >= remainingContainers.size) { + // Since we have <= required containers, add all remaining containers to + // `dataLocalContainers`. dataLocalContainers.put(candidateHost, remainingContainers) - // all consumed + // There are no more free containers remaining. remainingContainers = null - } - else if (requiredHostCount > 0) { + } else if (requiredHostCount > 0) { // Container list has more containers than we need for data locality. - // Split into two : data local container count of (remainingContainers.size - - // requiredHostCount) and rest as remainingContainer + // Split the list into two: one based on the data local container count, + // (`remainingContainers.size` - `requiredHostCount`), and the other to hold remaining + // containers. val (dataLocal, remaining) = remainingContainers.splitAt( remainingContainers.size - requiredHostCount) dataLocalContainers.put(candidateHost, dataLocal) - // remainingContainers = remaining - // yarn has nasty habit of allocating a tonne of containers on a host - discourage this : - // add remaining to release list. If we have insufficient containers, next allocation - // cycle will reallocate (but wont treat it as data local) - for (container <- remaining) releasedContainerList.add(container.getId()) + // Invariant: remainingContainers == remaining + + // YARN has a nasty habit of allocating a ton of containers on a host - discourage this. + // Add each container in `remaining` to list of containers to release. If we have an + // insufficient number of containers, then the next allocation cycle will reallocate + // (but won't treat it as data local). + // TODO(harvey): Rephrase this comment some more. + for (container <- remaining) releaseContainer(container) remainingContainers = null } - // Now rack local - if (remainingContainers != null){ + // For rack local containers + if (remainingContainers != null) { val rack = YarnAllocationHandler.lookupRack(conf, candidateHost) - - if (rack != null){ + if (rack != null) { val maxExpectedRackCount = preferredRackToCount.getOrElse(rack, 0) - val requiredRackCount = maxExpectedRackCount - allocatedContainersOnRack(rack) - - rackLocalContainers.get(rack).getOrElse(List()).size - + val requiredRackCount = maxExpectedRackCount - allocatedContainersOnRack(rack) - + rackLocalContainers.getOrElse(rack, List()).size - if (requiredRackCount >= remainingContainers.size){ - // Add all to dataLocalContainers + if (requiredRackCount >= remainingContainers.size) { + // Add all remaining containers to to `dataLocalContainers`. dataLocalContainers.put(rack, remainingContainers) - // All consumed remainingContainers = null - } - else if (requiredRackCount > 0) { - // container list has more containers than we need for data locality. - // Split into two : data local container count of (remainingContainers.size - - // requiredRackCount) and rest as remainingContainer + } else if (requiredRackCount > 0) { + // Container list has more containers that we need for data locality. + // Split the list into two: one based on the data local container count, + // (`remainingContainers.size` - `requiredHostCount`), and the other to hold remaining + // containers. val (rackLocal, remaining) = remainingContainers.splitAt( remainingContainers.size - requiredRackCount) val existingRackLocal = rackLocalContainers.getOrElseUpdate(rack, new ArrayBuffer[Container]()) existingRackLocal ++= rackLocal + remainingContainers = remaining } } } - // If still not consumed, then it is off rack host - add to that list. - if (remainingContainers != null){ + if (remainingContainers != null) { + // Not all containers have been consumed - add them to the list of off-rack containers. offRackContainers.put(candidateHost, remainingContainers) } } - // Now that we have split the containers into various groups, go through them in order : - // first host local, then rack local and then off rack (everything else). - // Note that the list we create below tries to ensure that not all containers end up within a - // host if there are sufficiently large number of hosts/containers. - - val allocatedContainers = new ArrayBuffer[Container](_allocatedContainers.size) - allocatedContainers ++= ClusterScheduler.prioritizeContainers(dataLocalContainers) - allocatedContainers ++= ClusterScheduler.prioritizeContainers(rackLocalContainers) - allocatedContainers ++= ClusterScheduler.prioritizeContainers(offRackContainers) - - // Run each of the allocated containers - for (container <- allocatedContainers) { + // Now that we have split the containers into various groups, go through them in order: + // first host-local, then rack-local, and finally off-rack. + // Note that the list we create below tries to ensure that not all containers end up within + // a host if there is a sufficiently large number of hosts/containers. + val allocatedContainersToProcess = new ArrayBuffer[Container](allocatedContainers.size) + allocatedContainersToProcess ++= ClusterScheduler.prioritizeContainers(dataLocalContainers) + allocatedContainersToProcess ++= ClusterScheduler.prioritizeContainers(rackLocalContainers) + allocatedContainersToProcess ++= ClusterScheduler.prioritizeContainers(offRackContainers) + + // Run each of the allocated containers. + for (container <- allocatedContainersToProcess) { val numWorkersRunningNow = numWorkersRunning.incrementAndGet() val workerHostname = container.getNodeId.getHost val containerId = container.getId - assert( - container.getResource.getMemory >= (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD)) + val workerMemoryOverhead = (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD) + assert(container.getResource.getMemory >= workerMemoryOverhead) if (numWorkersRunningNow > maxWorkers) { - logInfo("""Ignoring container %d at host %s, since we already have the required number of + logInfo("""Ignoring container %s at host %s, since we already have the required number of containers for it.""".format(containerId, workerHostname)) - releasedContainerList.add(containerId) - // reset counter back to old value. + releaseContainer(container) numWorkersRunning.decrementAndGet() - } - else { - // Deallocate + allocate can result in reusing id's wrongly - so use a different counter - // (workerIdCounter) + } else { val workerId = workerIdCounter.incrementAndGet().toString val driverUrl = "akka://spark@%s:%s/user/%s".format( - System.getProperty("spark.driver.host"), System.getProperty("spark.driver.port"), + System.getProperty("spark.driver.host"), + System.getProperty("spark.driver.port"), CoarseGrainedSchedulerBackend.ACTOR_NAME) - logInfo("launching container on " + containerId + " host " + workerHostname) - // Just to be safe, simply remove it from pendingReleaseContainers. - // Should not be there, but .. + logInfo("Launching container %s for on host %s".format(containerId, workerHostname)) + + // To be safe, remove the container from `pendingReleaseContainers`. pendingReleaseContainers.remove(containerId) val rack = YarnAllocationHandler.lookupRack(conf, workerHostname) @@ -254,45 +270,52 @@ private[yarn] class YarnAllocationHandler( containerSet += containerId allocatedContainerToHostMap.put(containerId, workerHostname) + if (rack != null) { allocatedRackCount.put(rack, allocatedRackCount.getOrElse(rack, 0) + 1) } } - - new Thread( - new WorkerRunnable(container, conf, driverUrl, workerId, - workerHostname, workerMemory, workerCores) - ).start() + logInfo("Launching WorkerRunnable. driverUrl: %s, workerHostname: %s".format(driverUrl, workerHostname)) + val workerRunnable = new WorkerRunnable( + container, + conf, + driverUrl, + workerId, + workerHostname, + workerMemory, + workerCores) + new Thread(workerRunnable).start() } } logDebug(""" - Finished processing %d completed containers. + Finished allocating %s containers (from %s originally). Current number of workers running: %d, releasedContainerList: %s, pendingReleaseContainers: %s """.format( - completedContainers.size, + allocatedContainersToProcess, + allocatedContainers, numWorkersRunning.get(), releasedContainerList, pendingReleaseContainers)) } + val completedContainers = allocateResponse.getCompletedContainersStatuses() + if (completedContainers.size > 0) { + logDebug("Completed %d containers".format(completedContainers.size)) - val completedContainers = amResp.getCompletedContainersStatuses() - if (completedContainers.size > 0){ - logDebug("Completed %d containers, to-be-released: %s".format( - completedContainers.size, releasedContainerList)) - for (completedContainer <- completedContainers){ + for (completedContainer <- completedContainers) { val containerId = completedContainer.getContainerId - // Was this released by us ? If yes, then simply remove from containerSet and move on. if (pendingReleaseContainers.containsKey(containerId)) { + // YarnAllocationHandler already marked the container for release, so remove it from + // `pendingReleaseContainers`. pendingReleaseContainers.remove(containerId) - } - else { - // Simply decrement count - next iteration of ReporterThread will take care of allocating. + } else { + // Decrement the number of workers running. The next iteration of the ApplicationMaster's + // reporting thread will take care of allocating. numWorkersRunning.decrementAndGet() - logInfo("Completed container %d (state: %s, http address: %s, exit status: %s)".format( + logInfo("Completed container %s (state: %s, exit status: %s)".format( containerId, completedContainer.getState, completedContainer.getExitStatus())) @@ -307,24 +330,32 @@ private[yarn] class YarnAllocationHandler( allocatedHostToContainersMap.synchronized { if (allocatedContainerToHostMap.containsKey(containerId)) { - val host = allocatedContainerToHostMap.get(containerId).getOrElse(null) - assert (host != null) - - val containerSet = allocatedHostToContainersMap.get(host).getOrElse(null) - assert (containerSet != null) - - containerSet -= containerId - if (containerSet.isEmpty) allocatedHostToContainersMap.remove(host) - else allocatedHostToContainersMap.update(host, containerSet) + val hostOpt = allocatedContainerToHostMap.get(containerId) + assert(hostOpt.isDefined) + val host = hostOpt.get + + val containerSetOpt = allocatedHostToContainersMap.get(host) + assert(containerSetOpt.isDefined) + val containerSet = containerSetOpt.get + + containerSet.remove(containerId) + if (containerSet.isEmpty) { + allocatedHostToContainersMap.remove(host) + } else { + allocatedHostToContainersMap.update(host, containerSet) + } - allocatedContainerToHostMap -= containerId + allocatedContainerToHostMap.remove(containerId) - // Doing this within locked context, sigh ... move to outside ? + // TODO: Move this part outside the synchronized block? val rack = YarnAllocationHandler.lookupRack(conf, host) if (rack != null) { val rackCount = allocatedRackCount.getOrElse(rack, 0) - 1 - if (rackCount > 0) allocatedRackCount.put(rack, rackCount) - else allocatedRackCount.remove(rack) + if (rackCount > 0) { + allocatedRackCount.put(rack, rackCount) + } else { + allocatedRackCount.remove(rack) + } } } } @@ -342,32 +373,34 @@ private[yarn] class YarnAllocationHandler( } } - def createRackResourceRequests(hostContainers: List[ResourceRequest]): List[ResourceRequest] = { - // First generate modified racks and new set of hosts under it : then issue requests + def createRackResourceRequests( + hostContainers: ArrayBuffer[ContainerRequest] + ): ArrayBuffer[ContainerRequest] = { + // Generate modified racks and new set of hosts under it before issuing requests. val rackToCounts = new HashMap[String, Int]() - // Within this lock - used to read/write to the rack related maps too. for (container <- hostContainers) { - val candidateHost = container.getHostName - val candidateNumContainers = container.getNumContainers + val candidateHost = container.getNodes.last assert(YarnAllocationHandler.ANY_HOST != candidateHost) val rack = YarnAllocationHandler.lookupRack(conf, candidateHost) if (rack != null) { var count = rackToCounts.getOrElse(rack, 0) - count += candidateNumContainers + count += 1 rackToCounts.put(rack, count) } } - val requestedContainers: ArrayBuffer[ResourceRequest] = - new ArrayBuffer[ResourceRequest](rackToCounts.size) - for ((rack, count) <- rackToCounts){ - requestedContainers += - createResourceRequest(AllocationType.RACK, rack, count, YarnAllocationHandler.PRIORITY) + val requestedContainers = new ArrayBuffer[ContainerRequest](rackToCounts.size) + for ((rack, count) <- rackToCounts) { + requestedContainers ++= createResourceRequests( + AllocationType.RACK, + rack, + count, + YarnAllocationHandler.PRIORITY) } - requestedContainers.toList + requestedContainers } def allocatedContainersOnHost(host: String): Int = { @@ -386,147 +419,128 @@ private[yarn] class YarnAllocationHandler( retval } - private def allocateWorkerResources(numWorkers: Int): AllocateResponse = { - - var resourceRequests: List[ResourceRequest] = null - - // default. - if (numWorkers <= 0 || preferredHostToCount.isEmpty) { - logDebug("numWorkers: " + numWorkers + ", host preferences: " + preferredHostToCount.isEmpty) - resourceRequests = List( - createResourceRequest(AllocationType.ANY, null, numWorkers, YarnAllocationHandler.PRIORITY)) - } - else { - // request for all hosts in preferred nodes and for numWorkers - - // candidates.size, request by default allocation policy. - val hostContainerRequests: ArrayBuffer[ResourceRequest] = - new ArrayBuffer[ResourceRequest](preferredHostToCount.size) - for ((candidateHost, candidateCount) <- preferredHostToCount) { - val requiredCount = candidateCount - allocatedContainersOnHost(candidateHost) - - if (requiredCount > 0) { - hostContainerRequests += createResourceRequest( - AllocationType.HOST, - candidateHost, - requiredCount, - YarnAllocationHandler.PRIORITY) + def addResourceRequests(numWorkers: Int) { + val containerRequests: List[ContainerRequest] = + if (numWorkers <= 0 || preferredHostToCount.isEmpty) { + logDebug("numWorkers: " + numWorkers + ", host preferences: " + + preferredHostToCount.isEmpty) + createResourceRequests( + AllocationType.ANY, + resource = null, + numWorkers, + YarnAllocationHandler.PRIORITY).toList + } else { + // Request for all hosts in preferred nodes and for numWorkers - + // candidates.size, request by default allocation policy. + val hostContainerRequests = new ArrayBuffer[ContainerRequest](preferredHostToCount.size) + for ((candidateHost, candidateCount) <- preferredHostToCount) { + val requiredCount = candidateCount - allocatedContainersOnHost(candidateHost) + + if (requiredCount > 0) { + hostContainerRequests ++= createResourceRequests( + AllocationType.HOST, + candidateHost, + requiredCount, + YarnAllocationHandler.PRIORITY) + } } + val rackContainerRequests: List[ContainerRequest] = createRackResourceRequests( + hostContainerRequests).toList + + val anyContainerRequests = createResourceRequests( + AllocationType.ANY, + resource = null, + numWorkers, + YarnAllocationHandler.PRIORITY) + + val containerRequestBuffer = new ArrayBuffer[ContainerRequest]( + hostContainerRequests.size + rackContainerRequests.size() + anyContainerRequests.size) + + containerRequestBuffer ++= hostContainerRequests + containerRequestBuffer ++= rackContainerRequests + containerRequestBuffer ++= anyContainerRequests + containerRequestBuffer.toList } - val rackContainerRequests: List[ResourceRequest] = createRackResourceRequests( - hostContainerRequests.toList) - val anyContainerRequests: ResourceRequest = createResourceRequest( - AllocationType.ANY, - resource = null, - numWorkers, - YarnAllocationHandler.PRIORITY) - - val containerRequests: ArrayBuffer[ResourceRequest] = new ArrayBuffer[ResourceRequest]( - hostContainerRequests.size() + rackContainerRequests.size() + 1) - - containerRequests ++= hostContainerRequests - containerRequests ++= rackContainerRequests - containerRequests += anyContainerRequests - - resourceRequests = containerRequests.toList + for (request <- containerRequests) { + amClient.addContainerRequest(request) } - val req = Records.newRecord(classOf[AllocateRequest]) - req.setResponseId(lastResponseId.incrementAndGet) - req.setApplicationAttemptId(appAttemptId) - - req.addAllAsks(resourceRequests) - - val releasedContainerList = createReleasedContainerList() - req.addAllReleases(releasedContainerList) - if (numWorkers > 0) { - logInfo("Allocating %d worker containers with %d of memory each.").format(numWorkers, - workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD) - } - else { - logDebug("Empty allocation req .. release : " + releasedContainerList) + numPendingAllocate.addAndGet(numWorkers) + logInfo("Will Allocate %d worker containers, each with %d memory".format( + numWorkers, + (workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD))) + } else { + logDebug("Empty allocation request ...") } - for (request <- resourceRequests) { - logInfo("ResourceRequest (host : %s, num containers: %d, priority = %d , capability : %s)"). - format( - request.getHostName, - request.getNumContainers, - request.getPriority, - request.getCapability) + for (request <- containerRequests) { + val nodes = request.getNodes + var hostStr = if (nodes == null || nodes.isEmpty) { + "Any" + } else { + nodes.last + } + logInfo("Container request (host: %s, priority: %s, capability: %s".format( + hostStr, + request.getPriority().getPriority, + request.getCapability)) } - resourceManager.allocate(req) } + private def createResourceRequests( + requestType: AllocationType.AllocationType, + resource: String, + numWorkers: Int, + priority: Int + ): ArrayBuffer[ContainerRequest] = { - private def createResourceRequest( - requestType: AllocationType.AllocationType, - resource:String, - numWorkers: Int, - priority: Int): ResourceRequest = { - - // If hostname specified, we need atleast two requests - node local and rack local. - // There must be a third request - which is ANY : that will be specially handled. + // If hostname is specified, then we need at least two requests - node local and rack local. + // There must be a third request, which is ANY. That will be specially handled. requestType match { case AllocationType.HOST => { assert(YarnAllocationHandler.ANY_HOST != resource) val hostname = resource - val nodeLocal = createResourceRequestImpl(hostname, numWorkers, priority) + val nodeLocal = constructContainerRequests( + Array(hostname), + racks = null, + numWorkers, + priority) - // Add to host->rack mapping + // Add `hostname` to the global (singleton) host->rack mapping in YarnAllocationHandler. YarnAllocationHandler.populateRackInfo(conf, hostname) - nodeLocal } case AllocationType.RACK => { val rack = resource - createResourceRequestImpl(rack, numWorkers, priority) + constructContainerRequests(hosts = null, Array(rack), numWorkers, priority) } - case AllocationType.ANY => createResourceRequestImpl( - YarnAllocationHandler.ANY_HOST, numWorkers, priority) + case AllocationType.ANY => constructContainerRequests( + hosts = null, racks = null, numWorkers, priority) case _ => throw new IllegalArgumentException( "Unexpected/unsupported request type: " + requestType) } } - private def createResourceRequestImpl( - hostname:String, - numWorkers: Int, - priority: Int): ResourceRequest = { - - val rsrcRequest = Records.newRecord(classOf[ResourceRequest]) - val memCapability = Records.newRecord(classOf[Resource]) - // There probably is some overhead here, let's reserve a bit more memory. - memCapability.setMemory(workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD) - rsrcRequest.setCapability(memCapability) + private def constructContainerRequests( + hosts: Array[String], + racks: Array[String], + numWorkers: Int, + priority: Int + ): ArrayBuffer[ContainerRequest] = { - val pri = Records.newRecord(classOf[Priority]) - pri.setPriority(priority) - rsrcRequest.setPriority(pri) + val memoryResource = Records.newRecord(classOf[Resource]) + memoryResource.setMemory(workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD) - rsrcRequest.setHostName(hostname) - - rsrcRequest.setNumContainers(java.lang.Math.max(numWorkers, 0)) - rsrcRequest - } + val prioritySetting = Records.newRecord(classOf[Priority]) + prioritySetting.setPriority(priority) - def createReleasedContainerList(): ArrayBuffer[ContainerId] = { - - val retval = new ArrayBuffer[ContainerId](1) - // Iterator on COW list ... - for (container <- releasedContainerList.iterator()){ - retval += container - } - // Remove from the original list. - if (! retval.isEmpty) { - releasedContainerList.removeAll(retval) - for (v <- retval) pendingReleaseContainers.put(v, true) - logInfo("Releasing " + retval.size + " containers. pendingReleaseContainers : " + - pendingReleaseContainers) + val requests = new ArrayBuffer[ContainerRequest]() + for (i <- 0 until numWorkers) { + requests += new ContainerRequest(memoryResource, hosts, racks, prioritySetting) } - - retval + requests } } @@ -537,26 +551,25 @@ object YarnAllocationHandler { // request types (like map/reduce in hadoop for example) val PRIORITY = 1 - // Additional memory overhead - in mb + // Additional memory overhead - in mb. val MEMORY_OVERHEAD = 384 - // Host to rack map - saved from allocation requests - // We are expecting this not to change. - // Note that it is possible for this to change : and RM will indicate that to us via update - // response to allocate. But we are punting on handling that for now. + // Host to rack map - saved from allocation requests. We are expecting this not to change. + // Note that it is possible for this to change : and ResurceManager will indicate that to us via + // update response to allocate. But we are punting on handling that for now. private val hostToRack = new ConcurrentHashMap[String, String]() private val rackToHostSet = new ConcurrentHashMap[String, JSet[String]]() def newAllocator( - conf: Configuration, - resourceManager: AMRMProtocol, - appAttemptId: ApplicationAttemptId, - args: ApplicationMasterArguments): YarnAllocationHandler = { - + conf: Configuration, + amClient: AMRMClient[ContainerRequest], + appAttemptId: ApplicationAttemptId, + args: ApplicationMasterArguments + ): YarnAllocationHandler = { new YarnAllocationHandler( conf, - resourceManager, + amClient, appAttemptId, args.numWorkers, args.workerMemory, @@ -566,39 +579,38 @@ object YarnAllocationHandler { } def newAllocator( - conf: Configuration, - resourceManager: AMRMProtocol, - appAttemptId: ApplicationAttemptId, - args: ApplicationMasterArguments, - map: collection.Map[String, - collection.Set[SplitInfo]]): YarnAllocationHandler = { - - val (hostToCount, rackToCount) = generateNodeToWeight(conf, map) + conf: Configuration, + amClient: AMRMClient[ContainerRequest], + appAttemptId: ApplicationAttemptId, + args: ApplicationMasterArguments, + map: collection.Map[String, + collection.Set[SplitInfo]] + ): YarnAllocationHandler = { + val (hostToSplitCount, rackToSplitCount) = generateNodeToWeight(conf, map) new YarnAllocationHandler( conf, - resourceManager, + amClient, appAttemptId, args.numWorkers, args.workerMemory, args.workerCores, - hostToCount, - rackToCount) + hostToSplitCount, + rackToSplitCount) } def newAllocator( - conf: Configuration, - resourceManager: AMRMProtocol, - appAttemptId: ApplicationAttemptId, - maxWorkers: Int, - workerMemory: Int, - workerCores: Int, - map: collection.Map[String, collection.Set[SplitInfo]]): YarnAllocationHandler = { - + conf: Configuration, + amClient: AMRMClient[ContainerRequest], + appAttemptId: ApplicationAttemptId, + maxWorkers: Int, + workerMemory: Int, + workerCores: Int, + map: collection.Map[String, collection.Set[SplitInfo]] + ): YarnAllocationHandler = { val (hostToCount, rackToCount) = generateNodeToWeight(conf, map) - new YarnAllocationHandler( conf, - resourceManager, + amClient, appAttemptId, maxWorkers, workerMemory, @@ -609,12 +621,13 @@ object YarnAllocationHandler { // A simple method to copy the split info map. private def generateNodeToWeight( - conf: Configuration, - input: collection.Map[String, collection.Set[SplitInfo]]) : - // host to count, rack to count - (Map[String, Int], Map[String, Int]) = { + conf: Configuration, + input: collection.Map[String, collection.Set[SplitInfo]] + ): (Map[String, Int], Map[String, Int]) = { - if (input == null) return (Map[String, Int](), Map[String, Int]()) + if (input == null) { + return (Map[String, Int](), Map[String, Int]()) + } val hostToCount = new HashMap[String, Int] val rackToCount = new HashMap[String, Int] @@ -634,24 +647,25 @@ object YarnAllocationHandler { } def lookupRack(conf: Configuration, host: String): String = { - if (!hostToRack.contains(host)) populateRackInfo(conf, host) + if (!hostToRack.contains(host)) { + populateRackInfo(conf, host) + } hostToRack.get(host) } def fetchCachedHostsForRack(rack: String): Option[Set[String]] = { - val set = rackToHostSet.get(rack) - if (set == null) return None - - // No better way to get a Set[String] from JSet ? - val convertedSet: collection.mutable.Set[String] = set - Some(convertedSet.toSet) + Option(rackToHostSet.get(rack)).map { set => + val convertedSet: collection.mutable.Set[String] = set + // TODO: Better way to get a Set[String] from JSet. + convertedSet.toSet + } } def populateRackInfo(conf: Configuration, hostname: String) { Utils.checkHost(hostname) if (!hostToRack.containsKey(hostname)) { - // If there are repeated failures to resolve, all to an ignore list ? + // If there are repeated failures to resolve, all to an ignore list. val rackInfo = RackResolver.resolve(conf, hostname) if (rackInfo != null && rackInfo.getNetworkLocation != null) { val rack = rackInfo.getNetworkLocation @@ -662,7 +676,7 @@ object YarnAllocationHandler { } rackToHostSet.get(rack).add(hostname) - // TODO(harvey): Figure out this comment... + // TODO(harvey): Figure out what this comment means... // Since RackResolver caches, we are disabling this for now ... } /* else { // right ? Else we will keep calling rack resolver in case we cant resolve rack info ... -- cgit v1.2.3 From 7535d7fbcbe3c0c2515a2d17a806fa523917e398 Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 23 Nov 2013 17:21:37 -0800 Subject: Fixes to AppendOnlyMap: - Use Murmur Hash 3 finalization step to scramble the bits of HashCode instead of the simpler version in java.util.HashMap; the latter one had trouble with ranges of consecutive integers. Murmur Hash 3 is used by fastutil. - Use Object.equals() instead of Scala's == to compare keys, because the latter does extra casts for numeric types (see the equals method in https://github.com/scala/scala/blob/master/src/library/scala/runtime/BoxesRunTime.java) --- .../main/scala/org/apache/spark/util/AppendOnlyMap.scala | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala index f60deafc6f..8542541fe6 100644 --- a/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala @@ -56,7 +56,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi var i = 1 while (true) { val curKey = data(2 * pos) - if (k.eq(curKey) || k == curKey) { + if (k.eq(curKey) || k.equals(curKey)) { return data(2 * pos + 1).asInstanceOf[V] } else if (curKey.eq(null)) { return null.asInstanceOf[V] @@ -104,7 +104,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi var i = 1 while (true) { val curKey = data(2 * pos) - if (k.eq(curKey) || k == curKey) { + if (k.eq(curKey) || k.equals(curKey)) { val newValue = updateFunc(true, data(2 * pos + 1).asInstanceOf[V]) data(2 * pos + 1) = newValue.asInstanceOf[AnyRef] return newValue @@ -167,12 +167,11 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi } /** - * Re-hash a value to deal better with hash functions that don't differ - * in the lower bits, similar to java.util.HashMap + * Re-hash a value to deal better with hash functions that don't differ in the lower bits. + * We use the Murmur Hash 3 finalization step that's also used in fastutil. */ private def rehash(h: Int): Int = { - val r = h ^ (h >>> 20) ^ (h >>> 12) - r ^ (r >>> 7) ^ (r >>> 4) + it.unimi.dsi.fastutil.HashCommon.murmurHash3(h) } /** @@ -190,7 +189,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi data(2 * pos) = key data(2 * pos + 1) = value.asInstanceOf[AnyRef] return true - } else if (curKey.eq(key) || curKey == key) { + } else if (curKey.eq(key) || curKey.equals(key)) { data(2 * pos + 1) = value.asInstanceOf[AnyRef] return false } else { -- cgit v1.2.3 From 9837a60234964333916ccbf02c8610909462a7ad Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 23 Nov 2013 17:38:29 -0800 Subject: Some other optimizations to AppendOnlyMap: - Don't check keys for equality when re-inserting due to growing the table; the keys will already be unique - Remember the grow threshold instead of recomputing it on each insert --- .../org/apache/spark/util/AppendOnlyMap.scala | 82 ++++++++++++---------- 1 file changed, 45 insertions(+), 37 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala b/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala index 8542541fe6..8bb4ee3bfa 100644 --- a/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala +++ b/core/src/main/scala/org/apache/spark/util/AppendOnlyMap.scala @@ -35,6 +35,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi private var capacity = nextPowerOf2(initialCapacity) private var mask = capacity - 1 private var curSize = 0 + private var growThreshold = LOAD_FACTOR * capacity // Holds keys and values in the same array for memory locality; specifically, the order of // elements is key0, value0, key1, value1, key2, value2, etc. @@ -80,9 +81,23 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi haveNullValue = true return } - val isNewEntry = putInto(data, k, value.asInstanceOf[AnyRef]) - if (isNewEntry) { - incrementSize() + var pos = rehash(key.hashCode) & mask + var i = 1 + while (true) { + val curKey = data(2 * pos) + if (curKey.eq(null)) { + data(2 * pos) = k + data(2 * pos + 1) = value.asInstanceOf[AnyRef] + incrementSize() // Since we added a new key + return + } else if (k.eq(curKey) || k.equals(curKey)) { + data(2 * pos + 1) = value.asInstanceOf[AnyRef] + return + } else { + val delta = i + pos = (pos + delta) & mask + i += 1 + } } } @@ -161,7 +176,7 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi /** Increase table size by 1, rehashing if necessary */ private def incrementSize() { curSize += 1 - if (curSize > LOAD_FACTOR * capacity) { + if (curSize > growThreshold) { growTable() } } @@ -174,33 +189,6 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi it.unimi.dsi.fastutil.HashCommon.murmurHash3(h) } - /** - * Put an entry into a table represented by data, returning true if - * this increases the size of the table or false otherwise. Assumes - * that "data" has at least one empty slot. - */ - private def putInto(data: Array[AnyRef], key: AnyRef, value: AnyRef): Boolean = { - val mask = (data.length / 2) - 1 - var pos = rehash(key.hashCode) & mask - var i = 1 - while (true) { - val curKey = data(2 * pos) - if (curKey.eq(null)) { - data(2 * pos) = key - data(2 * pos + 1) = value.asInstanceOf[AnyRef] - return true - } else if (curKey.eq(key) || curKey.equals(key)) { - data(2 * pos + 1) = value.asInstanceOf[AnyRef] - return false - } else { - val delta = i - pos = (pos + delta) & mask - i += 1 - } - } - return false // Never reached but needed to keep compiler happy - } - /** Double the table's size and re-hash everything */ private def growTable() { val newCapacity = capacity * 2 @@ -210,16 +198,36 @@ class AppendOnlyMap[K, V](initialCapacity: Int = 64) extends Iterable[(K, V)] wi throw new Exception("Can't make capacity bigger than 2^29 elements") } val newData = new Array[AnyRef](2 * newCapacity) - var pos = 0 - while (pos < capacity) { - if (!data(2 * pos).eq(null)) { - putInto(newData, data(2 * pos), data(2 * pos + 1)) + val newMask = newCapacity - 1 + // Insert all our old values into the new array. Note that because our old keys are + // unique, there's no need to check for equality here when we insert. + var oldPos = 0 + while (oldPos < capacity) { + if (!data(2 * oldPos).eq(null)) { + val key = data(2 * oldPos) + val value = data(2 * oldPos + 1) + var newPos = rehash(key.hashCode) & newMask + var i = 1 + var keepGoing = true + while (keepGoing) { + val curKey = newData(2 * newPos) + if (curKey.eq(null)) { + newData(2 * newPos) = key + newData(2 * newPos + 1) = value + keepGoing = false + } else { + val delta = i + newPos = (newPos + delta) & newMask + i += 1 + } + } } - pos += 1 + oldPos += 1 } data = newData capacity = newCapacity - mask = newCapacity - 1 + mask = newMask + growThreshold = LOAD_FACTOR * newCapacity } private def nextPowerOf2(n: Int): Int = { -- cgit v1.2.3 From e9ff13ec72718ada705b85cc10da1b09bcc86dcc Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 24 Nov 2013 17:56:43 +0800 Subject: Consolidated both mapPartitions related RDDs into a single MapPartitionsRDD. Also changed the semantics of the index parameter in mapPartitionsWithIndex from the partition index of the output partition to the partition index in the current RDD. --- .../org/apache/spark/rdd/MapPartitionsRDD.scala | 10 +++--- .../spark/rdd/MapPartitionsWithContextRDD.scala | 41 ---------------------- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 39 ++++++++++---------- .../scala/org/apache/spark/CheckpointSuite.scala | 2 -- 4 files changed, 22 insertions(+), 70 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala index 203179c4ea..ae70d55951 100644 --- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala @@ -20,18 +20,16 @@ package org.apache.spark.rdd import org.apache.spark.{Partition, TaskContext} -private[spark] -class MapPartitionsRDD[U: ClassManifest, T: ClassManifest]( +private[spark] class MapPartitionsRDD[U: ClassManifest, T: ClassManifest]( prev: RDD[T], - f: Iterator[T] => Iterator[U], + f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator) preservesPartitioning: Boolean = false) extends RDD[U](prev) { - override val partitioner = - if (preservesPartitioning) firstParent[T].partitioner else None + override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None override def getPartitions: Array[Partition] = firstParent[T].partitions override def compute(split: Partition, context: TaskContext) = - f(firstParent[T].iterator(split, context)) + f(context, split.index, firstParent[T].iterator(split, context)) } diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala deleted file mode 100644 index aea08ff81b..0000000000 --- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.rdd - -import org.apache.spark.{Partition, TaskContext} - - -/** - * A variant of the MapPartitionsRDD that passes the TaskContext into the closure. From the - * TaskContext, the closure can either get access to the interruptible flag or get the index - * of the partition in the RDD. - */ -private[spark] -class MapPartitionsWithContextRDD[U: ClassManifest, T: ClassManifest]( - prev: RDD[T], - f: (TaskContext, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean - ) extends RDD[U](prev) { - - override def getPartitions: Array[Partition] = firstParent[T].partitions - - override val partitioner = if (preservesPartitioning) prev.partitioner else None - - override def compute(split: Partition, context: TaskContext) = - f(context, firstParent[T].iterator(split, context)) -} diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 7623c44d88..5b1285307d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -408,7 +408,6 @@ abstract class RDD[T: ClassManifest]( def pipe(command: String, env: Map[String, String]): RDD[String] = new PipedRDD(this, command, env) - /** * Return an RDD created by piping elements to a forked external process. * The print behavior can be customized by providing two functions. @@ -442,7 +441,8 @@ abstract class RDD[T: ClassManifest]( */ def mapPartitions[U: ClassManifest]( f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { - new MapPartitionsRDD(this, sc.clean(f), preservesPartitioning) + val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(iter) + new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning) } /** @@ -451,8 +451,8 @@ abstract class RDD[T: ClassManifest]( */ def mapPartitionsWithIndex[U: ClassManifest]( f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { - val func = (context: TaskContext, iter: Iterator[T]) => f(context.partitionId, iter) - new MapPartitionsWithContextRDD(this, sc.clean(func), preservesPartitioning) + val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter) + new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning) } /** @@ -462,7 +462,8 @@ abstract class RDD[T: ClassManifest]( def mapPartitionsWithContext[U: ClassManifest]( f: (TaskContext, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { - new MapPartitionsWithContextRDD(this, sc.clean(f), preservesPartitioning) + val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(context, iter) + new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning) } /** @@ -483,11 +484,10 @@ abstract class RDD[T: ClassManifest]( def mapWith[A: ClassManifest, U: ClassManifest] (constructA: Int => A, preservesPartitioning: Boolean = false) (f: (T, A) => U): RDD[U] = { - def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = { - val a = constructA(context.partitionId) + mapPartitionsWithIndex((index, iter) => { + val a = constructA(index) iter.map(t => f(t, a)) - } - new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning) + }, preservesPartitioning) } /** @@ -498,11 +498,10 @@ abstract class RDD[T: ClassManifest]( def flatMapWith[A: ClassManifest, U: ClassManifest] (constructA: Int => A, preservesPartitioning: Boolean = false) (f: (T, A) => Seq[U]): RDD[U] = { - def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = { - val a = constructA(context.partitionId) + mapPartitionsWithIndex((index, iter) => { + val a = constructA(index) iter.flatMap(t => f(t, a)) - } - new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning) + }, preservesPartitioning) } /** @@ -511,11 +510,10 @@ abstract class RDD[T: ClassManifest]( * partition with the index of that partition. */ def foreachWith[A: ClassManifest](constructA: Int => A)(f: (T, A) => Unit) { - def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = { - val a = constructA(context.partitionId) + mapPartitionsWithIndex { (index, iter) => + val a = constructA(index) iter.map(t => {f(t, a); t}) - } - new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true).foreach(_ => {}) + }.foreach(_ => {}) } /** @@ -524,11 +522,10 @@ abstract class RDD[T: ClassManifest]( * partition with the index of that partition. */ def filterWith[A: ClassManifest](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = { - def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = { - val a = constructA(context.partitionId) + mapPartitionsWithIndex((index, iter) => { + val a = constructA(index) iter.filter(t => p(t, a)) - } - new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true) + }, preservesPartitioning = true) } /** diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index f26c44d3e7..d2226aa5a5 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -62,8 +62,6 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { testCheckpointing(_.sample(false, 0.5, 0)) testCheckpointing(_.glom()) testCheckpointing(_.mapPartitions(_.map(_.toString))) - testCheckpointing(r => new MapPartitionsWithContextRDD(r, - (context: TaskContext, iter: Iterator[Int]) => iter.map(_.toString), false )) testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString)) testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x)) testCheckpointing(_.pipe(Seq("cat"))) -- cgit v1.2.3 From 466fd06475d8ed262c456421ed2dceba54229db1 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 25 Nov 2013 18:27:26 +0800 Subject: Incorporated ideas from pull request #200. - Use Murmur Hash 3 finalization step to scramble the bits of HashCode instead of the simpler version in java.util.HashMap; the latter one had trouble with ranges of consecutive integers. Murmur Hash 3 is used by fastutil. - Don't check keys for equality when re-inserting due to growing the table; the keys will already be unique - Remember the grow threshold instead of recomputing it on each insert --- .../apache/spark/util/collection/OpenHashSet.scala | 107 +++++++++++---------- 1 file changed, 57 insertions(+), 50 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index 4592e4f939..40986e3731 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -79,6 +79,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest]( protected var _capacity = nextPowerOf2(initialCapacity) protected var _mask = _capacity - 1 protected var _size = 0 + protected var _growThreshold = (loadFactor * _capacity).toInt protected var _bitset = new BitSet(_capacity) @@ -115,7 +116,29 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest]( * @return The position where the key is placed, plus the highest order bit is set if the key * exists previously. */ - def addWithoutResize(k: T): Int = putInto(_bitset, _data, k) + def addWithoutResize(k: T): Int = { + var pos = hashcode(hasher.hash(k)) & _mask + var i = 1 + while (true) { + if (!_bitset.get(pos)) { + // This is a new key. + _data(pos) = k + _bitset.set(pos) + _size += 1 + return pos | NONEXISTENCE_MASK + } else if (_data(pos) == k) { + // Found an existing key. + return pos + } else { + val delta = i + pos = (pos + delta) & _mask + i += 1 + } + } + // Never reached here + assert(INVALID_POS != INVALID_POS) + INVALID_POS + } /** * Rehash the set if it is overloaded. @@ -126,7 +149,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest]( * to a new position (in the new data array). */ def rehashIfNeeded(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) { - if (_size > loadFactor * _capacity) { + if (_size > _growThreshold) { rehash(k, allocateFunc, moveFunc) } } @@ -160,37 +183,6 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest]( */ def nextPos(fromPos: Int): Int = _bitset.nextSetBit(fromPos) - /** - * Put an entry into the set. Return the position where the key is placed. In addition, the - * highest bit in the returned position is set if the key exists prior to this put. - * - * This function assumes the data array has at least one empty slot. - */ - private def putInto(bitset: BitSet, data: Array[T], k: T): Int = { - val mask = data.length - 1 - var pos = hashcode(hasher.hash(k)) & mask - var i = 1 - while (true) { - if (!bitset.get(pos)) { - // This is a new key. - data(pos) = k - bitset.set(pos) - _size += 1 - return pos | NONEXISTENCE_MASK - } else if (data(pos) == k) { - // Found an existing key. - return pos - } else { - val delta = i - pos = (pos + delta) & mask - i += 1 - } - } - // Never reached here - assert(INVALID_POS != INVALID_POS) - INVALID_POS - } - /** * Double the table's size and re-hash everything. We are not really using k, but it is declared * so Scala compiler can specialize this method (which leads to calling the specialized version @@ -204,34 +196,49 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest]( */ private def rehash(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) { val newCapacity = _capacity * 2 - require(newCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements") - allocateFunc(newCapacity) - val newData = new Array[T](newCapacity) val newBitset = new BitSet(newCapacity) - var pos = 0 - _size = 0 - while (pos < _capacity) { - if (_bitset.get(pos)) { - val newPos = putInto(newBitset, newData, _data(pos)) - moveFunc(pos, newPos & POSITION_MASK) + val newData = new Array[T](newCapacity) + val newMask = newCapacity - 1 + + var oldPos = 0 + while (oldPos < capacity) { + if (_bitset.get(oldPos)) { + val key = _data(oldPos) + var newPos = hashcode(hasher.hash(key)) & newMask + var i = 1 + var keepGoing = true + // No need to check for equality here when we insert so this has one less if branch than + // the similar code path in addWithoutResize. + while (keepGoing) { + if (!newBitset.get(newPos)) { + // Inserting the key at newPos + newData(newPos) = key + newBitset.set(newPos) + moveFunc(oldPos, newPos) + keepGoing = false + } else { + val delta = i + newPos = (newPos + delta) & newMask + i += 1 + } + } } - pos += 1 + oldPos += 1 } + _bitset = newBitset _data = newData _capacity = newCapacity - _mask = newCapacity - 1 + _mask = newMask + _growThreshold = (loadFactor * newCapacity).toInt } /** - * Re-hash a value to deal better with hash functions that don't differ - * in the lower bits, similar to java.util.HashMap + * Re-hash a value to deal better with hash functions that don't differ in the lower bits. + * We use the Murmur Hash 3 finalization step that's also used in fastutil. */ - private def hashcode(h: Int): Int = { - val r = h ^ (h >>> 20) ^ (h >>> 12) - r ^ (r >>> 7) ^ (r >>> 4) - } + private def hashcode(h: Int): Int = it.unimi.dsi.fastutil.HashCommon.murmurHash3(h) private def nextPowerOf2(n: Int): Int = { val highBit = Integer.highestOneBit(n) -- cgit v1.2.3 From 7222ee29779c3c5146aa5a3d6d060f3b039c1ff7 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 25 Nov 2013 21:06:42 -0800 Subject: Fix the test --- core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala | 4 ++-- core/src/test/scala/org/apache/spark/JavaAPISuite.java | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index 70f7f01d2b..dad5c72e1c 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -191,8 +191,8 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav * the maximum value of the last position and all NaN entries will be counted * in that bucket. */ - def histogram(buckets: Array[Double]): Array[Long] = { - srdd.histogram(buckets.map(_.toDouble), false) + def histogram(buckets: Array[scala.Double]): Array[Long] = { + srdd.histogram(buckets, false) } def histogram(buckets: Array[Double], evenBuckets: Boolean): Array[Long] = { diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java index 8a9c6e63e0..44483fd4ab 100644 --- a/core/src/test/scala/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java @@ -368,10 +368,10 @@ public class JavaAPISuite implements Serializable { public void javaDoubleRDDHistoGram() { JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); // Test using generated buckets - Tuple2 results = rdd.histogram(2); - Double[] expected_buckets = {1.0, 2.5, 4.0}; + Tuple2 results = rdd.histogram(2); + double[] expected_buckets = {1.0, 2.5, 4.0}; long[] expected_counts = {2, 2}; - Assert.assertArrayEquals(expected_buckets, results._1); + Assert.assertArrayEquals(expected_buckets, results._1, 0.1); Assert.assertArrayEquals(expected_counts, results._2); // Test with provided buckets long[] histogram = rdd.histogram(expected_buckets); -- cgit v1.2.3 From 297c09d4bb26ba815c7fcb0a0ff04974959f551e Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Mon, 25 Nov 2013 22:51:33 -0800 Subject: Improve docs for shuffle instrumentation --- .../org/apache/spark/executor/TaskMetrics.scala | 23 ++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 0b4892f98f..c0ce46e379 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -61,50 +61,53 @@ object TaskMetrics { class ShuffleReadMetrics extends Serializable { /** - * Time when shuffle finishs + * Absolute time when this task finished reading shuffle data */ var shuffleFinishTime: Long = _ /** - * Total number of blocks fetched in a shuffle (remote or local) + * Number of blocks fetched in this shuffle by this task (remote or local) */ var totalBlocksFetched: Int = _ /** - * Number of remote blocks fetched in a shuffle + * Number of remote blocks fetched in this shuffle by this task */ var remoteBlocksFetched: Int = _ /** - * Local blocks fetched in a shuffle + * Number of local blocks fetched in this shuffle by this task */ var localBlocksFetched: Int = _ /** - * Total time that is spent blocked waiting for shuffle to fetch data + * Time the task spent waiting for remote shuffle blocks. This only includes the time + * blocking on shuffle input data. For instance if block B is being fetched while the task is + * still not finished processing block A, it is not considered to be blocking on block B. */ var fetchWaitTime: Long = _ /** - * The total amount of time for all the shuffle fetches. This adds up time from overlapping - * shuffles, so can be longer than task time + * Total time spent fetching remote shuffle blocks. This aggregates the time spent fetching all + * input blocks. Since block fetches are both pipelined and parallelized, this can + * exceed fetchWaitTime and executorRunTime. */ var remoteFetchTime: Long = _ /** - * Total number of remote bytes read from a shuffle + * Total number of remote bytes read from the shuffle by this task */ var remoteBytesRead: Long = _ } class ShuffleWriteMetrics extends Serializable { /** - * Number of bytes written for a shuffle + * Number of bytes written for the shuffle by this task */ var shuffleBytesWritten: Long = _ /** - * Time spent blocking on writes to disk or buffer cache, in nanoseconds. + * Time the task spent blocking on writes to disk or buffer cache, in nanoseconds */ var shuffleWriteTime: Long = _ } -- cgit v1.2.3 From db998a6e14389768f93b1fdd6be7847d5f7604fd Mon Sep 17 00:00:00 2001 From: "haitao.yao" Date: Tue, 26 Nov 2013 18:23:48 +0800 Subject: add http timeout for httpbroadcast --- .../main/scala/org/apache/spark/broadcast/HttpBroadcast.scala | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index 609464e38d..47db720416 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -19,6 +19,7 @@ package org.apache.spark.broadcast import java.io.{File, FileOutputStream, ObjectInputStream, OutputStream} import java.net.URL +import java.util.concurrent.TimeUnit import it.unimi.dsi.fastutil.io.FastBufferedInputStream import it.unimi.dsi.fastutil.io.FastBufferedOutputStream @@ -83,6 +84,8 @@ private object HttpBroadcast extends Logging { private val files = new TimeStampedHashSet[String] private val cleaner = new MetadataCleaner(MetadataCleanerType.HTTP_BROADCAST, cleanup) + private val httpReadTimeout = TimeUnit.MILLISECONDS.convert(5,TimeUnit.MINUTES).toInt + private lazy val compressionCodec = CompressionCodec.createCodec() def initialize(isDriver: Boolean) { @@ -138,10 +141,13 @@ private object HttpBroadcast extends Logging { def read[T](id: Long): T = { val url = serverUri + "/" + BroadcastBlockId(id).name val in = { + val httpConnection = new URL(url).openConnection() + httpConnection.setReadTimeout(httpReadTimeout) + val inputStream = httpConnection.getInputStream() if (compress) { - compressionCodec.compressedInputStream(new URL(url).openStream()) + compressionCodec.compressedInputStream(inputStream) } else { - new FastBufferedInputStream(new URL(url).openStream(), bufferSize) + new FastBufferedInputStream(inputStream, bufferSize) } } val ser = SparkEnv.get.serializer.newInstance() -- cgit v1.2.3 From 57579934f0454f258615c10e69ac2adafc5b9835 Mon Sep 17 00:00:00 2001 From: hhd Date: Mon, 25 Nov 2013 17:17:17 -0500 Subject: Emit warning when task size > 100KB --- .../scala/org/apache/spark/scheduler/DAGScheduler.scala | 15 +++++++++++++++ .../main/scala/org/apache/spark/scheduler/StageInfo.scala | 1 + .../main/scala/org/apache/spark/scheduler/TaskInfo.scala | 2 ++ .../spark/scheduler/cluster/ClusterTaskSetManager.scala | 1 + 4 files changed, 19 insertions(+) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 42bb3884c8..4457525ac8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -110,6 +110,9 @@ class DAGScheduler( // resubmit failed stages val POLL_TIMEOUT = 10L + // Warns the user if a stage contains a task with size greater than this value (in KB) + val TASK_SIZE_TO_WARN = 100 + private val eventProcessActor: ActorRef = env.actorSystem.actorOf(Props(new Actor { override def preStart() { context.system.scheduler.schedule(RESUBMIT_TIMEOUT milliseconds, RESUBMIT_TIMEOUT milliseconds) { @@ -430,6 +433,18 @@ class DAGScheduler( handleExecutorLost(execId) case BeginEvent(task, taskInfo) => + for ( + job <- idToActiveJob.get(task.stageId); + stage <- stageIdToStage.get(task.stageId); + stageInfo <- stageToInfos.get(stage) + ) { + if (taskInfo.serializedSize > TASK_SIZE_TO_WARN * 1024 && !stageInfo.emittedTaskSizeWarning) { + stageInfo.emittedTaskSizeWarning = true + logWarning(("Stage %d (%s) contains a task of very large " + + "size (%d KB). The maximum recommended task size is %d KB.").format( + task.stageId, stageInfo.name, taskInfo.serializedSize / 1024, TASK_SIZE_TO_WARN)) + } + } listenerBus.post(SparkListenerTaskStart(task, taskInfo)) case GettingResultEvent(task, taskInfo) => diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index 93599dfdc8..e9f2198a00 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -33,4 +33,5 @@ class StageInfo( val name = stage.name val numPartitions = stage.numPartitions val numTasks = stage.numTasks + var emittedTaskSizeWarning = false } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index 4bae26f3a6..3c22edd524 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -46,6 +46,8 @@ class TaskInfo( var failed = false + var serializedSize: Int = 0 + def markGettingResult(time: Long = System.currentTimeMillis) { gettingResultTime = time } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala index 4c5eca8537..8884ea85a3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala @@ -377,6 +377,7 @@ private[spark] class ClusterTaskSetManager( logInfo("Serialized task %s:%d as %d bytes in %d ms".format( taskSet.id, index, serializedTask.limit, timeTaken)) val taskName = "task %s:%d".format(taskSet.id, index) + info.serializedSize = serializedTask.limit if (taskAttempts(index).size == 1) taskStarted(task,info) return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask)) -- cgit v1.2.3 From 18def5d6f20b33c946f9b8b2cea8cfb6848dcc34 Mon Sep 17 00:00:00 2001 From: "Lian, Cheng" Date: Thu, 28 Nov 2013 17:46:06 +0800 Subject: Bugfix: SPARK-965 & SPARK-966 SPARK-965: https://spark-project.atlassian.net/browse/SPARK-965 SPARK-966: https://spark-project.atlassian.net/browse/SPARK-966 * Add back DAGScheduler.start(), eventProcessActor is created and started here. Notice that function is only called by SparkContext. * Cancel the scheduled stage resubmission task when stopping eventProcessActor * Add a new DAGSchedulerEvent ResubmitFailedStages This event message is sent by the scheduled stage resubmission task to eventProcessActor. In this way, DAGScheduler.resubmitFailedStages is guaranteed to be executed from the same thread that runs DAGScheduler.processEvent. Please refer to discussion in SPARK-966 for details. --- .../main/scala/org/apache/spark/SparkContext.scala | 1 + .../org/apache/spark/scheduler/DAGScheduler.scala | 62 +++++++++++++--------- .../apache/spark/scheduler/DAGSchedulerEvent.scala | 2 + 3 files changed, 40 insertions(+), 25 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 3a80241daa..c314f01894 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -270,6 +270,7 @@ class SparkContext( taskScheduler.start() @volatile private[spark] var dagScheduler = new DAGScheduler(taskScheduler) + dagScheduler.start() ui.start() diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 4457525ac8..e2bf08c33f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -113,30 +113,7 @@ class DAGScheduler( // Warns the user if a stage contains a task with size greater than this value (in KB) val TASK_SIZE_TO_WARN = 100 - private val eventProcessActor: ActorRef = env.actorSystem.actorOf(Props(new Actor { - override def preStart() { - context.system.scheduler.schedule(RESUBMIT_TIMEOUT milliseconds, RESUBMIT_TIMEOUT milliseconds) { - if (failed.size > 0) { - resubmitFailedStages() - } - } - } - - /** - * The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure - * events and responds by launching tasks. This runs in a dedicated thread and receives events - * via the eventQueue. - */ - def receive = { - case event: DAGSchedulerEvent => - logDebug("Got event of type " + event.getClass.getName) - - if (!processEvent(event)) - submitWaitingStages() - else - context.stop(self) - } - })) + private var eventProcessActor: ActorRef = _ private[scheduler] val nextJobId = new AtomicInteger(0) @@ -177,6 +154,34 @@ class DAGScheduler( val metadataCleaner = new MetadataCleaner(MetadataCleanerType.DAG_SCHEDULER, this.cleanup) + def start() { + eventProcessActor = env.actorSystem.actorOf(Props(new Actor { + var resubmissionTask: Cancellable = _ + + override def preStart() { + resubmissionTask = context.system.scheduler.schedule( + RESUBMIT_TIMEOUT.millis, RESUBMIT_TIMEOUT.millis, self, ResubmitFailedStages) + } + + /** + * The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure + * events and responds by launching tasks. This runs in a dedicated thread and receives events + * via the eventQueue. + */ + def receive = { + case event: DAGSchedulerEvent => + logDebug("Got event of type " + event.getClass.getName) + + if (!processEvent(event)) { + submitWaitingStages() + } else { + resubmissionTask.cancel() + context.stop(self) + } + } + })) + } + def addSparkListener(listener: SparkListener) { listenerBus.addListener(listener) } @@ -457,6 +462,11 @@ class DAGScheduler( case TaskSetFailed(taskSet, reason) => abortStage(stageIdToStage(taskSet.stageId), reason) + case ResubmitFailedStages => + if (failed.size > 0) { + resubmitFailedStages() + } + case StopDAGScheduler => // Cancel any active jobs for (job <- activeJobs) { @@ -900,7 +910,9 @@ class DAGScheduler( } def stop() { - eventProcessActor ! StopDAGScheduler + if (eventProcessActor != null) { + eventProcessActor ! StopDAGScheduler + } metadataCleaner.cancel() taskSched.stop() } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index 708d221d60..5353cd24dc 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -73,4 +73,6 @@ private[scheduler] case class ExecutorLost(execId: String) extends DAGSchedulerE private[scheduler] case class TaskSetFailed(taskSet: TaskSet, reason: String) extends DAGSchedulerEvent +private[scheduler] case object ResubmitFailedStages extends DAGSchedulerEvent + private[scheduler] case object StopDAGScheduler extends DAGSchedulerEvent -- cgit v1.2.3 From 37f161cf6b19eb5b70a251340df0caf21afed84a Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Thu, 28 Nov 2013 20:36:18 -0800 Subject: Re-enable zk:// urls for Mesos SparkContexts This was broken in PR #71 when we explicitly disallow anything that didn't fit a mesos:// url. Although it is not really clear that a zk:// url should match Mesos, it is what the docs say and it is necessary for backwards compatibility. --- core/src/main/scala/org/apache/spark/SparkContext.scala | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 3a80241daa..cf1fd497f0 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -162,8 +162,8 @@ class SparkContext( val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r // Regular expression for connecting to Spark deploy clusters val SPARK_REGEX = """spark://(.*)""".r - // Regular expression for connection to Mesos cluster - val MESOS_REGEX = """mesos://(.*)""".r + // Regular expression for connection to Mesos cluster by mesos:// or zk:// url + val MESOS_REGEX = """(mesos|zk)://.*""".r // Regular expression for connection to Simr cluster val SIMR_REGEX = """simr://(.*)""".r @@ -251,14 +251,15 @@ class SparkContext( scheduler.initialize(backend) scheduler - case MESOS_REGEX(mesosUrl) => + case mesosUrl @ MESOS_REGEX(_) => MesosNativeLibrary.load() val scheduler = new ClusterScheduler(this) val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean + val url = mesosUrl.stripPrefix("mesos://") // strip scheme from raw Mesos URLs val backend = if (coarseGrained) { - new CoarseMesosSchedulerBackend(scheduler, this, mesosUrl, appName) + new CoarseMesosSchedulerBackend(scheduler, this, url, appName) } else { - new MesosSchedulerBackend(scheduler, this, mesosUrl, appName) + new MesosSchedulerBackend(scheduler, this, url, appName) } scheduler.initialize(backend) scheduler -- cgit v1.2.3 From 081a0b6861321d262a82166bc1df61959e9c6387 Mon Sep 17 00:00:00 2001 From: Aaron Davidson Date: Thu, 28 Nov 2013 20:39:10 -0800 Subject: Add unit test for SparkContext scheduler creation Since YARN and Mesos are not necessarily available in the system, they are allowed to pass as long as the YARN/Mesos code paths are exercised. --- .../main/scala/org/apache/spark/SparkContext.scala | 234 +++++++++++---------- .../spark/scheduler/local/LocalScheduler.scala | 2 +- .../spark/SparkContextSchedulerCreationSuite.scala | 135 ++++++++++++ 3 files changed, 255 insertions(+), 116 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index cf1fd497f0..1eb00e79e1 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -153,121 +153,7 @@ class SparkContext( executorEnvs("SPARK_USER") = sparkUser // Create and start the scheduler - private[spark] var taskScheduler: TaskScheduler = { - // Regular expression used for local[N] master format - val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r - // Regular expression for local[N, maxRetries], used in tests with failing tasks - val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r - // Regular expression for simulating a Spark cluster of [N, cores, memory] locally - val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r - // Regular expression for connecting to Spark deploy clusters - val SPARK_REGEX = """spark://(.*)""".r - // Regular expression for connection to Mesos cluster by mesos:// or zk:// url - val MESOS_REGEX = """(mesos|zk)://.*""".r - // Regular expression for connection to Simr cluster - val SIMR_REGEX = """simr://(.*)""".r - - master match { - case "local" => - new LocalScheduler(1, 0, this) - - case LOCAL_N_REGEX(threads) => - new LocalScheduler(threads.toInt, 0, this) - - case LOCAL_N_FAILURES_REGEX(threads, maxFailures) => - new LocalScheduler(threads.toInt, maxFailures.toInt, this) - - case SPARK_REGEX(sparkUrl) => - val scheduler = new ClusterScheduler(this) - val masterUrls = sparkUrl.split(",").map("spark://" + _) - val backend = new SparkDeploySchedulerBackend(scheduler, this, masterUrls, appName) - scheduler.initialize(backend) - scheduler - - case SIMR_REGEX(simrUrl) => - val scheduler = new ClusterScheduler(this) - val backend = new SimrSchedulerBackend(scheduler, this, simrUrl) - scheduler.initialize(backend) - scheduler - - case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) => - // Check to make sure memory requested <= memoryPerSlave. Otherwise Spark will just hang. - val memoryPerSlaveInt = memoryPerSlave.toInt - if (SparkContext.executorMemoryRequested > memoryPerSlaveInt) { - throw new SparkException( - "Asked to launch cluster with %d MB RAM / worker but requested %d MB/worker".format( - memoryPerSlaveInt, SparkContext.executorMemoryRequested)) - } - - val scheduler = new ClusterScheduler(this) - val localCluster = new LocalSparkCluster( - numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt) - val masterUrls = localCluster.start() - val backend = new SparkDeploySchedulerBackend(scheduler, this, masterUrls, appName) - scheduler.initialize(backend) - backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => { - localCluster.stop() - } - scheduler - - case "yarn-standalone" => - val scheduler = try { - val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClusterScheduler") - val cons = clazz.getConstructor(classOf[SparkContext]) - cons.newInstance(this).asInstanceOf[ClusterScheduler] - } catch { - // TODO: Enumerate the exact reasons why it can fail - // But irrespective of it, it means we cannot proceed ! - case th: Throwable => { - throw new SparkException("YARN mode not available ?", th) - } - } - val backend = new CoarseGrainedSchedulerBackend(scheduler, this.env.actorSystem) - scheduler.initialize(backend) - scheduler - - case "yarn-client" => - val scheduler = try { - val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClientClusterScheduler") - val cons = clazz.getConstructor(classOf[SparkContext]) - cons.newInstance(this).asInstanceOf[ClusterScheduler] - - } catch { - case th: Throwable => { - throw new SparkException("YARN mode not available ?", th) - } - } - - val backend = try { - val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend") - val cons = clazz.getConstructor(classOf[ClusterScheduler], classOf[SparkContext]) - cons.newInstance(scheduler, this).asInstanceOf[CoarseGrainedSchedulerBackend] - } catch { - case th: Throwable => { - throw new SparkException("YARN mode not available ?", th) - } - } - - scheduler.initialize(backend) - scheduler - - case mesosUrl @ MESOS_REGEX(_) => - MesosNativeLibrary.load() - val scheduler = new ClusterScheduler(this) - val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean - val url = mesosUrl.stripPrefix("mesos://") // strip scheme from raw Mesos URLs - val backend = if (coarseGrained) { - new CoarseMesosSchedulerBackend(scheduler, this, url, appName) - } else { - new MesosSchedulerBackend(scheduler, this, url, appName) - } - scheduler.initialize(backend) - scheduler - - case _ => - throw new SparkException("Could not parse Master URL: '" + master + "'") - } - } + private[spark] var taskScheduler = SparkContext.createTaskScheduler(this, master, appName) taskScheduler.start() @volatile private[spark] var dagScheduler = new DAGScheduler(taskScheduler) @@ -1137,6 +1023,124 @@ object SparkContext { .map(Utils.memoryStringToMb) .getOrElse(512) } + + // Creates a task scheduler based on a given master URL. Extracted for testing. + private + def createTaskScheduler(sc: SparkContext, master: String, appName: String): TaskScheduler = { + // Regular expression used for local[N] master format + val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r + // Regular expression for local[N, maxRetries], used in tests with failing tasks + val LOCAL_N_FAILURES_REGEX = """local\[([0-9]+)\s*,\s*([0-9]+)\]""".r + // Regular expression for simulating a Spark cluster of [N, cores, memory] locally + val LOCAL_CLUSTER_REGEX = """local-cluster\[\s*([0-9]+)\s*,\s*([0-9]+)\s*,\s*([0-9]+)\s*]""".r + // Regular expression for connecting to Spark deploy clusters + val SPARK_REGEX = """spark://(.*)""".r + // Regular expression for connection to Mesos cluster by mesos:// or zk:// url + val MESOS_REGEX = """(mesos|zk)://.*""".r + // Regular expression for connection to Simr cluster + val SIMR_REGEX = """simr://(.*)""".r + + master match { + case "local" => + new LocalScheduler(1, 0, sc) + + case LOCAL_N_REGEX(threads) => + new LocalScheduler(threads.toInt, 0, sc) + + case LOCAL_N_FAILURES_REGEX(threads, maxFailures) => + new LocalScheduler(threads.toInt, maxFailures.toInt, sc) + + case SPARK_REGEX(sparkUrl) => + val scheduler = new ClusterScheduler(sc) + val masterUrls = sparkUrl.split(",").map("spark://" + _) + val backend = new SparkDeploySchedulerBackend(scheduler, sc, masterUrls, appName) + scheduler.initialize(backend) + scheduler + + case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) => + // Check to make sure memory requested <= memoryPerSlave. Otherwise Spark will just hang. + val memoryPerSlaveInt = memoryPerSlave.toInt + if (SparkContext.executorMemoryRequested > memoryPerSlaveInt) { + throw new SparkException( + "Asked to launch cluster with %d MB RAM / worker but requested %d MB/worker".format( + memoryPerSlaveInt, SparkContext.executorMemoryRequested)) + } + + val scheduler = new ClusterScheduler(sc) + val localCluster = new LocalSparkCluster( + numSlaves.toInt, coresPerSlave.toInt, memoryPerSlaveInt) + val masterUrls = localCluster.start() + val backend = new SparkDeploySchedulerBackend(scheduler, sc, masterUrls, appName) + scheduler.initialize(backend) + backend.shutdownCallback = (backend: SparkDeploySchedulerBackend) => { + localCluster.stop() + } + scheduler + + case "yarn-standalone" => + val scheduler = try { + val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClusterScheduler") + val cons = clazz.getConstructor(classOf[SparkContext]) + cons.newInstance(sc).asInstanceOf[ClusterScheduler] + } catch { + // TODO: Enumerate the exact reasons why it can fail + // But irrespective of it, it means we cannot proceed ! + case th: Throwable => { + throw new SparkException("YARN mode not available ?", th) + } + } + val backend = new CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) + scheduler.initialize(backend) + scheduler + + case "yarn-client" => + val scheduler = try { + val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClientClusterScheduler") + val cons = clazz.getConstructor(classOf[SparkContext]) + cons.newInstance(sc).asInstanceOf[ClusterScheduler] + + } catch { + case th: Throwable => { + throw new SparkException("YARN mode not available ?", th) + } + } + + val backend = try { + val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend") + val cons = clazz.getConstructor(classOf[ClusterScheduler], classOf[SparkContext]) + cons.newInstance(scheduler, sc).asInstanceOf[CoarseGrainedSchedulerBackend] + } catch { + case th: Throwable => { + throw new SparkException("YARN mode not available ?", th) + } + } + + scheduler.initialize(backend) + scheduler + + case mesosUrl @ MESOS_REGEX(_) => + MesosNativeLibrary.load() + val scheduler = new ClusterScheduler(sc) + val coarseGrained = System.getProperty("spark.mesos.coarse", "false").toBoolean + val url = mesosUrl.stripPrefix("mesos://") // strip scheme from raw Mesos URLs + val backend = if (coarseGrained) { + new CoarseMesosSchedulerBackend(scheduler, sc, url, appName) + } else { + new MesosSchedulerBackend(scheduler, sc, url, appName) + } + scheduler.initialize(backend) + scheduler + + case SIMR_REGEX(simrUrl) => + val scheduler = new ClusterScheduler(sc) + val backend = new SimrSchedulerBackend(scheduler, sc, simrUrl) + scheduler.initialize(backend) + scheduler + + case _ => + throw new SparkException("Could not parse Master URL: '" + master + "'") + } + } } /** diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala index 2699f0b33e..5af51164f7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala @@ -74,7 +74,7 @@ class LocalActor(localScheduler: LocalScheduler, private var freeCores: Int) } } -private[spark] class LocalScheduler(threads: Int, val maxFailures: Int, val sc: SparkContext) +private[spark] class LocalScheduler(val threads: Int, val maxFailures: Int, val sc: SparkContext) extends TaskScheduler with ExecutorBackend with Logging { diff --git a/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala new file mode 100644 index 0000000000..61d6163659 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/SparkContextSchedulerCreationSuite.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark + +import org.scalatest.{FunSuite, PrivateMethodTester} + +import org.apache.spark.scheduler.TaskScheduler +import org.apache.spark.scheduler.cluster.{ClusterScheduler, SimrSchedulerBackend, SparkDeploySchedulerBackend} +import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend} +import org.apache.spark.scheduler.local.LocalScheduler + +class SparkContextSchedulerCreationSuite + extends FunSuite with PrivateMethodTester with LocalSparkContext with Logging { + + def createTaskScheduler(master: String): TaskScheduler = { + // Create local SparkContext to setup a SparkEnv. We don't actually want to start() the + // real schedulers, so we don't want to create a full SparkContext with the desired scheduler. + sc = new SparkContext("local", "test") + val createTaskSchedulerMethod = PrivateMethod[TaskScheduler]('createTaskScheduler) + SparkContext invokePrivate createTaskSchedulerMethod(sc, master, "test") + } + + test("bad-master") { + val e = intercept[SparkException] { + createTaskScheduler("localhost:1234") + } + assert(e.getMessage.contains("Could not parse Master URL")) + } + + test("local") { + createTaskScheduler("local") match { + case s: LocalScheduler => + assert(s.threads === 1) + assert(s.maxFailures === 0) + case _ => fail() + } + } + + test("local-n") { + createTaskScheduler("local[5]") match { + case s: LocalScheduler => + assert(s.threads === 5) + assert(s.maxFailures === 0) + case _ => fail() + } + } + + test("local-n-failures") { + createTaskScheduler("local[4, 2]") match { + case s: LocalScheduler => + assert(s.threads === 4) + assert(s.maxFailures === 2) + case _ => fail() + } + } + + test("simr") { + createTaskScheduler("simr://uri") match { + case s: ClusterScheduler => + assert(s.backend.isInstanceOf[SimrSchedulerBackend]) + case _ => fail() + } + } + + test("local-cluster") { + createTaskScheduler("local-cluster[3, 14, 512]") match { + case s: ClusterScheduler => + assert(s.backend.isInstanceOf[SparkDeploySchedulerBackend]) + case _ => fail() + } + } + + def testYarn(master: String, expectedClassName: String) { + try { + createTaskScheduler(master) match { + case s: ClusterScheduler => + assert(s.getClass === Class.forName(expectedClassName)) + case _ => fail() + } + } catch { + case e: SparkException => + assert(e.getMessage.contains("YARN mode not available")) + logWarning("YARN not available, could not test actual YARN scheduler creation") + case e: Throwable => fail(e) + } + } + test("yarn-standalone") { + testYarn("yarn-standalone", "org.apache.spark.scheduler.cluster.YarnClusterScheduler") + } + test("yarn-client") { + testYarn("yarn-client", "org.apache.spark.scheduler.cluster.YarnClientClusterScheduler") + } + + def testMesos(master: String, expectedClass: Class[_]) { + try { + createTaskScheduler(master) match { + case s: ClusterScheduler => + assert(s.backend.getClass === expectedClass) + case _ => fail() + } + } catch { + case e: UnsatisfiedLinkError => + assert(e.getMessage.contains("no mesos in")) + logWarning("Mesos not available, could not test actual Mesos scheduler creation") + case e: Throwable => fail(e) + } + } + test("mesos fine-grained") { + System.setProperty("spark.mesos.coarse", "false") + testMesos("mesos://localhost:1234", classOf[MesosSchedulerBackend]) + } + test("mesos coarse-grained") { + System.setProperty("spark.mesos.coarse", "true") + testMesos("mesos://localhost:1234", classOf[CoarseMesosSchedulerBackend]) + } + test("mesos with zookeeper") { + System.setProperty("spark.mesos.coarse", "false") + testMesos("zk://localhost:1234,localhost:2345", classOf[MesosSchedulerBackend]) + } +} -- cgit v1.2.3 From 1e25086009ff6421790609e406d00e1b978d6dbe Mon Sep 17 00:00:00 2001 From: "Lian, Cheng" Date: Fri, 29 Nov 2013 15:56:47 +0800 Subject: Updated some inline comments in DAGScheduler --- .../org/apache/spark/scheduler/DAGScheduler.scala | 31 ++++++++++++++++++---- 1 file changed, 26 insertions(+), 5 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index e2bf08c33f..08cf76325b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -154,24 +154,43 @@ class DAGScheduler( val metadataCleaner = new MetadataCleaner(MetadataCleanerType.DAG_SCHEDULER, this.cleanup) + /** + * Starts the event processing actor. The actor has two responsibilities: + * + * 1. Waits for events like job submission, task finished, task failure etc., and calls + * [[org.apache.spark.scheduler.DAGScheduler.processEvent()]] to process them. + * 2. Schedules a periodical task to resubmit failed stages. + * + * NOTE: the actor cannot be started in the constructor, because the periodical task references + * some internal states of the enclosing [[org.apache.spark.scheduler.DAGScheduler]] object, thus + * cannot be scheduled until the [[org.apache.spark.scheduler.DAGScheduler]] is fully constructed. + */ def start() { eventProcessActor = env.actorSystem.actorOf(Props(new Actor { var resubmissionTask: Cancellable = _ override def preStart() { + /** + * A message is sent to the actor itself periodically to remind the actor to resubmit failed + * stages. In this way, stage resubmission can be done within the same thread context of + * other event processing logic to avoid unnecessary synchronization overhead. + */ resubmissionTask = context.system.scheduler.schedule( RESUBMIT_TIMEOUT.millis, RESUBMIT_TIMEOUT.millis, self, ResubmitFailedStages) } /** - * The main event loop of the DAG scheduler, which waits for new-job / task-finished / failure - * events and responds by launching tasks. This runs in a dedicated thread and receives events - * via the eventQueue. + * The main event loop of the DAG scheduler. */ def receive = { case event: DAGSchedulerEvent => logDebug("Got event of type " + event.getClass.getName) + /** + * All events are forwarded to `processEvent()`, so that the event processing logic can + * easily tested without starting a dedicated actor. Please refer to `DAGSchedulerSuite` + * for details. + */ if (!processEvent(event)) { submitWaitingStages() } else { @@ -383,8 +402,10 @@ class DAGScheduler( } /** - * Process one event retrieved from the event queue. - * Returns true if we should stop the event loop. + * Process one event retrieved from the event processing actor. + * + * @param event The event to be processed. + * @return `true` if we should stop the event loop. */ private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = { event match { -- cgit v1.2.3 From 4a1d966e26e56fc5d42a828f414b4eca433c3a22 Mon Sep 17 00:00:00 2001 From: "Lian, Cheng" Date: Fri, 29 Nov 2013 16:02:58 +0800 Subject: More comments --- core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala | 3 +++ 1 file changed, 3 insertions(+) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 08cf76325b..bc37a70e98 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -167,6 +167,9 @@ class DAGScheduler( */ def start() { eventProcessActor = env.actorSystem.actorOf(Props(new Actor { + /** + * A handle to the periodical task, used to cancel the task when the actor is stopped. + */ var resubmissionTask: Cancellable = _ override def preStart() { -- cgit v1.2.3 From 4d53830eb79174cfd9641f6342727bc980d5c3e0 Mon Sep 17 00:00:00 2001 From: Sundeep Narravula Date: Sat, 30 Nov 2013 16:18:12 -0800 Subject: Scheduler quits when createStage fails. The current scheduler thread does not handle exceptions from createStage stage while launching new jobs. The thread fails on any exception that gets triggered at that level, leaving the cluster hanging with no schduler. --- .../main/scala/org/apache/spark/scheduler/DAGScheduler.scala | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 4457525ac8..f6a4482679 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -384,7 +384,15 @@ class DAGScheduler( private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = { event match { case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) => - val finalStage = newStage(rdd, partitions.size, None, jobId, Some(callSite)) + var finalStage:Stage = null + try { + finalStage = newStage(rdd, partitions.size, None, jobId, Some(callSite)) + } catch { + case e: Exception => + logWarning("Creating new stage failed due to exception - job: " + jobId ) + listener.jobFailed(e) + return false + } val job = new ActiveJob(jobId, finalStage, func, partitions, callSite, listener, properties) clearCacheLocs() logInfo("Got job " + job.jobId + " (" + callSite + ") with " + partitions.length + -- cgit v1.2.3 From 9cf7f31e4d4e542b88b6a474bdf08d07fdd3652c Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sat, 30 Nov 2013 18:07:36 -0800 Subject: Memoize preferred locations in ZippedPartitionsBaseRDD so preferred location computation doesn't lead to exponential explosion. (cherry picked from commit e36fe55a031d2c01c9d7c5d85965951c681a0c74) Signed-off-by: Reynold Xin --- .../org/apache/spark/rdd/ZippedPartitionsRDD.scala | 27 +++++++++------------- 1 file changed, 11 insertions(+), 16 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala index faeb316664..a97d2a01c8 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala @@ -22,7 +22,8 @@ import java.io.{ObjectOutputStream, IOException} private[spark] class ZippedPartitionsPartition( idx: Int, - @transient rdds: Seq[RDD[_]]) + @transient rdds: Seq[RDD[_]], + @transient val preferredLocations: Seq[String]) extends Partition { override val index: Int = idx @@ -47,27 +48,21 @@ abstract class ZippedPartitionsBaseRDD[V: ClassManifest]( if (preservesPartitioning) firstParent[Any].partitioner else None override def getPartitions: Array[Partition] = { - val sizes = rdds.map(x => x.partitions.size) - if (!sizes.forall(x => x == sizes(0))) { + val numParts = rdds.head.partitions.size + if (!rdds.forall(rdd => rdd.partitions.size == numParts)) { throw new IllegalArgumentException("Can't zip RDDs with unequal numbers of partitions") } - val array = new Array[Partition](sizes(0)) - for (i <- 0 until sizes(0)) { - array(i) = new ZippedPartitionsPartition(i, rdds) + Array.tabulate[Partition](numParts) { i => + val prefs = rdds.map(rdd => rdd.preferredLocations(rdd.partitions(i))) + // Check whether there are any hosts that match all RDDs; otherwise return the union + val exactMatchLocations = prefs.reduce((x, y) => x.intersect(y)) + val locs = if (!exactMatchLocations.isEmpty) exactMatchLocations else prefs.flatten.distinct + new ZippedPartitionsPartition(i, rdds, locs) } - array } override def getPreferredLocations(s: Partition): Seq[String] = { - val parts = s.asInstanceOf[ZippedPartitionsPartition].partitions - val prefs = rdds.zip(parts).map { case (rdd, p) => rdd.preferredLocations(p) } - // Check whether there are any hosts that match all RDDs; otherwise return the union - val exactMatchLocations = prefs.reduce((x, y) => x.intersect(y)) - if (!exactMatchLocations.isEmpty) { - exactMatchLocations - } else { - prefs.flatten.distinct - } + s.asInstanceOf[ZippedPartitionsPartition].preferredLocations } override def clearDependencies() { -- cgit v1.2.3 From be3ea2394fa2e626fb6b5f2cd46e7156016c9b3f Mon Sep 17 00:00:00 2001 From: Sundeep Narravula Date: Sun, 1 Dec 2013 00:50:34 -0800 Subject: Log exception in scheduler in addition to passing it to the caller. Code Styling changes. --- core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index f6a4482679..915918630b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -384,12 +384,14 @@ class DAGScheduler( private[scheduler] def processEvent(event: DAGSchedulerEvent): Boolean = { event match { case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) => - var finalStage:Stage = null + var finalStage: Stage = null try { + // New stage creation at times and if its not protected, the scheduler thread is killed. + // e.g. it can fail when jobs are run on HadoopRDD whose underlying hdfs files have been deleted finalStage = newStage(rdd, partitions.size, None, jobId, Some(callSite)) } catch { case e: Exception => - logWarning("Creating new stage failed due to exception - job: " + jobId ) + logWarning("Creating new stage failed due to exception - job: " + jobId, e) listener.jobFailed(e) return false } -- cgit v1.2.3 From 58b3aff9a871a38446aacc2d60b65199d44e56bb Mon Sep 17 00:00:00 2001 From: Kay Ousterhout Date: Mon, 2 Dec 2013 20:30:03 -0800 Subject: Fixed problem with scheduler delay --- core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala | 11 +++++++---- 1 file changed, 7 insertions(+), 4 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index fc8c334cb5..8deb495068 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -120,11 +120,14 @@ private[spark] class StagePage(parent: JobProgressUI) { // machine and to send back the result (but not the time to fetch the task result, // if it needed to be fetched from the block manager on the worker). val schedulerDelays = validTasks.map{case (info, metrics, exception) => - if (info.gettingResultTime > 0) { - (info.gettingResultTime - info.launchTime).toDouble - } else { - (info.finishTime - info.launchTime).toDouble + val totalExecutionTime = { + if (info.gettingResultTime > 0) { + (info.gettingResultTime - info.launchTime).toDouble + } else { + (info.finishTime - info.launchTime).toDouble + } } + totalExecutionTime - metrics.get.executorRunTime } val schedulerDelayQuantiles = ("Scheduler delay" +: Distribution(schedulerDelays).get.getQuantiles().map( -- cgit v1.2.3 From e34b4693d380c39d4a142515e416588e63d06297 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 2 Dec 2013 21:24:44 -0800 Subject: Mark partitioner, name, and generator field in RDD as @transient. --- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 5b1285307d..96e4841c78 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -101,7 +101,7 @@ abstract class RDD[T: ClassManifest]( protected def getPreferredLocations(split: Partition): Seq[String] = Nil /** Optionally overridden by subclasses to specify how they are partitioned. */ - val partitioner: Option[Partitioner] = None + @transient val partitioner: Option[Partitioner] = None // ======================================================================= // Methods and fields available on all RDDs @@ -114,7 +114,7 @@ abstract class RDD[T: ClassManifest]( val id: Int = sc.newRddId() /** A friendly name for this RDD */ - var name: String = null + @transient var name: String = null /** Assign a name to this RDD */ def setName(_name: String) = { @@ -123,7 +123,7 @@ abstract class RDD[T: ClassManifest]( } /** User-defined generator of this RDD*/ - var generator = Utils.getCallSiteInfo.firstUserClass + @transient var generator = Utils.getCallSiteInfo.firstUserClass /** Reset generator*/ def setGenerator(_generator: String) = { -- cgit v1.2.3 From 51458ab4a16a2d365f5de756d2fac942b766feca Mon Sep 17 00:00:00 2001 From: Mark Hamstra Date: Mon, 11 Nov 2013 16:06:12 -0800 Subject: Added stageId <--> jobId mapping in DAGScheduler ...and make sure that DAGScheduler data structures are cleaned up on job completion. Initial effort and discussion at https://github.com/mesos/spark/pull/842 --- .../scala/org/apache/spark/MapOutputTracker.scala | 8 +- .../org/apache/spark/scheduler/DAGScheduler.scala | 277 ++++++++++++++++----- .../apache/spark/scheduler/DAGSchedulerEvent.scala | 5 +- .../org/apache/spark/scheduler/SparkListener.scala | 2 +- .../spark/scheduler/cluster/ClusterScheduler.scala | 4 +- .../scheduler/cluster/ClusterTaskSetManager.scala | 2 +- .../spark/scheduler/local/LocalScheduler.scala | 27 +- .../org/apache/spark/JobCancellationSuite.scala | 4 +- .../apache/spark/scheduler/DAGSchedulerSuite.scala | 45 +++- 9 files changed, 286 insertions(+), 88 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 5e465fa22c..b4d0b7017c 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -244,12 +244,12 @@ private[spark] class MapOutputTrackerMaster extends MapOutputTracker { case Some(bytes) => return bytes case None => - statuses = mapStatuses(shuffleId) + statuses = mapStatuses.getOrElse(shuffleId, Array[MapStatus]()) epochGotten = epoch } } // If we got here, we failed to find the serialized locations in the cache, so we pulled - // out a snapshot of the locations as "locs"; let's serialize and return that + // out a snapshot of the locations as "statuses"; let's serialize and return that val bytes = MapOutputTracker.serializeMapStatuses(statuses) logInfo("Size of output statuses for shuffle %d is %d bytes".format(shuffleId, bytes.length)) // Add them into the table only if the epoch hasn't changed while we were working @@ -274,6 +274,10 @@ private[spark] class MapOutputTrackerMaster extends MapOutputTracker { override def updateEpoch(newEpoch: Long) { // This might be called on the MapOutputTrackerMaster if we're running in local mode. } + + def has(shuffleId: Int): Boolean = { + cachedSerializedStatuses.get(shuffleId).isDefined || mapStatuses.contains(shuffleId) + } } private[spark] object MapOutputTracker { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index a785a16a36..10417b9343 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -121,9 +121,13 @@ class DAGScheduler( private val nextStageId = new AtomicInteger(0) - private val stageIdToStage = new TimeStampedHashMap[Int, Stage] + private[scheduler] val jobIdToStageIds = new TimeStampedHashMap[Int, HashSet[Int]] - private val shuffleToMapStage = new TimeStampedHashMap[Int, Stage] + private[scheduler] val stageIdToJobIds = new TimeStampedHashMap[Int, HashSet[Int]] + + private[scheduler] val stageIdToStage = new TimeStampedHashMap[Int, Stage] + + private[scheduler] val shuffleToMapStage = new TimeStampedHashMap[Int, Stage] private[spark] val stageToInfos = new TimeStampedHashMap[Stage, StageInfo] @@ -232,7 +236,7 @@ class DAGScheduler( shuffleToMapStage.get(shuffleDep.shuffleId) match { case Some(stage) => stage case None => - val stage = newStage(shuffleDep.rdd, shuffleDep.rdd.partitions.size, Some(shuffleDep), jobId) + val stage = newOrUsedStage(shuffleDep.rdd, shuffleDep.rdd.partitions.size, shuffleDep, jobId) shuffleToMapStage(shuffleDep.shuffleId) = stage stage } @@ -241,7 +245,8 @@ class DAGScheduler( /** * Create a Stage for the given RDD, either as a shuffle map stage (for a ShuffleDependency) or * as a result stage for the final RDD used directly in an action. The stage will also be - * associated with the provided jobId. + * associated with the provided jobId.. Shuffle map stages, whose shuffleId may have previously + * been registered in the MapOutputTracker, should be (re)-created using newOrUsedStage. */ private def newStage( rdd: RDD[_], @@ -251,20 +256,44 @@ class DAGScheduler( callSite: Option[String] = None) : Stage = { - if (shuffleDep != None) { - // Kind of ugly: need to register RDDs with the cache and map output tracker here - // since we can't do it in the RDD constructor because # of partitions is unknown - logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")") - mapOutputTracker.registerShuffle(shuffleDep.get.shuffleId, rdd.partitions.size) - } val id = nextStageId.getAndIncrement() val stage = new Stage(id, rdd, numTasks, shuffleDep, getParentStages(rdd, jobId), jobId, callSite) stageIdToStage(id) = stage + registerJobIdWithStages(jobId, stage) stageToInfos(stage) = new StageInfo(stage) stage } + /** + * Create a shuffle map Stage for the given RDD. The stage will also be associated with the + * provided jobId. If a stage for the shuffleId existed previously so that the shuffleId is + * present in the MapOutputTracker, then the number and location of available outputs are + * recovered from the MapOutputTracker + */ + private def newOrUsedStage( + rdd: RDD[_], + numTasks: Int, + shuffleDep: ShuffleDependency[_,_], + jobId: Int, + callSite: Option[String] = None) + : Stage = + { + val stage = newStage(rdd, numTasks, Some(shuffleDep), jobId, callSite) + if (mapOutputTracker.has(shuffleDep.shuffleId)) { + val serLocs = mapOutputTracker.getSerializedMapOutputStatuses(shuffleDep.shuffleId) + val locs = MapOutputTracker.deserializeMapStatuses(serLocs) + for (i <- 0 until locs.size) stage.outputLocs(i) = List(locs(i)) + stage.numAvailableOutputs = locs.size + } else { + // Kind of ugly: need to register RDDs with the cache and map output tracker here + // since we can't do it in the RDD constructor because # of partitions is unknown + logInfo("Registering RDD " + rdd.id + " (" + rdd.origin + ")") + mapOutputTracker.registerShuffle(shuffleDep.shuffleId, rdd.partitions.size) + } + stage + } + /** * Get or create the list of parent stages for a given RDD. The stages will be assigned the * provided jobId if they haven't already been created with a lower jobId. @@ -316,6 +345,91 @@ class DAGScheduler( missing.toList } + /** + * Registers the given jobId among the jobs that need the given stage and + * all of that stage's ancestors. + */ + private def registerJobIdWithStages(jobId: Int, stage: Stage) { + def registerJobIdWithStageList(stages: List[Stage]) { + if (!stages.isEmpty) { + val s = stages.head + stageIdToJobIds.getOrElseUpdate(s.id, new HashSet[Int]()) += jobId + val parents = getParentStages(s.rdd, jobId) + val parentsWithoutThisJobId = parents.filter(p => !stageIdToJobIds.get(p.id).exists(_.contains(jobId))) + registerJobIdWithStageList(parentsWithoutThisJobId ++ stages.tail) + } + } + registerJobIdWithStageList(List(stage)) + } + + private def jobIdToStageIdsAdd(jobId: Int) { + val stageSet = jobIdToStageIds.getOrElseUpdate(jobId, new HashSet[Int]()) + stageIdToJobIds.foreach { case (stageId, jobSet) => + if (jobSet.contains(jobId)) { + stageSet += stageId + } + } + } + + // Removes job and applies p to any stages that aren't needed by any other jobs + private def forIndependentStagesOfRemovedJob(jobId: Int)(p: Int => Unit) { + val registeredStages = jobIdToStageIds(jobId) + if (registeredStages.isEmpty) { + logError("No stages registered for job " + jobId) + } else { + stageIdToJobIds.filterKeys(stageId => registeredStages.contains(stageId)).foreach { + case (stageId, jobSet) => + if (!jobSet.contains(jobId)) { + logError("Job %d not registered for stage %d even though that stage was registered for the job" + .format(jobId, stageId)) + } else { + jobSet -= jobId + if ((jobSet - jobId).isEmpty) { // no other job needs this stage + p(stageId) + } + } + } + } + } + + private def removeStage(stageId: Int) { + // data structures based on Stage + stageIdToStage.get(stageId).foreach { s => + if (running.contains(s)) { + logDebug("Removing running stage %d".format(stageId)) + running -= s + } + stageToInfos -= s + shuffleToMapStage.keys.filter(shuffleToMapStage(_) == s).foreach(shuffleToMapStage.remove(_)) + if (pendingTasks.contains(s) && !pendingTasks(s).isEmpty) { + logDebug("Removing pending status for stage %d".format(stageId)) + } + pendingTasks -= s + if (waiting.contains(s)) { + logDebug("Removing stage %d from waiting set.".format(stageId)) + waiting -= s + } + if (failed.contains(s)) { + logDebug("Removing stage %d from failed set.".format(stageId)) + failed -= s + } + } + // data structures based on StageId + stageIdToStage -= stageId + stageIdToJobIds -= stageId + + logDebug("After removal of stage %d, remaining stages = %d".format(stageId, stageIdToStage.size)) + } + + private def jobIdToStageIdsRemove(jobId: Int) { + if (!jobIdToStageIds.contains(jobId)) { + logDebug("Trying to remove unregistered job " + jobId) + } else { + forIndependentStagesOfRemovedJob(jobId) { removeStage } + jobIdToStageIds -= jobId + } + } + /** * Submit a job to the job scheduler and get a JobWaiter object back. The JobWaiter object * can be used to block until the the job finishes executing or can be used to cancel the job. @@ -435,35 +549,33 @@ class DAGScheduler( // Compute very short actions like first() or take() with no parent stages locally. runLocally(job) } else { - listenerBus.post(SparkListenerJobStart(job, properties)) idToActiveJob(jobId) = job activeJobs += job resultStageToJob(finalStage) = job + jobIdToStageIdsAdd(jobId) + listenerBus.post(SparkListenerJobStart(job, jobIdToStageIds(jobId).toArray, properties)) submitStage(finalStage) } case JobCancelled(jobId) => - // Cancel a job: find all the running stages that are linked to this job, and cancel them. - running.filter(_.jobId == jobId).foreach { stage => - taskSched.cancelTasks(stage.id) - } + handleJobCancellation(jobId) + idToActiveJob.get(jobId).foreach(job => activeJobs -= job) + idToActiveJob -= jobId case JobGroupCancelled(groupId) => // Cancel all jobs belonging to this job group. // First finds all active jobs with this group id, and then kill stages for them. - val jobIds = activeJobs.filter(groupId == _.properties.get(SparkContext.SPARK_JOB_GROUP_ID)) - .map(_.jobId) - if (!jobIds.isEmpty) { - running.filter(stage => jobIds.contains(stage.jobId)).foreach { stage => - taskSched.cancelTasks(stage.id) - } - } + val activeInGroup = activeJobs.filter(groupId == _.properties.get(SparkContext.SPARK_JOB_GROUP_ID)) + val jobIds = activeInGroup.map(_.jobId) + jobIds.foreach { handleJobCancellation } + activeJobs -- activeInGroup + idToActiveJob -- jobIds case AllJobsCancelled => // Cancel all running jobs. - running.foreach { stage => - taskSched.cancelTasks(stage.id) - } + running.map(_.jobId).foreach { handleJobCancellation } + activeJobs.clear() + idToActiveJob.clear() case ExecutorGained(execId, host) => handleExecutorGained(execId, host) @@ -493,8 +605,13 @@ class DAGScheduler( listenerBus.post(SparkListenerTaskEnd(task, reason, taskInfo, taskMetrics)) handleTaskCompletion(completion) + case LocalJobCompleted(stage) => + stageIdToJobIds -= stage.id // clean up data structures that were populated for a local job, + stageIdToStage -= stage.id // but that won't get cleaned up via the normal paths through + stageToInfos -= stage // completion events or stage abort + case TaskSetFailed(taskSet, reason) => - abortStage(stageIdToStage(taskSet.stageId), reason) + stageIdToStage.get(taskSet.stageId).foreach { abortStage(_, reason) } case ResubmitFailedStages => if (failed.size > 0) { @@ -576,30 +693,52 @@ class DAGScheduler( } catch { case e: Exception => job.listener.jobFailed(e) + } finally { + eventQueue.put(LocalJobCompleted(job.finalStage)) + } + } + + /** Finds the earliest-created active job that needs the stage */ + // TODO: Probably should actually find among the active jobs that need this + // stage the one with the highest priority (highest-priority pool, earliest created). + // That should take care of at least part of the priority inversion problem with + // cross-job dependencies. + private def activeJobForStage(stage: Stage): Option[Int] = { + if (stageIdToJobIds.contains(stage.id)) { + val jobsThatUseStage: Array[Int] = stageIdToJobIds(stage.id).toArray.sorted + jobsThatUseStage.find(idToActiveJob.contains(_)) + } else { + None } } /** Submits stage, but first recursively submits any missing parents. */ private def submitStage(stage: Stage) { - logDebug("submitStage(" + stage + ")") - if (!waiting(stage) && !running(stage) && !failed(stage)) { - val missing = getMissingParentStages(stage).sortBy(_.id) - logDebug("missing: " + missing) - if (missing == Nil) { - logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents") - submitMissingTasks(stage) - running += stage - } else { - for (parent <- missing) { - submitStage(parent) + val jobId = activeJobForStage(stage) + if (jobId.isDefined) { + logDebug("submitStage(" + stage + ")") + if (!waiting(stage) && !running(stage) && !failed(stage)) { + val missing = getMissingParentStages(stage).sortBy(_.id) + logDebug("missing: " + missing) + if (missing == Nil) { + logInfo("Submitting " + stage + " (" + stage.rdd + "), which has no missing parents") + submitMissingTasks(stage, jobId.get) + running += stage + } else { + for (parent <- missing) { + submitStage(parent) + } + waiting += stage } - waiting += stage } + } else { + abortStage(stage, "No active job for stage " + stage.id) } } + /** Called when stage's parents are available and we can now do its task. */ - private def submitMissingTasks(stage: Stage) { + private def submitMissingTasks(stage: Stage, jobId: Int) { logDebug("submitMissingTasks(" + stage + ")") // Get our pending tasks and remember them in our pendingTasks entry val myPending = pendingTasks.getOrElseUpdate(stage, new HashSet) @@ -620,7 +759,7 @@ class DAGScheduler( } } - val properties = if (idToActiveJob.contains(stage.jobId)) { + val properties = if (idToActiveJob.contains(jobId)) { idToActiveJob(stage.jobId).properties } else { //this stage will be assigned to "default" pool @@ -703,6 +842,7 @@ class DAGScheduler( resultStageToJob -= stage markStageAsFinished(stage) listenerBus.post(SparkListenerJobEnd(job, JobSucceeded)) + jobIdToStageIdsRemove(job.jobId) } job.listener.taskSucceeded(rt.outputId, event.result) } @@ -738,7 +878,7 @@ class DAGScheduler( changeEpoch = true) } clearCacheLocs() - if (stage.outputLocs.count(_ == Nil) != 0) { + if (stage.outputLocs.exists(_ == Nil)) { // Some tasks had failed; let's resubmit this stage // TODO: Lower-level scheduler should also deal with this logInfo("Resubmitting " + stage + " (" + stage.name + @@ -755,9 +895,12 @@ class DAGScheduler( } waiting --= newlyRunnable running ++= newlyRunnable - for (stage <- newlyRunnable.sortBy(_.id)) { + for { + stage <- newlyRunnable.sortBy(_.id) + jobId <- activeJobForStage(stage) + } { logInfo("Submitting " + stage + " (" + stage.rdd + "), which is now runnable") - submitMissingTasks(stage) + submitMissingTasks(stage, jobId) } } } @@ -841,11 +984,31 @@ class DAGScheduler( } } + private def handleJobCancellation(jobId: Int) { + if (!jobIdToStageIds.contains(jobId)) { + logDebug("Trying to cancel unregistered job " + jobId) + } else { + forIndependentStagesOfRemovedJob(jobId) { stageId => + taskSched.cancelTasks(stageId) + removeStage(stageId) + } + val error = new SparkException("Job %d cancelled".format(jobId)) + val job = idToActiveJob(jobId) + job.listener.jobFailed(error) + listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, Some(job.finalStage)))) + jobIdToStageIds -= jobId + } + } + /** * Aborts all jobs depending on a particular Stage. This is called in response to a task set * being canceled by the TaskScheduler. Use taskSetFailed() to inject this event from outside. */ private def abortStage(failedStage: Stage, reason: String) { + if (!stageIdToStage.contains(failedStage.id)) { + // Skip all the actions if the stage has been removed. + return + } val dependentStages = resultStageToJob.keys.filter(x => stageDependsOn(x, failedStage)).toSeq stageToInfos(failedStage).completionTime = Some(System.currentTimeMillis()) for (resultStage <- dependentStages) { @@ -853,6 +1016,7 @@ class DAGScheduler( val error = new SparkException("Job aborted: " + reason) job.listener.jobFailed(error) listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, Some(failedStage)))) + jobIdToStageIdsRemove(job.jobId) idToActiveJob -= resultStage.jobId activeJobs -= job resultStageToJob -= resultStage @@ -926,21 +1090,18 @@ class DAGScheduler( } private def cleanup(cleanupTime: Long) { - var sizeBefore = stageIdToStage.size - stageIdToStage.clearOldValues(cleanupTime) - logInfo("stageIdToStage " + sizeBefore + " --> " + stageIdToStage.size) - - sizeBefore = shuffleToMapStage.size - shuffleToMapStage.clearOldValues(cleanupTime) - logInfo("shuffleToMapStage " + sizeBefore + " --> " + shuffleToMapStage.size) - - sizeBefore = pendingTasks.size - pendingTasks.clearOldValues(cleanupTime) - logInfo("pendingTasks " + sizeBefore + " --> " + pendingTasks.size) - - sizeBefore = stageToInfos.size - stageToInfos.clearOldValues(cleanupTime) - logInfo("stageToInfos " + sizeBefore + " --> " + stageToInfos.size) + Map( + "stageIdToStage" -> stageIdToStage, + "shuffleToMapStage" -> shuffleToMapStage, + "pendingTasks" -> pendingTasks, + "stageToInfos" -> stageToInfos, + "jobIdToStageIds" -> jobIdToStageIds, + "stageIdToJobIds" -> stageIdToJobIds). + foreach { case(s, t) => { + val sizeBefore = t.size + t.clearOldValues(cleanupTime) + logInfo("%s %d --> %d".format(s, sizeBefore, t.size)) + }} } def stop() { diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index 5353cd24dc..bf8dfb5ac7 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -65,8 +65,9 @@ private[scheduler] case class CompletionEvent( taskMetrics: TaskMetrics) extends DAGSchedulerEvent -private[scheduler] -case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent +private[scheduler] case class LocalJobCompleted(stage: Stage) extends DAGSchedulerEvent + +private[scheduler] case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent private[scheduler] case class ExecutorLost(execId: String) extends DAGSchedulerEvent diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala index a35081f7b1..3841b5616d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala @@ -37,7 +37,7 @@ case class SparkListenerTaskGettingResult( case class SparkListenerTaskEnd(task: Task[_], reason: TaskEndReason, taskInfo: TaskInfo, taskMetrics: TaskMetrics) extends SparkListenerEvents -case class SparkListenerJobStart(job: ActiveJob, properties: Properties = null) +case class SparkListenerJobStart(job: ActiveJob, stageIds: Array[Int], properties: Properties = null) extends SparkListenerEvents case class SparkListenerJobEnd(job: ActiveJob, jobResult: JobResult) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala index c1e65a3c48..bd0a39b4d2 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala @@ -173,7 +173,9 @@ private[spark] class ClusterScheduler(val sc: SparkContext) backend.killTask(tid, execId) } } - tsm.error("Stage %d was cancelled".format(stageId)) + logInfo("Stage %d was cancelled".format(stageId)) + tsm.removeAllRunningTasks() + taskSetFinished(tsm) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala index 8884ea85a3..94961790df 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala @@ -574,7 +574,7 @@ private[spark] class ClusterTaskSetManager( runningTasks = runningTasksSet.size } - private def removeAllRunningTasks() { + private[cluster] def removeAllRunningTasks() { val numRunningTasks = runningTasksSet.size runningTasksSet.clear() if (parent != null) { diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala index 5af51164f7..01e95162c0 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala @@ -144,7 +144,8 @@ private[spark] class LocalScheduler(val threads: Int, val maxFailures: Int, val localActor ! KillTask(tid) } } - tsm.error("Stage %d was cancelled".format(stageId)) + logInfo("Stage %d was cancelled".format(stageId)) + taskSetFinished(tsm) } } @@ -192,17 +193,19 @@ private[spark] class LocalScheduler(val threads: Int, val maxFailures: Int, val synchronized { taskIdToTaskSetId.get(taskId) match { case Some(taskSetId) => - val taskSetManager = activeTaskSets(taskSetId) - taskSetTaskIds(taskSetId) -= taskId - - state match { - case TaskState.FINISHED => - taskSetManager.taskEnded(taskId, state, serializedData) - case TaskState.FAILED => - taskSetManager.taskFailed(taskId, state, serializedData) - case TaskState.KILLED => - taskSetManager.error("Task %d was killed".format(taskId)) - case _ => {} + val taskSetManager = activeTaskSets.get(taskSetId) + taskSetManager.foreach { tsm => + taskSetTaskIds(taskSetId) -= taskId + + state match { + case TaskState.FINISHED => + tsm.taskEnded(taskId, state, serializedData) + case TaskState.FAILED => + tsm.taskFailed(taskId, state, serializedData) + case TaskState.KILLED => + tsm.error("Task %d was killed".format(taskId)) + case _ => {} + } } case None => logInfo("Ignoring update from TID " + taskId + " because its task set is gone") diff --git a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala index d8a0e983b2..1121e06e2e 100644 --- a/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala +++ b/core/src/test/scala/org/apache/spark/JobCancellationSuite.scala @@ -114,7 +114,7 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf // Once A is cancelled, job B should finish fairly quickly. assert(jobB.get() === 100) } - +/* test("two jobs sharing the same stage") { // sem1: make sure cancel is issued after some tasks are launched // sem2: make sure the first stage is not finished until cancel is issued @@ -148,7 +148,7 @@ class JobCancellationSuite extends FunSuite with ShouldMatchers with BeforeAndAf intercept[SparkException] { f1.get() } intercept[SparkException] { f2.get() } } - + */ def testCount() { // Cancel before launching any tasks { diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index a4d41ebbff..8ce8c68af3 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -206,6 +206,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont submit(rdd, Array(0)) complete(taskSets(0), List((Success, 42))) assert(results === Map(0 -> 42)) + assertDataStructuresEmpty } test("local job") { @@ -218,7 +219,10 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont } val jobId = scheduler.nextJobId.getAndIncrement() runEvent(JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, null, listener)) + assert(scheduler.stageToInfos.size === 1) + runEvent(LocalJobCompleted(scheduler.stageToInfos.keys.head)) assert(results === Map(0 -> 42)) + assertDataStructuresEmpty } test("run trivial job w/ dependency") { @@ -227,6 +231,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont submit(finalRdd, Array(0)) complete(taskSets(0), Seq((Success, 42))) assert(results === Map(0 -> 42)) + assertDataStructuresEmpty } test("cache location preferences w/ dependency") { @@ -239,12 +244,14 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont assertLocations(taskSet, Seq(Seq("hostA", "hostB"))) complete(taskSet, Seq((Success, 42))) assert(results === Map(0 -> 42)) + assertDataStructuresEmpty } test("trivial job failure") { submit(makeRdd(1, Nil), Array(0)) failed(taskSets(0), "some failure") assert(failure.getMessage === "Job aborted: some failure") + assertDataStructuresEmpty } test("run trivial shuffle") { @@ -260,6 +267,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont Array(makeBlockManagerId("hostA"), makeBlockManagerId("hostB"))) complete(taskSets(1), Seq((Success, 42))) assert(results === Map(0 -> 42)) + assertDataStructuresEmpty } test("run trivial shuffle with fetch failure") { @@ -285,6 +293,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1.host) === Array("hostA", "hostB")) complete(taskSets(3), Seq((Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) + assertDataStructuresEmpty } test("ignore late map task completions") { @@ -313,6 +322,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont Array(makeBlockManagerId("hostB"), makeBlockManagerId("hostA"))) complete(taskSets(1), Seq((Success, 42), (Success, 43))) assert(results === Map(0 -> 42, 1 -> 43)) + assertDataStructuresEmpty } test("run trivial shuffle with out-of-band failure and retry") { @@ -329,15 +339,16 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont complete(taskSets(0), Seq( (Success, makeMapStatus("hostA", 1)), (Success, makeMapStatus("hostB", 1)))) - // have hostC complete the resubmitted task - complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1)))) - assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === - Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) - complete(taskSets(2), Seq((Success, 42))) - assert(results === Map(0 -> 42)) - } - - test("recursive shuffle failures") { + // have hostC complete the resubmitted task + complete(taskSets(1), Seq((Success, makeMapStatus("hostC", 1)))) + assert(mapOutputTracker.getServerStatuses(shuffleId, 0).map(_._1) === + Array(makeBlockManagerId("hostC"), makeBlockManagerId("hostB"))) + complete(taskSets(2), Seq((Success, 42))) + assert(results === Map(0 -> 42)) + assertDataStructuresEmpty + } + + test("recursive shuffle failures") { val shuffleOneRdd = makeRdd(2, Nil) val shuffleDepOne = new ShuffleDependency(shuffleOneRdd, null) val shuffleTwoRdd = makeRdd(2, List(shuffleDepOne)) @@ -363,6 +374,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont complete(taskSets(4), Seq((Success, makeMapStatus("hostA", 1)))) complete(taskSets(5), Seq((Success, 42))) assert(results === Map(0 -> 42)) + assertDataStructuresEmpty } test("cached post-shuffle") { @@ -394,6 +406,7 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont complete(taskSets(3), Seq((Success, makeMapStatus("hostD", 1)))) complete(taskSets(4), Seq((Success, 42))) assert(results === Map(0 -> 42)) + assertDataStructuresEmpty } /** @@ -413,4 +426,18 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont private def makeBlockManagerId(host: String): BlockManagerId = BlockManagerId("exec-" + host, host, 12345, 0) + private def assertDataStructuresEmpty = { + assert(scheduler.pendingTasks.isEmpty) + assert(scheduler.activeJobs.isEmpty) + assert(scheduler.failed.isEmpty) + assert(scheduler.idToActiveJob.isEmpty) + assert(scheduler.jobIdToStageIds.isEmpty) + assert(scheduler.stageIdToJobIds.isEmpty) + assert(scheduler.stageIdToStage.isEmpty) + assert(scheduler.stageToInfos.isEmpty) + assert(scheduler.resultStageToJob.isEmpty) + assert(scheduler.running.isEmpty) + assert(scheduler.shuffleToMapStage.isEmpty) + assert(scheduler.waiting.isEmpty) + } } -- cgit v1.2.3 From 6f8359b5ad6c069c6105631a6c74e225b866cfce Mon Sep 17 00:00:00 2001 From: Mark Hamstra Date: Tue, 19 Nov 2013 10:16:48 -0800 Subject: Actor instead of eventQueue for LocalJobCompleted --- core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 10417b9343..ad436f854c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -694,7 +694,7 @@ class DAGScheduler( case e: Exception => job.listener.jobFailed(e) } finally { - eventQueue.put(LocalJobCompleted(job.finalStage)) + eventProcessActor ! LocalJobCompleted(job.finalStage) } } -- cgit v1.2.3 From 982797dcbafa4c1149ad354b0c5a07e3f74fe005 Mon Sep 17 00:00:00 2001 From: Mark Hamstra Date: Tue, 19 Nov 2013 16:59:42 -0800 Subject: Fixed intended side-effects --- core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index ad436f854c..bf5827d011 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -568,8 +568,8 @@ class DAGScheduler( val activeInGroup = activeJobs.filter(groupId == _.properties.get(SparkContext.SPARK_JOB_GROUP_ID)) val jobIds = activeInGroup.map(_.jobId) jobIds.foreach { handleJobCancellation } - activeJobs -- activeInGroup - idToActiveJob -- jobIds + activeJobs --= activeInGroup + idToActiveJob --= jobIds case AllJobsCancelled => // Cancel all running jobs. -- cgit v1.2.3 From 94087c463b41a92a9462b954f1f6452614569fe5 Mon Sep 17 00:00:00 2001 From: Mark Hamstra Date: Wed, 20 Nov 2013 15:47:30 -0800 Subject: Removed redundant residual re: reverted refactoring. --- core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index bf5827d011..be46f74f7c 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -384,7 +384,7 @@ class DAGScheduler( .format(jobId, stageId)) } else { jobSet -= jobId - if ((jobSet - jobId).isEmpty) { // no other job needs this stage + if (jobSet.isEmpty) { // no other job needs this stage p(stageId) } } -- cgit v1.2.3 From 205566e56e2891245b2d7820bfb3629945a2dcd9 Mon Sep 17 00:00:00 2001 From: Mark Hamstra Date: Wed, 20 Nov 2013 14:49:09 -0800 Subject: Improved comment --- core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index be46f74f7c..6f9d4d52a4 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -243,10 +243,9 @@ class DAGScheduler( } /** - * Create a Stage for the given RDD, either as a shuffle map stage (for a ShuffleDependency) or - * as a result stage for the final RDD used directly in an action. The stage will also be - * associated with the provided jobId.. Shuffle map stages, whose shuffleId may have previously - * been registered in the MapOutputTracker, should be (re)-created using newOrUsedStage. + * Create a Stage -- either directly for use as a result stage, or as part of the (re)-creation + * of a shuffle map stage in newOrUsedStage. The stage will be associated with the provided + * jobId. Production of shuffle map stages should always use newOrUsedStage, not newStage directly. */ private def newStage( rdd: RDD[_], -- cgit v1.2.3 From 686a420ddc33407050d9019711cbe801fc352fa3 Mon Sep 17 00:00:00 2001 From: Mark Hamstra Date: Fri, 22 Nov 2013 10:20:09 -0800 Subject: Refactoring to make job removal, stage removal, task cancellation clearer --- .../org/apache/spark/scheduler/DAGScheduler.scala | 76 +++++++++++----------- 1 file changed, 39 insertions(+), 37 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 6f9d4d52a4..b8b3ac0b43 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -370,9 +370,11 @@ class DAGScheduler( } } - // Removes job and applies p to any stages that aren't needed by any other jobs - private def forIndependentStagesOfRemovedJob(jobId: Int)(p: Int => Unit) { + // Removes job and any stages that are not needed by any other job. Returns the set of ids for stages that + // were removed and whose associated tasks may need to be cancelled. + private def removeJobAndIndependentStages(jobId: Int): Set[Int] = { val registeredStages = jobIdToStageIds(jobId) + val independentStages = new HashSet[Int]() if (registeredStages.isEmpty) { logError("No stages registered for job " + jobId) } else { @@ -382,49 +384,51 @@ class DAGScheduler( logError("Job %d not registered for stage %d even though that stage was registered for the job" .format(jobId, stageId)) } else { + def removeStage(stageId: Int) { + // data structures based on Stage + stageIdToStage.get(stageId).foreach { s => + if (running.contains(s)) { + logDebug("Removing running stage %d".format(stageId)) + running -= s + } + stageToInfos -= s + shuffleToMapStage.keys.filter(shuffleToMapStage(_) == s).foreach(shuffleToMapStage.remove) + if (pendingTasks.contains(s) && !pendingTasks(s).isEmpty) { + logDebug("Removing pending status for stage %d".format(stageId)) + } + pendingTasks -= s + if (waiting.contains(s)) { + logDebug("Removing stage %d from waiting set.".format(stageId)) + waiting -= s + } + if (failed.contains(s)) { + logDebug("Removing stage %d from failed set.".format(stageId)) + failed -= s + } + } + // data structures based on StageId + stageIdToStage -= stageId + stageIdToJobIds -= stageId + + logDebug("After removal of stage %d, remaining stages = %d".format(stageId, stageIdToStage.size)) + } + jobSet -= jobId if (jobSet.isEmpty) { // no other job needs this stage - p(stageId) + independentStages += stageId + removeStage(stageId) } } } } - } - - private def removeStage(stageId: Int) { - // data structures based on Stage - stageIdToStage.get(stageId).foreach { s => - if (running.contains(s)) { - logDebug("Removing running stage %d".format(stageId)) - running -= s - } - stageToInfos -= s - shuffleToMapStage.keys.filter(shuffleToMapStage(_) == s).foreach(shuffleToMapStage.remove(_)) - if (pendingTasks.contains(s) && !pendingTasks(s).isEmpty) { - logDebug("Removing pending status for stage %d".format(stageId)) - } - pendingTasks -= s - if (waiting.contains(s)) { - logDebug("Removing stage %d from waiting set.".format(stageId)) - waiting -= s - } - if (failed.contains(s)) { - logDebug("Removing stage %d from failed set.".format(stageId)) - failed -= s - } - } - // data structures based on StageId - stageIdToStage -= stageId - stageIdToJobIds -= stageId - - logDebug("After removal of stage %d, remaining stages = %d".format(stageId, stageIdToStage.size)) + independentStages.toSet } private def jobIdToStageIdsRemove(jobId: Int) { if (!jobIdToStageIds.contains(jobId)) { logDebug("Trying to remove unregistered job " + jobId) } else { - forIndependentStagesOfRemovedJob(jobId) { removeStage } + removeJobAndIndependentStages(jobId) jobIdToStageIds -= jobId } } @@ -987,10 +991,8 @@ class DAGScheduler( if (!jobIdToStageIds.contains(jobId)) { logDebug("Trying to cancel unregistered job " + jobId) } else { - forIndependentStagesOfRemovedJob(jobId) { stageId => - taskSched.cancelTasks(stageId) - removeStage(stageId) - } + val independentStages = removeJobAndIndependentStages(jobId) + independentStages.foreach { taskSched.cancelTasks } val error = new SparkException("Job %d cancelled".format(jobId)) val job = idToActiveJob(jobId) job.listener.jobFailed(error) -- cgit v1.2.3 From 27c45e523620d801d547f167a5a33d71ee3af7b5 Mon Sep 17 00:00:00 2001 From: Mark Hamstra Date: Fri, 22 Nov 2013 11:14:39 -0800 Subject: Cleaned up job cancellation handling --- .../main/scala/org/apache/spark/scheduler/DAGScheduler.scala | 12 +++++------- 1 file changed, 5 insertions(+), 7 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index b8b3ac0b43..aeac14ad7b 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -371,7 +371,7 @@ class DAGScheduler( } // Removes job and any stages that are not needed by any other job. Returns the set of ids for stages that - // were removed and whose associated tasks may need to be cancelled. + // were removed. The associated tasks for those stages need to be cancelled if we got here via job cancellation. private def removeJobAndIndependentStages(jobId: Int): Set[Int] = { val registeredStages = jobIdToStageIds(jobId) val independentStages = new HashSet[Int]() @@ -562,8 +562,6 @@ class DAGScheduler( case JobCancelled(jobId) => handleJobCancellation(jobId) - idToActiveJob.get(jobId).foreach(job => activeJobs -= job) - idToActiveJob -= jobId case JobGroupCancelled(groupId) => // Cancel all jobs belonging to this job group. @@ -571,14 +569,12 @@ class DAGScheduler( val activeInGroup = activeJobs.filter(groupId == _.properties.get(SparkContext.SPARK_JOB_GROUP_ID)) val jobIds = activeInGroup.map(_.jobId) jobIds.foreach { handleJobCancellation } - activeJobs --= activeInGroup - idToActiveJob --= jobIds case AllJobsCancelled => // Cancel all running jobs. running.map(_.jobId).foreach { handleJobCancellation } - activeJobs.clear() - idToActiveJob.clear() + activeJobs.clear() // These should already be empty by this point, + idToActiveJob.clear() // but just in case we lost track of some jobs... case ExecutorGained(execId, host) => handleExecutorGained(execId, host) @@ -998,6 +994,8 @@ class DAGScheduler( job.listener.jobFailed(error) listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, Some(job.finalStage)))) jobIdToStageIds -= jobId + activeJobs -= job + idToActiveJob -= jobId } } -- cgit v1.2.3 From 9ae2d094a967782e3f5a624dd854059a40430ee6 Mon Sep 17 00:00:00 2001 From: Mark Hamstra Date: Fri, 22 Nov 2013 13:14:26 -0800 Subject: Tightly couple stageIdToJobIds and jobIdToStageIds --- .../org/apache/spark/scheduler/DAGScheduler.scala | 29 +++++++++------------- 1 file changed, 12 insertions(+), 17 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index aeac14ad7b..01c5133e6e 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -259,7 +259,7 @@ class DAGScheduler( val stage = new Stage(id, rdd, numTasks, shuffleDep, getParentStages(rdd, jobId), jobId, callSite) stageIdToStage(id) = stage - registerJobIdWithStages(jobId, stage) + updateJobIdStageIdMaps(jobId, stage) stageToInfos(stage) = new StageInfo(stage) stage } @@ -348,30 +348,24 @@ class DAGScheduler( * Registers the given jobId among the jobs that need the given stage and * all of that stage's ancestors. */ - private def registerJobIdWithStages(jobId: Int, stage: Stage) { - def registerJobIdWithStageList(stages: List[Stage]) { + private def updateJobIdStageIdMaps(jobId: Int, stage: Stage) { + def updateJobIdStageIdMapsList(stages: List[Stage]) { if (!stages.isEmpty) { val s = stages.head stageIdToJobIds.getOrElseUpdate(s.id, new HashSet[Int]()) += jobId + jobIdToStageIds.getOrElseUpdate(jobId, new HashSet[Int]()) += s.id val parents = getParentStages(s.rdd, jobId) val parentsWithoutThisJobId = parents.filter(p => !stageIdToJobIds.get(p.id).exists(_.contains(jobId))) - registerJobIdWithStageList(parentsWithoutThisJobId ++ stages.tail) + updateJobIdStageIdMapsList(parentsWithoutThisJobId ++ stages.tail) } } - registerJobIdWithStageList(List(stage)) + updateJobIdStageIdMapsList(List(stage)) } - private def jobIdToStageIdsAdd(jobId: Int) { - val stageSet = jobIdToStageIds.getOrElseUpdate(jobId, new HashSet[Int]()) - stageIdToJobIds.foreach { case (stageId, jobSet) => - if (jobSet.contains(jobId)) { - stageSet += stageId - } - } - } - - // Removes job and any stages that are not needed by any other job. Returns the set of ids for stages that - // were removed. The associated tasks for those stages need to be cancelled if we got here via job cancellation. + /** + * Removes job and any stages that are not needed by any other job. Returns the set of ids for stages that + * were removed. The associated tasks for those stages need to be cancelled if we got here via job cancellation. + */ private def removeJobAndIndependentStages(jobId: Int): Set[Int] = { val registeredStages = jobIdToStageIds(jobId) val independentStages = new HashSet[Int]() @@ -555,7 +549,6 @@ class DAGScheduler( idToActiveJob(jobId) = job activeJobs += job resultStageToJob(finalStage) = job - jobIdToStageIdsAdd(jobId) listenerBus.post(SparkListenerJobStart(job, jobIdToStageIds(jobId).toArray, properties)) submitStage(finalStage) } @@ -605,9 +598,11 @@ class DAGScheduler( handleTaskCompletion(completion) case LocalJobCompleted(stage) => + val jobId = stageIdToJobIds(stage.id).head stageIdToJobIds -= stage.id // clean up data structures that were populated for a local job, stageIdToStage -= stage.id // but that won't get cleaned up via the normal paths through stageToInfos -= stage // completion events or stage abort + jobIdToStageIds -= jobId case TaskSetFailed(taskSet, reason) => stageIdToStage.get(taskSet.stageId).foreach { abortStage(_, reason) } -- cgit v1.2.3 From c9fcd909d0f86b08935a132409888b30e989bca4 Mon Sep 17 00:00:00 2001 From: Mark Hamstra Date: Sun, 24 Nov 2013 17:49:14 -0800 Subject: Local jobs post SparkListenerJobEnd, and DAGScheduler data structure cleanup always occurs before any posting of SparkListenerJobEnd. --- .../scala/org/apache/spark/scheduler/DAGScheduler.scala | 17 ++++++++++------- .../org/apache/spark/scheduler/DAGSchedulerEvent.scala | 2 +- 2 files changed, 11 insertions(+), 8 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 01c5133e6e..b371a2412f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -597,12 +597,13 @@ class DAGScheduler( listenerBus.post(SparkListenerTaskEnd(task, reason, taskInfo, taskMetrics)) handleTaskCompletion(completion) - case LocalJobCompleted(stage) => - val jobId = stageIdToJobIds(stage.id).head + case LocalJobCompleted(job, result) => + val stage = job.finalStage stageIdToJobIds -= stage.id // clean up data structures that were populated for a local job, stageIdToStage -= stage.id // but that won't get cleaned up via the normal paths through stageToInfos -= stage // completion events or stage abort - jobIdToStageIds -= jobId + jobIdToStageIds -= job.jobId + listenerBus.post(SparkListenerJobEnd(job, result)) case TaskSetFailed(taskSet, reason) => stageIdToStage.get(taskSet.stageId).foreach { abortStage(_, reason) } @@ -672,6 +673,7 @@ class DAGScheduler( // Broken out for easier testing in DAGSchedulerSuite. protected def runLocallyWithinThread(job: ActiveJob) { + var jobResult: JobResult = JobSucceeded try { SparkEnv.set(env) val rdd = job.finalStage.rdd @@ -686,9 +688,10 @@ class DAGScheduler( } } catch { case e: Exception => + jobResult = JobFailed(e, Some(job.finalStage)) job.listener.jobFailed(e) } finally { - eventProcessActor ! LocalJobCompleted(job.finalStage) + eventProcessActor ! LocalJobCompleted(job, jobResult) } } @@ -835,8 +838,8 @@ class DAGScheduler( activeJobs -= job resultStageToJob -= stage markStageAsFinished(stage) - listenerBus.post(SparkListenerJobEnd(job, JobSucceeded)) jobIdToStageIdsRemove(job.jobId) + listenerBus.post(SparkListenerJobEnd(job, JobSucceeded)) } job.listener.taskSucceeded(rt.outputId, event.result) } @@ -987,10 +990,10 @@ class DAGScheduler( val error = new SparkException("Job %d cancelled".format(jobId)) val job = idToActiveJob(jobId) job.listener.jobFailed(error) - listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, Some(job.finalStage)))) jobIdToStageIds -= jobId activeJobs -= job idToActiveJob -= jobId + listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, Some(job.finalStage)))) } } @@ -1009,11 +1012,11 @@ class DAGScheduler( val job = resultStageToJob(resultStage) val error = new SparkException("Job aborted: " + reason) job.listener.jobFailed(error) - listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, Some(failedStage)))) jobIdToStageIdsRemove(job.jobId) idToActiveJob -= resultStage.jobId activeJobs -= job resultStageToJob -= resultStage + listenerBus.post(SparkListenerJobEnd(job, JobFailed(error, Some(failedStage)))) } if (dependentStages.isEmpty) { logInfo("Ignoring failure of " + failedStage + " because all jobs depending on it are done") diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index bf8dfb5ac7..aa496b7ac6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -65,7 +65,7 @@ private[scheduler] case class CompletionEvent( taskMetrics: TaskMetrics) extends DAGSchedulerEvent -private[scheduler] case class LocalJobCompleted(stage: Stage) extends DAGSchedulerEvent +private[scheduler] case class LocalJobCompleted(job: ActiveJob, result: JobResult) extends DAGSchedulerEvent private[scheduler] case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent -- cgit v1.2.3 From f55d0b935d7c148f49b15932938e91150b64466f Mon Sep 17 00:00:00 2001 From: Mark Hamstra Date: Tue, 26 Nov 2013 14:06:59 -0800 Subject: Synchronous, inline cleanup after runLocally --- .../scala/org/apache/spark/scheduler/DAGScheduler.scala | 15 ++++++--------- .../org/apache/spark/scheduler/DAGSchedulerEvent.scala | 2 -- .../org/apache/spark/scheduler/DAGSchedulerSuite.scala | 2 -- 3 files changed, 6 insertions(+), 13 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index b371a2412f..b849867519 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -597,14 +597,6 @@ class DAGScheduler( listenerBus.post(SparkListenerTaskEnd(task, reason, taskInfo, taskMetrics)) handleTaskCompletion(completion) - case LocalJobCompleted(job, result) => - val stage = job.finalStage - stageIdToJobIds -= stage.id // clean up data structures that were populated for a local job, - stageIdToStage -= stage.id // but that won't get cleaned up via the normal paths through - stageToInfos -= stage // completion events or stage abort - jobIdToStageIds -= job.jobId - listenerBus.post(SparkListenerJobEnd(job, result)) - case TaskSetFailed(taskSet, reason) => stageIdToStage.get(taskSet.stageId).foreach { abortStage(_, reason) } @@ -691,7 +683,12 @@ class DAGScheduler( jobResult = JobFailed(e, Some(job.finalStage)) job.listener.jobFailed(e) } finally { - eventProcessActor ! LocalJobCompleted(job, jobResult) + val s = job.finalStage + stageIdToJobIds -= s.id // clean up data structures that were populated for a local job, + stageIdToStage -= s.id // but that won't get cleaned up via the normal paths through + stageToInfos -= s // completion events or stage abort + jobIdToStageIds -= job.jobId + listenerBus.post(SparkListenerJobEnd(job, jobResult)) } } diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala index aa496b7ac6..add1187613 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGSchedulerEvent.scala @@ -65,8 +65,6 @@ private[scheduler] case class CompletionEvent( taskMetrics: TaskMetrics) extends DAGSchedulerEvent -private[scheduler] case class LocalJobCompleted(job: ActiveJob, result: JobResult) extends DAGSchedulerEvent - private[scheduler] case class ExecutorGained(execId: String, host: String) extends DAGSchedulerEvent private[scheduler] case class ExecutorLost(execId: String) extends DAGSchedulerEvent diff --git a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala index 8ce8c68af3..706d84a58b 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/DAGSchedulerSuite.scala @@ -219,8 +219,6 @@ class DAGSchedulerSuite extends FunSuite with BeforeAndAfter with LocalSparkCont } val jobId = scheduler.nextJobId.getAndIncrement() runEvent(JobSubmitted(jobId, rdd, jobComputeFunc, Array(0), true, null, listener)) - assert(scheduler.stageToInfos.size === 1) - runEvent(LocalJobCompleted(scheduler.stageToInfos.keys.head)) assert(results === Map(0 -> 42)) assertDataStructuresEmpty } -- cgit v1.2.3 From 403234dd0d63a7e89f3304d7bb31e3675d405a13 Mon Sep 17 00:00:00 2001 From: Mark Hamstra Date: Tue, 26 Nov 2013 22:25:20 -0800 Subject: SparkListenerJobStart posted from local jobs --- core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala | 1 + 1 file changed, 1 insertion(+) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index b849867519..f9cd021dd3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -544,6 +544,7 @@ class DAGScheduler( logInfo("Missing parents: " + getMissingParentStages(finalStage)) if (allowLocal && finalStage.parents.size == 0 && partitions.length == 1) { // Compute very short actions like first() or take() with no parent stages locally. + listenerBus.post(SparkListenerJobStart(job, Array(), properties)) runLocally(job) } else { idToActiveJob(jobId) = job -- cgit v1.2.3 From 974a69d79c4bf64fd9f27b65c1c464d33e647e20 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Tue, 3 Dec 2013 11:34:38 -0800 Subject: Marked doCheckpointCalled as transient. --- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 96e4841c78..893708f8f2 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -940,7 +940,7 @@ abstract class RDD[T: ClassManifest]( private var storageLevel: StorageLevel = StorageLevel.NONE /** Record user function generating this RDD. */ - private[spark] val origin = Utils.formatSparkCallSite + @transient private[spark] val origin = Utils.formatSparkCallSite private[spark] def elementClassManifest: ClassManifest[T] = classManifest[T] @@ -955,7 +955,7 @@ abstract class RDD[T: ClassManifest]( def context = sc // Avoid handling doCheckpoint multiple times to prevent excessive recursion - private var doCheckpointCalled = false + @transient private var doCheckpointCalled = false /** * Performs the checkpointing of this RDD by saving this. It is called by the DAGScheduler -- cgit v1.2.3 From 217611680d09efcf6a218179081ee71c0a8d5c12 Mon Sep 17 00:00:00 2001 From: Andrew Ash Date: Wed, 4 Dec 2013 11:29:20 -0800 Subject: Add missing space after "Serialized" in StorageLevel Current code creates outputs like: scala> res0.getStorageLevel.description res2: String = Serialized1x Replicated --- core/src/main/scala/org/apache/spark/storage/StorageLevel.scala | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala index 632ff047d1..b5596dffd3 100644 --- a/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala +++ b/core/src/main/scala/org/apache/spark/storage/StorageLevel.scala @@ -101,7 +101,7 @@ class StorageLevel private( var result = "" result += (if (useDisk) "Disk " else "") result += (if (useMemory) "Memory " else "") - result += (if (deserialized) "Deserialized " else "Serialized") + result += (if (deserialized) "Deserialized " else "Serialized ") result += "%sx Replicated".format(replication) result } -- cgit v1.2.3 From 380b90b9b360db9cb6a4edc1312704afe11eb31d Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 4 Dec 2013 14:41:48 -0800 Subject: Fix small bug in web UI and minor clean-up. There was a bug where sorting order didn't work correctly for write time metrics. I also cleaned up some earlier code that fixed the same issue for read and write bytes. --- .../scala/org/apache/spark/ui/jobs/StagePage.scala | 29 ++++++++++------------ 1 file changed, 13 insertions(+), 16 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala index fbd822867f..baccc4281a 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala @@ -152,21 +152,18 @@ private[spark] class StagePage(parent: JobProgressUI) { else metrics.map(m => parent.formatDuration(m.executorRunTime)).getOrElse("") val gcTime = metrics.map(m => m.jvmGCTime).getOrElse(0L) - var shuffleReadSortable: String = "" - var shuffleReadReadable: String = "" - if (shuffleRead) { - shuffleReadSortable = metrics.flatMap{m => m.shuffleReadMetrics}.map{s => s.remoteBytesRead}.toString() - shuffleReadReadable = metrics.flatMap{m => m.shuffleReadMetrics}.map{s => - Utils.bytesToString(s.remoteBytesRead)}.getOrElse("") - } + val maybeShuffleRead = metrics.flatMap{m => m.shuffleReadMetrics}.map{s => s.remoteBytesRead} + val shuffleReadSortable = maybeShuffleRead.map(_.toString).getOrElse("") + val shuffleReadReadable = maybeShuffleRead.map{Utils.bytesToString(_)}.getOrElse("") - var shuffleWriteSortable: String = "" - var shuffleWriteReadable: String = "" - if (shuffleWrite) { - shuffleWriteSortable = metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => s.shuffleBytesWritten}.toString() - shuffleWriteReadable = metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => - Utils.bytesToString(s.shuffleBytesWritten)}.getOrElse("") - } + val maybeShuffleWrite = metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => s.shuffleBytesWritten} + val shuffleWriteSortable = maybeShuffleWrite.map(_.toString).getOrElse("") + val shuffleWriteReadable = maybeShuffleWrite.map{Utils.bytesToString(_)}.getOrElse("") + + val maybeWriteTime = metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => s.shuffleWriteTime} + val writeTimeSortable = maybeWriteTime.map(_.toString).getOrElse("") + val writeTimeReadable = maybeWriteTime.map{ t => t / (1000 * 1000)}.map{ ms => + if (ms == 0) "" else parent.formatDuration(ms)}.getOrElse("") {info.index} @@ -187,8 +184,8 @@ private[spark] class StagePage(parent: JobProgressUI) { }} {if (shuffleWrite) { - {metrics.flatMap{m => m.shuffleWriteMetrics}.map{s => - parent.formatDuration(s.shuffleWriteTime / (1000 * 1000))}.getOrElse("")} + + {writeTimeReadable} {shuffleWriteReadable} -- cgit v1.2.3 From b1c6fa1584099b3a1e0615c100f10ea90b1ad2c9 Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Wed, 4 Dec 2013 18:39:34 -0800 Subject: Document missing configs and set shuffle consolidation to false. --- .../apache/spark/storage/ShuffleBlockManager.scala | 2 +- docs/configuration.md | 37 +++++++++++++++++++++- 2 files changed, 37 insertions(+), 2 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index 2f1b049ce4..e828e1d1c5 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -62,7 +62,7 @@ class ShuffleBlockManager(blockManager: BlockManager) { // Turning off shuffle file consolidation causes all shuffle Blocks to get their own file. // TODO: Remove this once the shuffle file consolidation feature is stable. val consolidateShuffleFiles = - System.getProperty("spark.shuffle.consolidateFiles", "true").toBoolean + System.getProperty("spark.shuffle.consolidateFiles", "false").toBoolean private val bufferSize = System.getProperty("spark.shuffle.file.buffer.kb", "100").toInt * 1024 diff --git a/docs/configuration.md b/docs/configuration.md index 97183bafdb..1a3eef345c 100644 --- a/docs/configuration.md +++ b/docs/configuration.md @@ -327,7 +327,42 @@ Apart from these, the following properties are also available, and may be useful Too large a value decreases parallelism during broadcast (makes it slower); however, if it is too small, BlockManager might take a performance hit. - + + spark.shuffle.consolidateFiles + false + + If set to "true", consolidates intermediate files created during a shuffle. + + + + + spark.speculation + false + + If set to "true", performs speculative execution of tasks. This means if one or more tasks are running slowly in a stage, they will be re-launched. + + + + spark.speculation.interval + 100 + + How often Spark will check for tasks to speculate, in seconds. + + + + spark.speculation.quantile + 0.75 + + Percentage of tasks which must be complete before speculation is enabled for a particular stage. + + + + spark.speculation.multiplier + 1.5 + + How many times slower a task is than the median to be considered for speculation. + + # Environment Variables -- cgit v1.2.3 From aebb123fd3b4bf0d57d867f33ca0325340ee42e4 Mon Sep 17 00:00:00 2001 From: Mark Hamstra Date: Thu, 5 Dec 2013 17:16:44 -0800 Subject: jobWaiter.synchronized before jobWaiter.wait --- core/src/main/scala/org/apache/spark/FutureAction.scala | 2 +- core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala | 1 + 2 files changed, 2 insertions(+), 1 deletion(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/FutureAction.scala b/core/src/main/scala/org/apache/spark/FutureAction.scala index 1ad9240cfa..c6b4ac5192 100644 --- a/core/src/main/scala/org/apache/spark/FutureAction.scala +++ b/core/src/main/scala/org/apache/spark/FutureAction.scala @@ -99,7 +99,7 @@ class SimpleFutureAction[T] private[spark](jobWaiter: JobWaiter[_], resultFunc: override def ready(atMost: Duration)(implicit permit: CanAwait): SimpleFutureAction.this.type = { if (!atMost.isFinite()) { awaitResult() - } else { + } else jobWaiter.synchronized { val finishTime = System.currentTimeMillis() + atMost.toMillis while (!isCompleted) { val time = System.currentTimeMillis() diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala index 58f238d8cf..b026f860a8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobWaiter.scala @@ -31,6 +31,7 @@ private[spark] class JobWaiter[T]( private var finishedTasks = 0 // Is the job as a whole finished (succeeded or failed)? + @volatile private var _jobFinished = totalTasks == 0 def jobFinished = _jobFinished -- cgit v1.2.3 From 1cb259cb577bfd3385cca6bb187d7fee18bd2c24 Mon Sep 17 00:00:00 2001 From: Henry Saputra Date: Thu, 5 Dec 2013 18:50:26 -0800 Subject: Change the name of input ragument in ClusterScheduler#initialize from context to backend. The SchedulerBackend used to be called ClusterSchedulerContext so just want to make small change of the input param in the ClusterScheduler#initialize to reflect this. --- .../scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) (limited to 'core/src/main/scala/org') diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala index c1e65a3c48..f475d000bd 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala @@ -100,8 +100,8 @@ private[spark] class ClusterScheduler(val sc: SparkContext) this.dagScheduler = dagScheduler } - def initialize(context: SchedulerBackend) { - backend = context + def initialize(backend: SchedulerBackend) { + this.backend = backend // temporarily set rootPool name to empty rootPool = new Pool("", schedulingMode, 0, 0) schedulableBuilder = { -- cgit v1.2.3