From 2a37235825cecd3f75286d11456c6e3cb13d4327 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Fri, 18 Oct 2013 00:07:49 -0700 Subject: Initial commit of adding histogram functionality to the DoubleRDDFunctions. --- .../org/apache/spark/api/java/JavaDoubleRDD.scala | 32 +++ .../org/apache/spark/rdd/DoubleRDDFunctions.scala | 134 ++++++++++++ .../org/apache/spark/rdd/DoubleRDDSuite.scala | 233 +++++++++++++++++++++ 3 files changed, 399 insertions(+) create mode 100644 core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index 5fd1fab580..d2a2818e59 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -26,6 +26,8 @@ import org.apache.spark.storage.StorageLevel import java.lang.Double import org.apache.spark.Partitioner +import scala.collection.JavaConverters._ + class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, JavaDoubleRDD] { override val classManifest: ClassManifest[Double] = implicitly[ClassManifest[Double]] @@ -158,6 +160,36 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav /** (Experimental) Approximate operation to return the sum within a timeout. */ def sumApprox(timeout: Long): PartialResult[BoundedDouble] = srdd.sumApprox(timeout) + + /** + * Compute a histogram of the data using bucketCount number of buckets evenly + * spaced between the minimum and maximum of the RDD. For example if the min + * value is 0 and the max is 100 and there are two buckets the resulting + * buckets will be [0,50) [50,100]. bucketCount must be at least 1 + * If the RDD contains infinity, NaN throws an exception + * If the elements in RDD do not vary (max == min) throws an exception + */ + def histogram(bucketCount: Int): Pair[Array[Double], Array[Long]] = { + val result = srdd.histogram(bucketCount) + (result._1.map(scala.Double.box(_)), result._2) + } + /** + * Compute a histogram using the provided buckets. The buckets are all open + * to the left except for the last which is closed + * e.g. for the array + * [1,10,20,50] the buckets are [1,10) [10,20) [20,50] + * e.g 1<=x<10 , 10<=x<20, 20<=x<50 + * And on the input of 1 and 50 we would have a histogram of 1,0,0 + * + * Note: if your histogram is evenly spaced (e.g. [0,10,20,30]) this switches + * from an O(log n) inseration to O(1) per element. (where n = # buckets) + * buckets must be sorted and not contain any duplicates. + * buckets array must be at least two elements + * All NaN entries are treated the same. + */ + def histogram(buckets: Array[Double]): Array[Long] = { + srdd.histogram(buckets.map(_.toDouble)) + } } object JavaDoubleRDD { diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index a4bec41752..776a83cefe 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -24,6 +24,8 @@ import org.apache.spark.partial.SumEvaluator import org.apache.spark.util.StatCounter import org.apache.spark.{TaskContext, Logging} +import scala.collection.immutable.NumericRange + /** * Extra functions available on RDDs of Doubles through an implicit conversion. * Import `org.apache.spark.SparkContext._` at the top of your program to use these functions. @@ -76,4 +78,136 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { val evaluator = new SumEvaluator(self.partitions.size, confidence) self.context.runApproximateJob(self, processPartition, evaluator, timeout) } + + /** + * Compute a histogram of the data using bucketCount number of buckets evenly + * spaced between the minimum and maximum of the RDD. For example if the min + * value is 0 and the max is 100 and there are two buckets the resulting + * buckets will be [0,50) [50,100]. bucketCount must be at least 1 + * If the RDD contains infinity, NaN throws an exception + * If the elements in RDD do not vary (max == min) throws an exception + */ + def histogram(bucketCount: Int): Pair[Array[Double], Array[Long]] = { + // Compute the minimum and the maxium + val (max: Double, min: Double) = self.mapPartitions { items => + Iterator(items.foldRight(-1/0.0, Double.NaN)((e: Double, x: Pair[Double, Double]) => + (x._1.max(e),x._2.min(e)))) + }.reduce { (maxmin1, maxmin2) => + (maxmin1._1.max(maxmin2._1), maxmin1._2.min(maxmin2._2)) + } + if (max.isNaN() || max.isInfinity || min.isInfinity ) { + throw new UnsupportedOperationException("Histogram on either an empty RDD or RDD containing +-infinity or NaN") + } + if (max == min) { + throw new UnsupportedOperationException("Histogram with no range in elements") + } + val increment: Double = (max-min)/bucketCount.toDouble + val range = Range.Double.inclusive(min, max, increment) + val buckets: Array[Double] = range.toArray + (buckets,histogram(buckets)) + } + /** + * Compute a histogram using the provided buckets. The buckets are all open + * to the left except for the last which is closed + * e.g. for the array + * [1,10,20,50] the buckets are [1,10) [10,20) [20,50] + * e.g 1<=x<10 , 10<=x<20, 20<=x<50 + * And on the input of 1 and 50 we would have a histogram of 1,0,0 + * + * Note: if your histogram is evenly spaced (e.g. [0,10,20,30]) this switches + * from an O(log n) inseration to O(1) per element. (where n = # buckets) + * buckets must be sorted and not contain any duplicates. + * buckets array must be at least two elements + * All NaN entries are treated the same. + */ + def histogram(buckets: Array[Double]): Array[Long] = { + if (buckets.length < 2) { + throw new IllegalArgumentException("buckets array must have at least two elements") + } + // The histogramPartition function computes the partail histogram for a given + // partition. The provided bucketFunction determines which bucket in the array + // to increment or returns None if there is no bucket. This is done so we can + // specialize for uniformly distributed buckets and save the O(log n) binary + // search cost. + def histogramPartition(bucketFunction: (Double) => Option[Int])(iter: Iterator[Double]): Iterator[Array[Long]] = { + val counters = new Array[Long](buckets.length-1) + while (iter.hasNext) { + bucketFunction(iter.next()) match { + case Some(x: Int) => {counters(x)+=1} + case _ => {} + } + } + Iterator(counters) + } + // Merge the counters. + def mergeCounters(a1: Array[Long], a2: Array[Long]): Array[Long] = { + a1.indices.foreach(i => a1(i) += a2(i)) + a1 + } + // Basic bucket function. This works using Java's built in Array + // binary search. Takes log(size(buckets)) + def basicBucketFunction(e: Double): Option[Int] = { + val location = java.util.Arrays.binarySearch(buckets, e) + if (location < 0) { + // If the location is less than 0 then the insertion point in the array + // to keep it sorted is -location-1 + val insertionPoint = -location-1 + // If we have to insert before the first element or after the last one + // its out of bounds. + // We do this rather than buckets.lengthCompare(insertionPoint) + // because Array[Double] fails to override it (for now). + if (insertionPoint > 0 && insertionPoint < buckets.length) { + Some(insertionPoint-1) + } else { + None + } + } else if (location < buckets.length-1) { + // Exact match, just insert here + Some(location) + } else { + // Exact match to the last element + Some(location-1) + } + } + // Determine the bucket function in constant time. Requires that buckets are evenly spaced + def fastBucketFunction(min: Double, increment: Double, count: Int)(e: Double): Option[Int] = { + // If our input is not a number unless the increment is also NaN then we fail fast + if (e.isNaN()) { + return None + } + val bucketNumber = (e-min)/(increment) + // We do this rather than buckets.lengthCompare(bucketNumber) + // because Array[Double] fails to override it (for now). + if (bucketNumber > count || bucketNumber < 0) { + None + } else { + Some(bucketNumber.toInt.min(count-1)) + } + } + def evenlySpaced(buckets: Array[Double]): Boolean = { + val delta = buckets(1)-buckets(0) + // Technically you could have an evenly spaced bucket with NaN + // increments but then its a single bucket and this makes the + // fastBucketFunction simpler. + if (delta.isNaN() || delta.isInfinite()) { + return false + } + for (i <- 1 to buckets.length-1) { + if (buckets(i)-buckets(i-1) != delta) { + return false + } + } + true + } + // Decide which bucket function to pass to histogramPartition. We decide here + // rather than having a general function so that the decission need only be made + // once rather than once per shard + val bucketFunction = if (evenlySpaced(buckets)) { + fastBucketFunction(buckets(0), buckets(1)-buckets(0), buckets.length-1) _ + } else { + basicBucketFunction _ + } + self.mapPartitions(histogramPartition(bucketFunction)).reduce(mergeCounters) + } + } diff --git a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala new file mode 100644 index 0000000000..2ec7173511 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala @@ -0,0 +1,233 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.rdd + +import scala.math.abs +import scala.collection.mutable.ArrayBuffer + +import org.scalatest.FunSuite + +import org.apache.spark.SparkContext._ +import org.apache.spark.rdd._ +import org.apache.spark._ + +class DoubleRDDSuite extends FunSuite with SharedSparkContext { + // Verify tests on the histogram functionality. We test with both evenly + // and non-evenly spaced buckets as the bucket lookup function changes. + test("WorksOnEmpty") { + // Make sure that it works on an empty input + val rdd: RDD[Double] = sc.parallelize(Seq()) + val buckets: Array[Double] = Array(0.0, 10.0) + val histogramResults: Array[Long] = rdd.histogram(buckets) + val expectedHistogramResults: Array[Long] = Array(0) + assert(histogramResults === expectedHistogramResults) + } + test("WorksWithOutOfRangeWithOneBucket") { + // Verify that if all of the elements are out of range the counts are zero + val rdd: RDD[Double] = sc.parallelize(Seq(10.01,-0.01)) + val buckets: Array[Double] = Array(0.0, 10.0) + val histogramResults: Array[Long] = rdd.histogram(buckets) + val expectedHistogramResults: Array[Long] = Array(0) + assert(histogramResults === expectedHistogramResults) + } + test("WorksInRangeWithOneBucket") { + // Verify the basic case of one bucket and all elements in that bucket works + val rdd: RDD[Double] = sc.parallelize(Seq(1,2,3,4)) + val buckets: Array[Double] = Array(0.0, 10.0) + val histogramResults: Array[Long] = rdd.histogram(buckets) + val expectedHistogramResults: Array[Long] = Array(4) + assert(histogramResults === expectedHistogramResults) + } + test("WorksInRangeWithOneBucketExactMatch") { + // Verify the basic case of one bucket and all elements in that bucket works + val rdd: RDD[Double] = sc.parallelize(Seq(1,2,3,4)) + val buckets: Array[Double] = Array(1.0, 4.0) + val histogramResults: Array[Long] = rdd.histogram(buckets) + val expectedHistogramResults: Array[Long] = Array(4) + assert(histogramResults === expectedHistogramResults) + } + test("WorksWithOutOfRangeWithTwoBuckets") { + // Verify that out of range works with two buckets + val rdd: RDD[Double] = sc.parallelize(Seq(10.01,-0.01)) + val buckets: Array[Double] = Array(0.0, 5.0, 10.0) + val histogramResults: Array[Long] = rdd.histogram(buckets) + val expectedHistogramResults: Array[Long] = Array(0,0) + assert(histogramResults === expectedHistogramResults) + } + test("WorksWithOutOfRangeWithTwoUnEvenBuckets") { + // Verify that out of range works with two un even buckets + val rdd: RDD[Double] = sc.parallelize(Seq(10.01,-0.01)) + val buckets: Array[Double] = Array(0.0, 4.0, 10.0) + val histogramResults: Array[Long] = rdd.histogram(buckets) + val expectedHistogramResults: Array[Long] = Array(0,0) + assert(histogramResults === expectedHistogramResults) + } + test("WorksInRangeWithTwoBuckets") { + // Make sure that it works with two equally spaced buckets and elements in each + val rdd: RDD[Double] = sc.parallelize(Seq(1,2,3,5,6)) + val buckets: Array[Double] = Array(0.0, 5.0, 10.0) + val histogramResults: Array[Long] = rdd.histogram(buckets) + val expectedHistogramResults: Array[Long] = Array(3,2) + assert(histogramResults === expectedHistogramResults) + } + test("WorksInRangeWithTwoBucketsAndNaN") { + // Make sure that it works with two equally spaced buckets and elements in each + val rdd: RDD[Double] = sc.parallelize(Seq(1,2,3,5,6,Double.NaN)) + val buckets: Array[Double] = Array(0.0, 5.0, 10.0) + val histogramResults: Array[Long] = rdd.histogram(buckets) + val expectedHistogramResults: Array[Long] = Array(3,2) + assert(histogramResults === expectedHistogramResults) + } + test("WorksInRangeWithTwoUnevenBuckets") { + // Make sure that it works with two unequally spaced buckets and elements in each + val rdd: RDD[Double] = sc.parallelize(Seq(1,2,3,5,6)) + val buckets: Array[Double] = Array(0.0, 5.0, 11.0) + val histogramResults: Array[Long] = rdd.histogram(buckets) + val expectedHistogramResults: Array[Long] = Array(3,2) + assert(histogramResults === expectedHistogramResults) + } + test("WorksMixedRangeWithTwoUnevenBuckets") { + // Make sure that it works with two unequally spaced buckets and elements in each + val rdd: RDD[Double] = sc.parallelize(Seq(-0.01,0.0,1,2,3,5,6,11.0,11.01)) + val buckets: Array[Double] = Array(0.0, 5.0, 11.0) + val histogramResults: Array[Long] = rdd.histogram(buckets) + val expectedHistogramResults: Array[Long] = Array(4,3) + assert(histogramResults === expectedHistogramResults) + } + test("WorksMixedRangeWithFourUnevenBuckets") { + // Make sure that it works with two unequally spaced buckets and elements in each + val rdd: RDD[Double] = sc.parallelize(Seq(-0.01,0.0,1,2,3,5,6,11.01,12.0,199.0,200.0,200.1)) + val buckets: Array[Double] = Array(0.0, 5.0, 11.0,12.0,200.0) + val histogramResults: Array[Long] = rdd.histogram(buckets) + val expectedHistogramResults: Array[Long] = Array(4,2,1,3) + assert(histogramResults === expectedHistogramResults) + } + test("WorksMixedRangeWithUnevenBucketsAndNaN") { + // Make sure that it works with two unequally spaced buckets and elements in each + val rdd: RDD[Double] = sc.parallelize(Seq(-0.01,0.0,1,2,3,5,6,11.01,12.0,199.0,200.0,200.1,Double.NaN)) + val buckets: Array[Double] = Array(0.0, 5.0, 11.0,12.0,200.0) + val histogramResults: Array[Long] = rdd.histogram(buckets) + val expectedHistogramResults: Array[Long] = Array(4,2,1,3) + assert(histogramResults === expectedHistogramResults) + } + // Make sure this works with a NaN end bucket + test("WorksMixedRangeWithUnevenBucketsAndNaNAndNaNRange") { + // Make sure that it works with two unequally spaced buckets and elements in each + val rdd: RDD[Double] = sc.parallelize(Seq(-0.01,0.0,1,2,3,5,6,11.01,12.0,199.0,200.0,200.1,Double.NaN)) + val buckets: Array[Double] = Array(0.0, 5.0, 11.0,12.0,200.0,Double.NaN) + val histogramResults: Array[Long] = rdd.histogram(buckets) + val expectedHistogramResults: Array[Long] = Array(4,2,1,2,3) + assert(histogramResults === expectedHistogramResults) + } + // Make sure this works with a NaN end bucket and an inifity + test("WorksMixedRangeWithUnevenBucketsAndNaNAndNaNRangeAndInfity") { + // Make sure that it works with two unequally spaced buckets and elements in each + val rdd: RDD[Double] = sc.parallelize(Seq(-0.01,0.0,1,2,3,5,6,11.01,12.0,199.0,200.0,200.1,1.0/0.0,-1.0/0.0,Double.NaN)) + val buckets: Array[Double] = Array(0.0, 5.0, 11.0,12.0,200.0,Double.NaN) + val histogramResults: Array[Long] = rdd.histogram(buckets) + val expectedHistogramResults: Array[Long] = Array(4,2,1,2,4) + assert(histogramResults === expectedHistogramResults) + } + test("WorksWithOutOfRangeWithInfiniteBuckets") { + // Verify that out of range works with two buckets + val rdd: RDD[Double] = sc.parallelize(Seq(10.01,-0.01,Double.NaN)) + val buckets: Array[Double] = Array(-1.0/0.0 ,0.0, 1.0/0.0) + val histogramResults: Array[Long] = rdd.histogram(buckets) + val expectedHistogramResults: Array[Long] = Array(1,1) + assert(histogramResults === expectedHistogramResults) + } + // Test the failure mode with an invalid bucket array + test("ThrowsExceptionOnInvalidBucketArray") { + val rdd: RDD[Double] = sc.parallelize(Seq(1.0)) + // Empty array + intercept[IllegalArgumentException]{ + val buckets: Array[Double] = Array.empty[Double] + val result = rdd.histogram(buckets) + } + // Single element array + intercept[IllegalArgumentException] + { + val buckets: Array[Double] = Array(1.0) + val result = rdd.histogram(buckets) + } + } + + // Test automatic histogram function + test("WorksWithoutBucketsBasic") { + // Verify the basic case of one bucket and all elements in that bucket works + val rdd: RDD[Double] = sc.parallelize(Seq(1,2,3,4)) + val (histogramBuckets, histogramResults) = rdd.histogram(1) + val expectedHistogramResults: Array[Long] = Array(4) + val expectedHistogramBuckets: Array[Double] = Array(1.0,4.0) + assert(histogramResults === expectedHistogramResults) + assert(histogramBuckets === expectedHistogramBuckets) + } + test("WorksWithoutBucketsBasicTwo") { + // Verify the basic case of one bucket and all elements in that bucket works + val rdd: RDD[Double] = sc.parallelize(Seq(1,2,3,4)) + val (histogramBuckets, histogramResults) = rdd.histogram(2) + val expectedHistogramResults: Array[Long] = Array(2,2) + val expectedHistogramBuckets: Array[Double] = Array(1.0,2.5,4.0) + assert(histogramResults === expectedHistogramResults) + assert(histogramBuckets === expectedHistogramBuckets) + } + test("WorksWithoutBucketsWithMoreRequestedThanElements") { + // Verify the basic case of one bucket and all elements in that bucket works + val rdd: RDD[Double] = sc.parallelize(Seq(1,2)) + val (histogramBuckets, histogramResults) = rdd.histogram(10) + val expectedHistogramResults: Array[Long] = + Array(1, 0, 0, 0, 0, 0, 0, 0, 0, 1) + val expectedHistogramBuckets: Array[Double] = + Array(1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0) + assert(histogramResults === expectedHistogramResults) + assert(histogramBuckets === expectedHistogramBuckets) + } + // Test the failure mode with an invalid RDD + test("ThrowsExceptionOnInvalidRDDs") { + // infinity + intercept[UnsupportedOperationException]{ + val rdd: RDD[Double] = sc.parallelize(Seq(1,1.0/0.0)) + val result = rdd.histogram(1) + } + // NaN + intercept[UnsupportedOperationException] + { + val rdd: RDD[Double] = sc.parallelize(Seq(1,Double.NaN)) + val result = rdd.histogram(1) + } + // Empty + intercept[UnsupportedOperationException] + { + val rdd: RDD[Double] = sc.parallelize(Seq()) + val result = rdd.histogram(1) + } + // Single element + intercept[UnsupportedOperationException] + { + val rdd: RDD[Double] = sc.parallelize(Seq(1)) + val result = rdd.histogram(1) + } + // No Range + intercept[UnsupportedOperationException] + { + val rdd: RDD[Double] = sc.parallelize(Seq(1,1,1)) + val result = rdd.histogram(1) + } + } + +} -- cgit v1.2.3 From e58c69d955ef8faacb794a0c1666b21c1606453e Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Sun, 20 Oct 2013 01:17:13 -0700 Subject: Add tests for the Java implementation. --- core/src/test/scala/org/apache/spark/JavaAPISuite.java | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java index 591c1d498d..8a9c6e63e0 100644 --- a/core/src/test/scala/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java @@ -364,6 +364,20 @@ public class JavaAPISuite implements Serializable { List take = rdd.take(5); } + @Test + public void javaDoubleRDDHistoGram() { + JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); + // Test using generated buckets + Tuple2 results = rdd.histogram(2); + Double[] expected_buckets = {1.0, 2.5, 4.0}; + long[] expected_counts = {2, 2}; + Assert.assertArrayEquals(expected_buckets, results._1); + Assert.assertArrayEquals(expected_counts, results._2); + // Test with provided buckets + long[] histogram = rdd.histogram(expected_buckets); + Assert.assertArrayEquals(expected_counts, histogram); + } + @Test public void map() { JavaRDD rdd = sc.parallelize(Arrays.asList(1, 2, 3, 4, 5)); -- cgit v1.2.3 From 699f7d28c0347cb516fa17f94b53d7bc50f18346 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 21 Oct 2013 00:10:03 -0700 Subject: CR feedback --- .../org/apache/spark/api/java/JavaDoubleRDD.scala | 18 ++- .../org/apache/spark/rdd/DoubleRDDFunctions.scala | 68 +++++----- .../org/apache/spark/rdd/DoubleRDDSuite.scala | 140 ++++++++++++--------- 3 files changed, 125 insertions(+), 101 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index d2a2818e59..b002468442 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -167,12 +167,13 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav * value is 0 and the max is 100 and there are two buckets the resulting * buckets will be [0,50) [50,100]. bucketCount must be at least 1 * If the RDD contains infinity, NaN throws an exception - * If the elements in RDD do not vary (max == min) throws an exception + * If the elements in RDD do not vary (max == min) always returns a single bucket. */ def histogram(bucketCount: Int): Pair[Array[Double], Array[Long]] = { val result = srdd.histogram(bucketCount) (result._1.map(scala.Double.box(_)), result._2) } + /** * Compute a histogram using the provided buckets. The buckets are all open * to the left except for the last which is closed @@ -181,14 +182,21 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav * e.g 1<=x<10 , 10<=x<20, 20<=x<50 * And on the input of 1 and 50 we would have a histogram of 1,0,0 * - * Note: if your histogram is evenly spaced (e.g. [0,10,20,30]) this switches - * from an O(log n) inseration to O(1) per element. (where n = # buckets) + * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched + * from an O(log n) inseration to O(1) per element. (where n = # buckets) if you set evenBuckets + * to true. * buckets must be sorted and not contain any duplicates. * buckets array must be at least two elements - * All NaN entries are treated the same. + * All NaN entries are treated the same. If you have a NaN bucket it must be + * the maximum value of the last position and all NaN entries will be counted + * in that bucket. */ def histogram(buckets: Array[Double]): Array[Long] = { - srdd.histogram(buckets.map(_.toDouble)) + srdd.histogram(buckets.map(_.toDouble), false) + } + + def histogram(buckets: Array[Double], evenBuckets: Boolean): Array[Long] = { + srdd.histogram(buckets.map(_.toDouble), evenBuckets) } } diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index 776a83cefe..33738ee094 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -83,44 +83,50 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { * Compute a histogram of the data using bucketCount number of buckets evenly * spaced between the minimum and maximum of the RDD. For example if the min * value is 0 and the max is 100 and there are two buckets the resulting - * buckets will be [0,50) [50,100]. bucketCount must be at least 1 + * buckets will be [0, 50) [50, 100]. bucketCount must be at least 1 * If the RDD contains infinity, NaN throws an exception - * If the elements in RDD do not vary (max == min) throws an exception + * If the elements in RDD do not vary (max == min) always returns a single bucket. */ def histogram(bucketCount: Int): Pair[Array[Double], Array[Long]] = { // Compute the minimum and the maxium val (max: Double, min: Double) = self.mapPartitions { items => Iterator(items.foldRight(-1/0.0, Double.NaN)((e: Double, x: Pair[Double, Double]) => - (x._1.max(e),x._2.min(e)))) + (x._1.max(e), x._2.min(e)))) }.reduce { (maxmin1, maxmin2) => (maxmin1._1.max(maxmin2._1), maxmin1._2.min(maxmin2._2)) } if (max.isNaN() || max.isInfinity || min.isInfinity ) { - throw new UnsupportedOperationException("Histogram on either an empty RDD or RDD containing +-infinity or NaN") - } - if (max == min) { - throw new UnsupportedOperationException("Histogram with no range in elements") + throw new UnsupportedOperationException( + "Histogram on either an empty RDD or RDD containing +/-infinity or NaN") } val increment: Double = (max-min)/bucketCount.toDouble - val range = Range.Double.inclusive(min, max, increment) + val range = if (increment != 0) { + Range.Double.inclusive(min, max, increment) + } else { + List(min, min) + } val buckets: Array[Double] = range.toArray - (buckets,histogram(buckets)) + (buckets, histogram(buckets, true)) } + /** * Compute a histogram using the provided buckets. The buckets are all open * to the left except for the last which is closed * e.g. for the array - * [1,10,20,50] the buckets are [1,10) [10,20) [20,50] + * [1, 10, 20, 50] the buckets are [1, 10) [10, 20) [20, 50] * e.g 1<=x<10 , 10<=x<20, 20<=x<50 - * And on the input of 1 and 50 we would have a histogram of 1,0,0 + * And on the input of 1 and 50 we would have a histogram of 1, 0, 0 * - * Note: if your histogram is evenly spaced (e.g. [0,10,20,30]) this switches - * from an O(log n) inseration to O(1) per element. (where n = # buckets) + * Note: if your histogram is evenly spaced (e.g. [0, 10, 20, 30]) this can be switched + * from an O(log n) inseration to O(1) per element. (where n = # buckets) if you set evenBuckets + * to true. * buckets must be sorted and not contain any duplicates. * buckets array must be at least two elements - * All NaN entries are treated the same. + * All NaN entries are treated the same. If you have a NaN bucket it must be + * the maximum value of the last position and all NaN entries will be counted + * in that bucket. */ - def histogram(buckets: Array[Double]): Array[Long] = { + def histogram(buckets: Array[Double], evenBuckets: Boolean = false): Array[Long] = { if (buckets.length < 2) { throw new IllegalArgumentException("buckets array must have at least two elements") } @@ -129,11 +135,12 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { // to increment or returns None if there is no bucket. This is done so we can // specialize for uniformly distributed buckets and save the O(log n) binary // search cost. - def histogramPartition(bucketFunction: (Double) => Option[Int])(iter: Iterator[Double]): Iterator[Array[Long]] = { - val counters = new Array[Long](buckets.length-1) + def histogramPartition(bucketFunction: (Double) => Option[Int])(iter: Iterator[Double]): + Iterator[Array[Long]] = { + val counters = new Array[Long](buckets.length - 1) while (iter.hasNext) { bucketFunction(iter.next()) match { - case Some(x: Int) => {counters(x)+=1} + case Some(x: Int) => {counters(x) += 1} case _ => {} } } @@ -161,12 +168,12 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { } else { None } - } else if (location < buckets.length-1) { + } else if (location < buckets.length - 1) { // Exact match, just insert here Some(location) } else { // Exact match to the last element - Some(location-1) + Some(location - 1) } } // Determine the bucket function in constant time. Requires that buckets are evenly spaced @@ -175,34 +182,19 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { if (e.isNaN()) { return None } - val bucketNumber = (e-min)/(increment) + val bucketNumber = (e - min)/(increment) // We do this rather than buckets.lengthCompare(bucketNumber) // because Array[Double] fails to override it (for now). if (bucketNumber > count || bucketNumber < 0) { None } else { - Some(bucketNumber.toInt.min(count-1)) - } - } - def evenlySpaced(buckets: Array[Double]): Boolean = { - val delta = buckets(1)-buckets(0) - // Technically you could have an evenly spaced bucket with NaN - // increments but then its a single bucket and this makes the - // fastBucketFunction simpler. - if (delta.isNaN() || delta.isInfinite()) { - return false - } - for (i <- 1 to buckets.length-1) { - if (buckets(i)-buckets(i-1) != delta) { - return false - } + Some(bucketNumber.toInt.min(count - 1)) } - true } // Decide which bucket function to pass to histogramPartition. We decide here // rather than having a general function so that the decission need only be made // once rather than once per shard - val bucketFunction = if (evenlySpaced(buckets)) { + val bucketFunction = if (evenBuckets) { fastBucketFunction(buckets(0), buckets(1)-buckets(0), buckets.length-1) _ } else { basicBucketFunction _ diff --git a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala index 2ec7173511..071084485a 100644 --- a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala @@ -34,134 +34,151 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { val rdd: RDD[Double] = sc.parallelize(Seq()) val buckets: Array[Double] = Array(0.0, 10.0) val histogramResults: Array[Long] = rdd.histogram(buckets) + val histogramResults2: Array[Long] = rdd.histogram(buckets, true) val expectedHistogramResults: Array[Long] = Array(0) assert(histogramResults === expectedHistogramResults) + assert(histogramResults2 === expectedHistogramResults) } test("WorksWithOutOfRangeWithOneBucket") { // Verify that if all of the elements are out of range the counts are zero - val rdd: RDD[Double] = sc.parallelize(Seq(10.01,-0.01)) + val rdd: RDD[Double] = sc.parallelize(Seq(10.01, -0.01)) val buckets: Array[Double] = Array(0.0, 10.0) val histogramResults: Array[Long] = rdd.histogram(buckets) + val histogramResults2: Array[Long] = rdd.histogram(buckets, true) val expectedHistogramResults: Array[Long] = Array(0) assert(histogramResults === expectedHistogramResults) + assert(histogramResults2 === expectedHistogramResults) } test("WorksInRangeWithOneBucket") { // Verify the basic case of one bucket and all elements in that bucket works - val rdd: RDD[Double] = sc.parallelize(Seq(1,2,3,4)) + val rdd: RDD[Double] = sc.parallelize(Seq(1, 2, 3, 4)) val buckets: Array[Double] = Array(0.0, 10.0) val histogramResults: Array[Long] = rdd.histogram(buckets) + val histogramResults2: Array[Long] = rdd.histogram(buckets, true) val expectedHistogramResults: Array[Long] = Array(4) assert(histogramResults === expectedHistogramResults) + assert(histogramResults2 === expectedHistogramResults) } test("WorksInRangeWithOneBucketExactMatch") { // Verify the basic case of one bucket and all elements in that bucket works - val rdd: RDD[Double] = sc.parallelize(Seq(1,2,3,4)) + val rdd: RDD[Double] = sc.parallelize(Seq(1, 2, 3, 4)) val buckets: Array[Double] = Array(1.0, 4.0) val histogramResults: Array[Long] = rdd.histogram(buckets) + val histogramResults2: Array[Long] = rdd.histogram(buckets, true) val expectedHistogramResults: Array[Long] = Array(4) assert(histogramResults === expectedHistogramResults) + assert(histogramResults2 === expectedHistogramResults) } test("WorksWithOutOfRangeWithTwoBuckets") { // Verify that out of range works with two buckets - val rdd: RDD[Double] = sc.parallelize(Seq(10.01,-0.01)) + val rdd: RDD[Double] = sc.parallelize(Seq(10.01, -0.01)) val buckets: Array[Double] = Array(0.0, 5.0, 10.0) val histogramResults: Array[Long] = rdd.histogram(buckets) - val expectedHistogramResults: Array[Long] = Array(0,0) + val histogramResults2: Array[Long] = rdd.histogram(buckets, true) + val expectedHistogramResults: Array[Long] = Array(0, 0) assert(histogramResults === expectedHistogramResults) + assert(histogramResults2 === expectedHistogramResults) } test("WorksWithOutOfRangeWithTwoUnEvenBuckets") { // Verify that out of range works with two un even buckets - val rdd: RDD[Double] = sc.parallelize(Seq(10.01,-0.01)) + val rdd: RDD[Double] = sc.parallelize(Seq(10.01, -0.01)) val buckets: Array[Double] = Array(0.0, 4.0, 10.0) val histogramResults: Array[Long] = rdd.histogram(buckets) - val expectedHistogramResults: Array[Long] = Array(0,0) + val expectedHistogramResults: Array[Long] = Array(0, 0) assert(histogramResults === expectedHistogramResults) } test("WorksInRangeWithTwoBuckets") { // Make sure that it works with two equally spaced buckets and elements in each - val rdd: RDD[Double] = sc.parallelize(Seq(1,2,3,5,6)) + val rdd: RDD[Double] = sc.parallelize(Seq(1, 2, 3, 5, 6)) val buckets: Array[Double] = Array(0.0, 5.0, 10.0) val histogramResults: Array[Long] = rdd.histogram(buckets) - val expectedHistogramResults: Array[Long] = Array(3,2) + val histogramResults2: Array[Long] = rdd.histogram(buckets, true) + val expectedHistogramResults: Array[Long] = Array(3, 2) assert(histogramResults === expectedHistogramResults) + assert(histogramResults2 === expectedHistogramResults) } test("WorksInRangeWithTwoBucketsAndNaN") { // Make sure that it works with two equally spaced buckets and elements in each - val rdd: RDD[Double] = sc.parallelize(Seq(1,2,3,5,6,Double.NaN)) + val rdd: RDD[Double] = sc.parallelize(Seq(1, 2, 3, 5, 6, Double.NaN)) val buckets: Array[Double] = Array(0.0, 5.0, 10.0) val histogramResults: Array[Long] = rdd.histogram(buckets) - val expectedHistogramResults: Array[Long] = Array(3,2) + val histogramResults2: Array[Long] = rdd.histogram(buckets, true) + val expectedHistogramResults: Array[Long] = Array(3, 2) assert(histogramResults === expectedHistogramResults) + assert(histogramResults2 === expectedHistogramResults) } test("WorksInRangeWithTwoUnevenBuckets") { // Make sure that it works with two unequally spaced buckets and elements in each - val rdd: RDD[Double] = sc.parallelize(Seq(1,2,3,5,6)) + val rdd: RDD[Double] = sc.parallelize(Seq(1, 2, 3, 5, 6)) val buckets: Array[Double] = Array(0.0, 5.0, 11.0) val histogramResults: Array[Long] = rdd.histogram(buckets) - val expectedHistogramResults: Array[Long] = Array(3,2) + val expectedHistogramResults: Array[Long] = Array(3, 2) assert(histogramResults === expectedHistogramResults) } test("WorksMixedRangeWithTwoUnevenBuckets") { // Make sure that it works with two unequally spaced buckets and elements in each - val rdd: RDD[Double] = sc.parallelize(Seq(-0.01,0.0,1,2,3,5,6,11.0,11.01)) + val rdd: RDD[Double] = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.0, 11.01)) val buckets: Array[Double] = Array(0.0, 5.0, 11.0) val histogramResults: Array[Long] = rdd.histogram(buckets) - val expectedHistogramResults: Array[Long] = Array(4,3) + val expectedHistogramResults: Array[Long] = Array(4, 3) assert(histogramResults === expectedHistogramResults) } test("WorksMixedRangeWithFourUnevenBuckets") { // Make sure that it works with two unequally spaced buckets and elements in each - val rdd: RDD[Double] = sc.parallelize(Seq(-0.01,0.0,1,2,3,5,6,11.01,12.0,199.0,200.0,200.1)) - val buckets: Array[Double] = Array(0.0, 5.0, 11.0,12.0,200.0) + val rdd: RDD[Double] = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, + 200.0, 200.1)) + val buckets: Array[Double] = Array(0.0, 5.0, 11.0, 12.0, 200.0) val histogramResults: Array[Long] = rdd.histogram(buckets) - val expectedHistogramResults: Array[Long] = Array(4,2,1,3) + val expectedHistogramResults: Array[Long] = Array(4, 2, 1, 3) assert(histogramResults === expectedHistogramResults) } test("WorksMixedRangeWithUnevenBucketsAndNaN") { // Make sure that it works with two unequally spaced buckets and elements in each - val rdd: RDD[Double] = sc.parallelize(Seq(-0.01,0.0,1,2,3,5,6,11.01,12.0,199.0,200.0,200.1,Double.NaN)) - val buckets: Array[Double] = Array(0.0, 5.0, 11.0,12.0,200.0) + val rdd: RDD[Double] = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, + 200.0, 200.1, Double.NaN)) + val buckets: Array[Double] = Array(0.0, 5.0, 11.0, 12.0, 200.0) val histogramResults: Array[Long] = rdd.histogram(buckets) - val expectedHistogramResults: Array[Long] = Array(4,2,1,3) + val expectedHistogramResults: Array[Long] = Array(4, 2, 1, 3) assert(histogramResults === expectedHistogramResults) } // Make sure this works with a NaN end bucket test("WorksMixedRangeWithUnevenBucketsAndNaNAndNaNRange") { // Make sure that it works with two unequally spaced buckets and elements in each - val rdd: RDD[Double] = sc.parallelize(Seq(-0.01,0.0,1,2,3,5,6,11.01,12.0,199.0,200.0,200.1,Double.NaN)) - val buckets: Array[Double] = Array(0.0, 5.0, 11.0,12.0,200.0,Double.NaN) + val rdd: RDD[Double] = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, + 200.0, 200.1, Double.NaN)) + val buckets: Array[Double] = Array(0.0, 5.0, 11.0, 12.0, 200.0, Double.NaN) val histogramResults: Array[Long] = rdd.histogram(buckets) - val expectedHistogramResults: Array[Long] = Array(4,2,1,2,3) + val expectedHistogramResults: Array[Long] = Array(4, 2, 1, 2, 3) assert(histogramResults === expectedHistogramResults) } // Make sure this works with a NaN end bucket and an inifity test("WorksMixedRangeWithUnevenBucketsAndNaNAndNaNRangeAndInfity") { // Make sure that it works with two unequally spaced buckets and elements in each - val rdd: RDD[Double] = sc.parallelize(Seq(-0.01,0.0,1,2,3,5,6,11.01,12.0,199.0,200.0,200.1,1.0/0.0,-1.0/0.0,Double.NaN)) - val buckets: Array[Double] = Array(0.0, 5.0, 11.0,12.0,200.0,Double.NaN) + val rdd: RDD[Double] = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, + 200.0, 200.1, 1.0/0.0, -1.0/0.0, Double.NaN)) + val buckets: Array[Double] = Array(0.0, 5.0, 11.0, 12.0, 200.0, Double.NaN) val histogramResults: Array[Long] = rdd.histogram(buckets) - val expectedHistogramResults: Array[Long] = Array(4,2,1,2,4) + val expectedHistogramResults: Array[Long] = Array(4, 2, 1, 2, 4) assert(histogramResults === expectedHistogramResults) } test("WorksWithOutOfRangeWithInfiniteBuckets") { // Verify that out of range works with two buckets - val rdd: RDD[Double] = sc.parallelize(Seq(10.01,-0.01,Double.NaN)) - val buckets: Array[Double] = Array(-1.0/0.0 ,0.0, 1.0/0.0) + val rdd: RDD[Double] = sc.parallelize(Seq(10.01, -0.01, Double.NaN)) + val buckets: Array[Double] = Array(-1.0/0.0 , 0.0, 1.0/0.0) val histogramResults: Array[Long] = rdd.histogram(buckets) - val expectedHistogramResults: Array[Long] = Array(1,1) + val expectedHistogramResults: Array[Long] = Array(1, 1) assert(histogramResults === expectedHistogramResults) } // Test the failure mode with an invalid bucket array test("ThrowsExceptionOnInvalidBucketArray") { val rdd: RDD[Double] = sc.parallelize(Seq(1.0)) // Empty array - intercept[IllegalArgumentException]{ + intercept[IllegalArgumentException] { val buckets: Array[Double] = Array.empty[Double] val result = rdd.histogram(buckets) } // Single element array - intercept[IllegalArgumentException] - { + intercept[IllegalArgumentException] { val buckets: Array[Double] = Array(1.0) val result = rdd.histogram(buckets) } @@ -170,25 +187,45 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { // Test automatic histogram function test("WorksWithoutBucketsBasic") { // Verify the basic case of one bucket and all elements in that bucket works - val rdd: RDD[Double] = sc.parallelize(Seq(1,2,3,4)) + val rdd: RDD[Double] = sc.parallelize(Seq(1, 2, 3, 4)) val (histogramBuckets, histogramResults) = rdd.histogram(1) val expectedHistogramResults: Array[Long] = Array(4) - val expectedHistogramBuckets: Array[Double] = Array(1.0,4.0) + val expectedHistogramBuckets: Array[Double] = Array(1.0, 4.0) + assert(histogramResults === expectedHistogramResults) + assert(histogramBuckets === expectedHistogramBuckets) + } + // Test automatic histogram function with a single element + test("WorksWithoutBucketsBasicSingleElement") { + // Verify the basic case of one bucket and all elements in that bucket works + val rdd: RDD[Double] = sc.parallelize(Seq(1)) + val (histogramBuckets, histogramResults) = rdd.histogram(1) + val expectedHistogramResults: Array[Long] = Array(1) + val expectedHistogramBuckets: Array[Double] = Array(1.0, 1.0) + assert(histogramResults === expectedHistogramResults) + assert(histogramBuckets === expectedHistogramBuckets) + } + // Test automatic histogram function with a single element + test("WorksWithoutBucketsBasicNoRange") { + // Verify the basic case of one bucket and all elements in that bucket works + val rdd: RDD[Double] = sc.parallelize(Seq(1, 1, 1, 1)) + val (histogramBuckets, histogramResults) = rdd.histogram(1) + val expectedHistogramResults: Array[Long] = Array(4) + val expectedHistogramBuckets: Array[Double] = Array(1.0, 1.0) assert(histogramResults === expectedHistogramResults) assert(histogramBuckets === expectedHistogramBuckets) } test("WorksWithoutBucketsBasicTwo") { // Verify the basic case of one bucket and all elements in that bucket works - val rdd: RDD[Double] = sc.parallelize(Seq(1,2,3,4)) + val rdd: RDD[Double] = sc.parallelize(Seq(1, 2, 3, 4)) val (histogramBuckets, histogramResults) = rdd.histogram(2) - val expectedHistogramResults: Array[Long] = Array(2,2) - val expectedHistogramBuckets: Array[Double] = Array(1.0,2.5,4.0) + val expectedHistogramResults: Array[Long] = Array(2, 2) + val expectedHistogramBuckets: Array[Double] = Array(1.0, 2.5, 4.0) assert(histogramResults === expectedHistogramResults) assert(histogramBuckets === expectedHistogramBuckets) } test("WorksWithoutBucketsWithMoreRequestedThanElements") { // Verify the basic case of one bucket and all elements in that bucket works - val rdd: RDD[Double] = sc.parallelize(Seq(1,2)) + val rdd: RDD[Double] = sc.parallelize(Seq(1, 2)) val (histogramBuckets, histogramResults) = rdd.histogram(10) val expectedHistogramResults: Array[Long] = Array(1, 0, 0, 0, 0, 0, 0, 0, 0, 1) @@ -197,37 +234,24 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { assert(histogramResults === expectedHistogramResults) assert(histogramBuckets === expectedHistogramBuckets) } + // Test the failure mode with an invalid RDD test("ThrowsExceptionOnInvalidRDDs") { // infinity - intercept[UnsupportedOperationException]{ - val rdd: RDD[Double] = sc.parallelize(Seq(1,1.0/0.0)) + intercept[UnsupportedOperationException] { + val rdd: RDD[Double] = sc.parallelize(Seq(1, 1.0/0.0)) val result = rdd.histogram(1) } // NaN - intercept[UnsupportedOperationException] - { - val rdd: RDD[Double] = sc.parallelize(Seq(1,Double.NaN)) + intercept[UnsupportedOperationException] { + val rdd: RDD[Double] = sc.parallelize(Seq(1, Double.NaN)) val result = rdd.histogram(1) } // Empty - intercept[UnsupportedOperationException] - { + intercept[UnsupportedOperationException] { val rdd: RDD[Double] = sc.parallelize(Seq()) val result = rdd.histogram(1) } - // Single element - intercept[UnsupportedOperationException] - { - val rdd: RDD[Double] = sc.parallelize(Seq(1)) - val result = rdd.histogram(1) - } - // No Range - intercept[UnsupportedOperationException] - { - val rdd: RDD[Double] = sc.parallelize(Seq(1,1,1)) - val result = rdd.histogram(1) - } } } -- cgit v1.2.3 From 092b87e7c8f723a0c4ecf1dfb5379cad4c39d37f Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 21 Oct 2013 00:20:15 -0700 Subject: Remove extranious type definitions from inside of tests --- .../org/apache/spark/rdd/DoubleRDDSuite.scala | 172 ++++++++++----------- 1 file changed, 86 insertions(+), 86 deletions(-) diff --git a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala index 071084485a..0d8ac19024 100644 --- a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala @@ -32,154 +32,154 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { test("WorksOnEmpty") { // Make sure that it works on an empty input val rdd: RDD[Double] = sc.parallelize(Seq()) - val buckets: Array[Double] = Array(0.0, 10.0) - val histogramResults: Array[Long] = rdd.histogram(buckets) - val histogramResults2: Array[Long] = rdd.histogram(buckets, true) - val expectedHistogramResults: Array[Long] = Array(0) + val buckets = Array(0.0, 10.0) + val histogramResults = rdd.histogram(buckets) + val histogramResults2 = rdd.histogram(buckets, true) + val expectedHistogramResults = Array(0) assert(histogramResults === expectedHistogramResults) assert(histogramResults2 === expectedHistogramResults) } test("WorksWithOutOfRangeWithOneBucket") { // Verify that if all of the elements are out of range the counts are zero - val rdd: RDD[Double] = sc.parallelize(Seq(10.01, -0.01)) - val buckets: Array[Double] = Array(0.0, 10.0) - val histogramResults: Array[Long] = rdd.histogram(buckets) - val histogramResults2: Array[Long] = rdd.histogram(buckets, true) - val expectedHistogramResults: Array[Long] = Array(0) + val rdd = sc.parallelize(Seq(10.01, -0.01)) + val buckets = Array(0.0, 10.0) + val histogramResults = rdd.histogram(buckets) + val histogramResults2 = rdd.histogram(buckets, true) + val expectedHistogramResults = Array(0) assert(histogramResults === expectedHistogramResults) assert(histogramResults2 === expectedHistogramResults) } test("WorksInRangeWithOneBucket") { // Verify the basic case of one bucket and all elements in that bucket works - val rdd: RDD[Double] = sc.parallelize(Seq(1, 2, 3, 4)) - val buckets: Array[Double] = Array(0.0, 10.0) - val histogramResults: Array[Long] = rdd.histogram(buckets) - val histogramResults2: Array[Long] = rdd.histogram(buckets, true) - val expectedHistogramResults: Array[Long] = Array(4) + val rdd = sc.parallelize(Seq(1, 2, 3, 4)) + val buckets = Array(0.0, 10.0) + val histogramResults = rdd.histogram(buckets) + val histogramResults2 = rdd.histogram(buckets, true) + val expectedHistogramResults = Array(4) assert(histogramResults === expectedHistogramResults) assert(histogramResults2 === expectedHistogramResults) } test("WorksInRangeWithOneBucketExactMatch") { // Verify the basic case of one bucket and all elements in that bucket works - val rdd: RDD[Double] = sc.parallelize(Seq(1, 2, 3, 4)) - val buckets: Array[Double] = Array(1.0, 4.0) - val histogramResults: Array[Long] = rdd.histogram(buckets) - val histogramResults2: Array[Long] = rdd.histogram(buckets, true) - val expectedHistogramResults: Array[Long] = Array(4) + val rdd = sc.parallelize(Seq(1, 2, 3, 4)) + val buckets = Array(1.0, 4.0) + val histogramResults = rdd.histogram(buckets) + val histogramResults2 = rdd.histogram(buckets, true) + val expectedHistogramResults = Array(4) assert(histogramResults === expectedHistogramResults) assert(histogramResults2 === expectedHistogramResults) } test("WorksWithOutOfRangeWithTwoBuckets") { // Verify that out of range works with two buckets - val rdd: RDD[Double] = sc.parallelize(Seq(10.01, -0.01)) - val buckets: Array[Double] = Array(0.0, 5.0, 10.0) - val histogramResults: Array[Long] = rdd.histogram(buckets) - val histogramResults2: Array[Long] = rdd.histogram(buckets, true) - val expectedHistogramResults: Array[Long] = Array(0, 0) + val rdd = sc.parallelize(Seq(10.01, -0.01)) + val buckets = Array(0.0, 5.0, 10.0) + val histogramResults = rdd.histogram(buckets) + val histogramResults2 = rdd.histogram(buckets, true) + val expectedHistogramResults = Array(0, 0) assert(histogramResults === expectedHistogramResults) assert(histogramResults2 === expectedHistogramResults) } test("WorksWithOutOfRangeWithTwoUnEvenBuckets") { // Verify that out of range works with two un even buckets - val rdd: RDD[Double] = sc.parallelize(Seq(10.01, -0.01)) - val buckets: Array[Double] = Array(0.0, 4.0, 10.0) - val histogramResults: Array[Long] = rdd.histogram(buckets) - val expectedHistogramResults: Array[Long] = Array(0, 0) + val rdd = sc.parallelize(Seq(10.01, -0.01)) + val buckets = Array(0.0, 4.0, 10.0) + val histogramResults = rdd.histogram(buckets) + val expectedHistogramResults = Array(0, 0) assert(histogramResults === expectedHistogramResults) } test("WorksInRangeWithTwoBuckets") { // Make sure that it works with two equally spaced buckets and elements in each - val rdd: RDD[Double] = sc.parallelize(Seq(1, 2, 3, 5, 6)) - val buckets: Array[Double] = Array(0.0, 5.0, 10.0) - val histogramResults: Array[Long] = rdd.histogram(buckets) - val histogramResults2: Array[Long] = rdd.histogram(buckets, true) - val expectedHistogramResults: Array[Long] = Array(3, 2) + val rdd = sc.parallelize(Seq(1, 2, 3, 5, 6)) + val buckets = Array(0.0, 5.0, 10.0) + val histogramResults = rdd.histogram(buckets) + val histogramResults2 = rdd.histogram(buckets, true) + val expectedHistogramResults = Array(3, 2) assert(histogramResults === expectedHistogramResults) assert(histogramResults2 === expectedHistogramResults) } test("WorksInRangeWithTwoBucketsAndNaN") { // Make sure that it works with two equally spaced buckets and elements in each - val rdd: RDD[Double] = sc.parallelize(Seq(1, 2, 3, 5, 6, Double.NaN)) - val buckets: Array[Double] = Array(0.0, 5.0, 10.0) - val histogramResults: Array[Long] = rdd.histogram(buckets) - val histogramResults2: Array[Long] = rdd.histogram(buckets, true) - val expectedHistogramResults: Array[Long] = Array(3, 2) + val rdd = sc.parallelize(Seq(1, 2, 3, 5, 6, Double.NaN)) + val buckets = Array(0.0, 5.0, 10.0) + val histogramResults = rdd.histogram(buckets) + val histogramResults2 = rdd.histogram(buckets, true) + val expectedHistogramResults = Array(3, 2) assert(histogramResults === expectedHistogramResults) assert(histogramResults2 === expectedHistogramResults) } test("WorksInRangeWithTwoUnevenBuckets") { // Make sure that it works with two unequally spaced buckets and elements in each - val rdd: RDD[Double] = sc.parallelize(Seq(1, 2, 3, 5, 6)) - val buckets: Array[Double] = Array(0.0, 5.0, 11.0) - val histogramResults: Array[Long] = rdd.histogram(buckets) - val expectedHistogramResults: Array[Long] = Array(3, 2) + val rdd = sc.parallelize(Seq(1, 2, 3, 5, 6)) + val buckets = Array(0.0, 5.0, 11.0) + val histogramResults = rdd.histogram(buckets) + val expectedHistogramResults = Array(3, 2) assert(histogramResults === expectedHistogramResults) } test("WorksMixedRangeWithTwoUnevenBuckets") { // Make sure that it works with two unequally spaced buckets and elements in each - val rdd: RDD[Double] = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.0, 11.01)) - val buckets: Array[Double] = Array(0.0, 5.0, 11.0) - val histogramResults: Array[Long] = rdd.histogram(buckets) - val expectedHistogramResults: Array[Long] = Array(4, 3) + val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.0, 11.01)) + val buckets = Array(0.0, 5.0, 11.0) + val histogramResults = rdd.histogram(buckets) + val expectedHistogramResults = Array(4, 3) assert(histogramResults === expectedHistogramResults) } test("WorksMixedRangeWithFourUnevenBuckets") { // Make sure that it works with two unequally spaced buckets and elements in each - val rdd: RDD[Double] = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, + val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, 200.0, 200.1)) - val buckets: Array[Double] = Array(0.0, 5.0, 11.0, 12.0, 200.0) - val histogramResults: Array[Long] = rdd.histogram(buckets) - val expectedHistogramResults: Array[Long] = Array(4, 2, 1, 3) + val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0) + val histogramResults = rdd.histogram(buckets) + val expectedHistogramResults = Array(4, 2, 1, 3) assert(histogramResults === expectedHistogramResults) } test("WorksMixedRangeWithUnevenBucketsAndNaN") { // Make sure that it works with two unequally spaced buckets and elements in each - val rdd: RDD[Double] = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, + val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, 200.0, 200.1, Double.NaN)) - val buckets: Array[Double] = Array(0.0, 5.0, 11.0, 12.0, 200.0) - val histogramResults: Array[Long] = rdd.histogram(buckets) - val expectedHistogramResults: Array[Long] = Array(4, 2, 1, 3) + val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0) + val histogramResults = rdd.histogram(buckets) + val expectedHistogramResults = Array(4, 2, 1, 3) assert(histogramResults === expectedHistogramResults) } // Make sure this works with a NaN end bucket test("WorksMixedRangeWithUnevenBucketsAndNaNAndNaNRange") { // Make sure that it works with two unequally spaced buckets and elements in each - val rdd: RDD[Double] = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, + val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, 200.0, 200.1, Double.NaN)) - val buckets: Array[Double] = Array(0.0, 5.0, 11.0, 12.0, 200.0, Double.NaN) - val histogramResults: Array[Long] = rdd.histogram(buckets) - val expectedHistogramResults: Array[Long] = Array(4, 2, 1, 2, 3) + val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0, Double.NaN) + val histogramResults = rdd.histogram(buckets) + val expectedHistogramResults = Array(4, 2, 1, 2, 3) assert(histogramResults === expectedHistogramResults) } // Make sure this works with a NaN end bucket and an inifity test("WorksMixedRangeWithUnevenBucketsAndNaNAndNaNRangeAndInfity") { // Make sure that it works with two unequally spaced buckets and elements in each - val rdd: RDD[Double] = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, + val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, 200.0, 200.1, 1.0/0.0, -1.0/0.0, Double.NaN)) - val buckets: Array[Double] = Array(0.0, 5.0, 11.0, 12.0, 200.0, Double.NaN) - val histogramResults: Array[Long] = rdd.histogram(buckets) - val expectedHistogramResults: Array[Long] = Array(4, 2, 1, 2, 4) + val buckets = Array(0.0, 5.0, 11.0, 12.0, 200.0, Double.NaN) + val histogramResults = rdd.histogram(buckets) + val expectedHistogramResults = Array(4, 2, 1, 2, 4) assert(histogramResults === expectedHistogramResults) } test("WorksWithOutOfRangeWithInfiniteBuckets") { // Verify that out of range works with two buckets - val rdd: RDD[Double] = sc.parallelize(Seq(10.01, -0.01, Double.NaN)) - val buckets: Array[Double] = Array(-1.0/0.0 , 0.0, 1.0/0.0) - val histogramResults: Array[Long] = rdd.histogram(buckets) - val expectedHistogramResults: Array[Long] = Array(1, 1) + val rdd = sc.parallelize(Seq(10.01, -0.01, Double.NaN)) + val buckets = Array(-1.0/0.0 , 0.0, 1.0/0.0) + val histogramResults = rdd.histogram(buckets) + val expectedHistogramResults = Array(1, 1) assert(histogramResults === expectedHistogramResults) } // Test the failure mode with an invalid bucket array test("ThrowsExceptionOnInvalidBucketArray") { - val rdd: RDD[Double] = sc.parallelize(Seq(1.0)) + val rdd = sc.parallelize(Seq(1.0)) // Empty array intercept[IllegalArgumentException] { - val buckets: Array[Double] = Array.empty[Double] + val buckets = Array.empty[Double] val result = rdd.histogram(buckets) } // Single element array intercept[IllegalArgumentException] { - val buckets: Array[Double] = Array(1.0) + val buckets = Array(1.0) val result = rdd.histogram(buckets) } } @@ -187,49 +187,49 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { // Test automatic histogram function test("WorksWithoutBucketsBasic") { // Verify the basic case of one bucket and all elements in that bucket works - val rdd: RDD[Double] = sc.parallelize(Seq(1, 2, 3, 4)) + val rdd = sc.parallelize(Seq(1, 2, 3, 4)) val (histogramBuckets, histogramResults) = rdd.histogram(1) - val expectedHistogramResults: Array[Long] = Array(4) - val expectedHistogramBuckets: Array[Double] = Array(1.0, 4.0) + val expectedHistogramResults = Array(4) + val expectedHistogramBuckets = Array(1.0, 4.0) assert(histogramResults === expectedHistogramResults) assert(histogramBuckets === expectedHistogramBuckets) } // Test automatic histogram function with a single element test("WorksWithoutBucketsBasicSingleElement") { // Verify the basic case of one bucket and all elements in that bucket works - val rdd: RDD[Double] = sc.parallelize(Seq(1)) + val rdd = sc.parallelize(Seq(1)) val (histogramBuckets, histogramResults) = rdd.histogram(1) - val expectedHistogramResults: Array[Long] = Array(1) - val expectedHistogramBuckets: Array[Double] = Array(1.0, 1.0) + val expectedHistogramResults = Array(1) + val expectedHistogramBuckets = Array(1.0, 1.0) assert(histogramResults === expectedHistogramResults) assert(histogramBuckets === expectedHistogramBuckets) } // Test automatic histogram function with a single element test("WorksWithoutBucketsBasicNoRange") { // Verify the basic case of one bucket and all elements in that bucket works - val rdd: RDD[Double] = sc.parallelize(Seq(1, 1, 1, 1)) + val rdd = sc.parallelize(Seq(1, 1, 1, 1)) val (histogramBuckets, histogramResults) = rdd.histogram(1) - val expectedHistogramResults: Array[Long] = Array(4) - val expectedHistogramBuckets: Array[Double] = Array(1.0, 1.0) + val expectedHistogramResults = Array(4) + val expectedHistogramBuckets = Array(1.0, 1.0) assert(histogramResults === expectedHistogramResults) assert(histogramBuckets === expectedHistogramBuckets) } test("WorksWithoutBucketsBasicTwo") { // Verify the basic case of one bucket and all elements in that bucket works - val rdd: RDD[Double] = sc.parallelize(Seq(1, 2, 3, 4)) + val rdd = sc.parallelize(Seq(1, 2, 3, 4)) val (histogramBuckets, histogramResults) = rdd.histogram(2) - val expectedHistogramResults: Array[Long] = Array(2, 2) - val expectedHistogramBuckets: Array[Double] = Array(1.0, 2.5, 4.0) + val expectedHistogramResults = Array(2, 2) + val expectedHistogramBuckets = Array(1.0, 2.5, 4.0) assert(histogramResults === expectedHistogramResults) assert(histogramBuckets === expectedHistogramBuckets) } test("WorksWithoutBucketsWithMoreRequestedThanElements") { // Verify the basic case of one bucket and all elements in that bucket works - val rdd: RDD[Double] = sc.parallelize(Seq(1, 2)) + val rdd = sc.parallelize(Seq(1, 2)) val (histogramBuckets, histogramResults) = rdd.histogram(10) - val expectedHistogramResults: Array[Long] = + val expectedHistogramResults = Array(1, 0, 0, 0, 0, 0, 0, 0, 0, 1) - val expectedHistogramBuckets: Array[Double] = + val expectedHistogramBuckets = Array(1.0, 1.1, 1.2, 1.3, 1.4, 1.5, 1.6, 1.7, 1.8, 1.9, 2.0) assert(histogramResults === expectedHistogramResults) assert(histogramBuckets === expectedHistogramBuckets) @@ -239,12 +239,12 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { test("ThrowsExceptionOnInvalidRDDs") { // infinity intercept[UnsupportedOperationException] { - val rdd: RDD[Double] = sc.parallelize(Seq(1, 1.0/0.0)) + val rdd = sc.parallelize(Seq(1, 1.0/0.0)) val result = rdd.histogram(1) } // NaN intercept[UnsupportedOperationException] { - val rdd: RDD[Double] = sc.parallelize(Seq(1, Double.NaN)) + val rdd = sc.parallelize(Seq(1, Double.NaN)) val result = rdd.histogram(1) } // Empty -- cgit v1.2.3 From 20b33bc4b5de1addd943c7a1e6d5d2366d9cd445 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 21 Oct 2013 00:21:37 -0700 Subject: Remove extranious type declerations --- core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala index 33738ee094..02d75eccc5 100644 --- a/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala +++ b/core/src/main/scala/org/apache/spark/rdd/DoubleRDDFunctions.scala @@ -99,13 +99,13 @@ class DoubleRDDFunctions(self: RDD[Double]) extends Logging with Serializable { throw new UnsupportedOperationException( "Histogram on either an empty RDD or RDD containing +/-infinity or NaN") } - val increment: Double = (max-min)/bucketCount.toDouble + val increment = (max-min)/bucketCount.toDouble val range = if (increment != 0) { Range.Double.inclusive(min, max, increment) } else { List(min, min) } - val buckets: Array[Double] = range.toArray + val buckets = range.toArray (buckets, histogram(buckets, true)) } -- cgit v1.2.3 From a48d88d206fae348720ab077a624b3c57293374f Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sat, 2 Nov 2013 21:13:18 -0700 Subject: Replace magic lengths with constants in PySpark. Write the length of the accumulators section up-front rather than terminating it with a negative length. I find this easier to read. --- .../org/apache/spark/api/python/PythonRDD.scala | 26 +++++++++++++--------- python/pyspark/serializers.py | 6 +++++ python/pyspark/worker.py | 13 ++++++----- 3 files changed, 29 insertions(+), 16 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 12b4d94a56..0d5913ec60 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -132,7 +132,7 @@ private[spark] class PythonRDD[T: ClassManifest]( val obj = new Array[Byte](length) stream.readFully(obj) obj - case -3 => + case SpecialLengths.TIMING_DATA => // Timing data from worker val bootTime = stream.readLong() val initTime = stream.readLong() @@ -143,24 +143,24 @@ private[spark] class PythonRDD[T: ClassManifest]( val total = finishTime - startTime logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, init, finish)) read - case -2 => + case SpecialLengths.PYTHON_EXCEPTION_THROWN => // Signals that an exception has been thrown in python val exLength = stream.readInt() val obj = new Array[Byte](exLength) stream.readFully(obj) throw new PythonException(new String(obj)) - case -1 => + case SpecialLengths.END_OF_DATA_SECTION => // We've finished the data section of the output, but we can still - // read some accumulator updates; let's do that, breaking when we - // get a negative length record. - var len2 = stream.readInt() - while (len2 >= 0) { - val update = new Array[Byte](len2) + // read some accumulator updates: + val numAccumulatorUpdates = stream.readInt() + (1 to numAccumulatorUpdates).foreach { _ => + val updateLen = stream.readInt() + val update = new Array[Byte](updateLen) stream.readFully(update) accumulator += Collections.singletonList(update) - len2 = stream.readInt() + } - new Array[Byte](0) + Array.empty[Byte] } } catch { case eof: EOFException => { @@ -197,6 +197,12 @@ private class PairwiseRDD(prev: RDD[Array[Byte]]) extends val asJavaPairRDD : JavaPairRDD[Long, Array[Byte]] = JavaPairRDD.fromRDD(this) } +private object SpecialLengths { + val END_OF_DATA_SECTION = -1 + val PYTHON_EXCEPTION_THROWN = -2 + val TIMING_DATA = -3 +} + private[spark] object PythonRDD { /** Strips the pickle PROTO and STOP opcodes from the start and end of a pickle */ diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 54fed1c9c7..fbc280fd37 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -19,6 +19,12 @@ import struct import cPickle +class SpecialLengths(object): + END_OF_DATA_SECTION = -1 + PYTHON_EXCEPTION_THROWN = -2 + TIMING_DATA = -3 + + class Batch(object): """ Used to store multiple RDD entries as a single Java object. diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index d63c2aaef7..7696df9d1c 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -31,7 +31,8 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, read_with_length, write_int, \ - read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file + read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file \ + SpecialLengths def load_obj(infile): @@ -39,7 +40,7 @@ def load_obj(infile): def report_times(outfile, boot, init, finish): - write_int(-3, outfile) + write_int(SpecialLengths.TIMING_DATA, outfile) write_long(1000 * boot, outfile) write_long(1000 * init, outfile) write_long(1000 * finish, outfile) @@ -82,16 +83,16 @@ def main(infile, outfile): for obj in func(split_index, iterator): write_with_length(dumps(obj), outfile) except Exception as e: - write_int(-2, outfile) + write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) write_with_length(traceback.format_exc(), outfile) sys.exit(-1) finish_time = time.time() report_times(outfile, boot_time, init_time, finish_time) # Mark the beginning of the accumulators section of the output - write_int(-1, outfile) - for aid, accum in _accumulatorRegistry.items(): + write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) + write_int(len(_accumulatorRegistry), outfile) + for (aid, accum) in _accumulatorRegistry.items(): write_with_length(dump_pickle((aid, accum._value)), outfile) - write_int(-1, outfile) if __name__ == '__main__': -- cgit v1.2.3 From 7d68a81a8ed5f49fefb3bd0fa0b9d3835cc7d86e Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 3 Nov 2013 11:03:02 -0800 Subject: Remove Pickle-wrapping of Java objects in PySpark. If we support custom serializers, the Python worker will know what type of input to expect, so we won't need to wrap Tuple2 and Strings into pickled tuples and strings. --- .../org/apache/spark/api/python/PythonRDD.scala | 106 ++++++++------------- python/pyspark/context.py | 10 +- python/pyspark/rdd.py | 11 ++- python/pyspark/serializers.py | 18 ++++ python/pyspark/worker.py | 14 ++- 5 files changed, 78 insertions(+), 81 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index 0d5913ec60..eb0b0db0cc 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -75,7 +75,7 @@ private[spark] class PythonRDD[T: ClassManifest]( // Partition index dataOut.writeInt(split.index) // sparkFilesDir - PythonRDD.writeAsPickle(SparkFiles.getRootDirectory, dataOut) + dataOut.writeUTF(SparkFiles.getRootDirectory) // Broadcast variables dataOut.writeInt(broadcastVars.length) for (broadcast <- broadcastVars) { @@ -85,9 +85,7 @@ private[spark] class PythonRDD[T: ClassManifest]( } // Python includes (*.zip and *.egg files) dataOut.writeInt(pythonIncludes.length) - for (f <- pythonIncludes) { - PythonRDD.writeAsPickle(f, dataOut) - } + pythonIncludes.foreach(dataOut.writeUTF) dataOut.flush() // Serialized user code for (elem <- command) { @@ -96,7 +94,7 @@ private[spark] class PythonRDD[T: ClassManifest]( printOut.flush() // Data values for (elem <- parent.iterator(split, context)) { - PythonRDD.writeAsPickle(elem, dataOut) + PythonRDD.writeToStream(elem, dataOut) } dataOut.flush() printOut.flush() @@ -205,60 +203,7 @@ private object SpecialLengths { private[spark] object PythonRDD { - /** Strips the pickle PROTO and STOP opcodes from the start and end of a pickle */ - def stripPickle(arr: Array[Byte]) : Array[Byte] = { - arr.slice(2, arr.length - 1) - } - - /** - * Write strings, pickled Python objects, or pairs of pickled objects to a data output stream. - * The data format is a 32-bit integer representing the pickled object's length (in bytes), - * followed by the pickled data. - * - * Pickle module: - * - * http://docs.python.org/2/library/pickle.html - * - * The pickle protocol is documented in the source of the `pickle` and `pickletools` modules: - * - * http://hg.python.org/cpython/file/2.6/Lib/pickle.py - * http://hg.python.org/cpython/file/2.6/Lib/pickletools.py - * - * @param elem the object to write - * @param dOut a data output stream - */ - def writeAsPickle(elem: Any, dOut: DataOutputStream) { - if (elem.isInstanceOf[Array[Byte]]) { - val arr = elem.asInstanceOf[Array[Byte]] - dOut.writeInt(arr.length) - dOut.write(arr) - } else if (elem.isInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]]) { - val t = elem.asInstanceOf[scala.Tuple2[Array[Byte], Array[Byte]]] - val length = t._1.length + t._2.length - 3 - 3 + 4 // stripPickle() removes 3 bytes - dOut.writeInt(length) - dOut.writeByte(Pickle.PROTO) - dOut.writeByte(Pickle.TWO) - dOut.write(PythonRDD.stripPickle(t._1)) - dOut.write(PythonRDD.stripPickle(t._2)) - dOut.writeByte(Pickle.TUPLE2) - dOut.writeByte(Pickle.STOP) - } else if (elem.isInstanceOf[String]) { - // For uniformity, strings are wrapped into Pickles. - val s = elem.asInstanceOf[String].getBytes("UTF-8") - val length = 2 + 1 + 4 + s.length + 1 - dOut.writeInt(length) - dOut.writeByte(Pickle.PROTO) - dOut.writeByte(Pickle.TWO) - dOut.write(Pickle.BINUNICODE) - dOut.writeInt(Integer.reverseBytes(s.length)) - dOut.write(s) - dOut.writeByte(Pickle.STOP) - } else { - throw new SparkException("Unexpected RDD type") - } - } - - def readRDDFromPickleFile(sc: JavaSparkContext, filename: String, parallelism: Int) : + def readRDDFromFile(sc: JavaSparkContext, filename: String, parallelism: Int): JavaRDD[Array[Byte]] = { val file = new DataInputStream(new FileInputStream(filename)) val objs = new collection.mutable.ArrayBuffer[Array[Byte]] @@ -276,15 +221,46 @@ private[spark] object PythonRDD { JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) } - def writeIteratorToPickleFile[T](items: java.util.Iterator[T], filename: String) { + def writeStringAsPickle(elem: String, dOut: DataOutputStream) { + val s = elem.getBytes("UTF-8") + val length = 2 + 1 + 4 + s.length + 1 + dOut.writeInt(length) + dOut.writeByte(Pickle.PROTO) + dOut.writeByte(Pickle.TWO) + dOut.write(Pickle.BINUNICODE) + dOut.writeInt(Integer.reverseBytes(s.length)) + dOut.write(s) + dOut.writeByte(Pickle.STOP) + } + + def writeToStream(elem: Any, dataOut: DataOutputStream) { + elem match { + case bytes: Array[Byte] => + dataOut.writeInt(bytes.length) + dataOut.write(bytes) + case pair: (Array[Byte], Array[Byte]) => + dataOut.writeInt(pair._1.length) + dataOut.write(pair._1) + dataOut.writeInt(pair._2.length) + dataOut.write(pair._2) + case str: String => + // Until we've implemented full custom serializer support, we need to return + // strings as Pickles to properly support union() and cartesian(): + writeStringAsPickle(str, dataOut) + case other => + throw new SparkException("Unexpected element type " + other.getClass) + } + } + + def writeToFile[T](items: java.util.Iterator[T], filename: String) { import scala.collection.JavaConverters._ - writeIteratorToPickleFile(items.asScala, filename) + writeToFile(items.asScala, filename) } - def writeIteratorToPickleFile[T](items: Iterator[T], filename: String) { + def writeToFile[T](items: Iterator[T], filename: String) { val file = new DataOutputStream(new FileOutputStream(filename)) for (item <- items) { - writeAsPickle(item, file) + writeToStream(item, file) } file.close() } @@ -300,10 +276,6 @@ private object Pickle { val TWO: Byte = 0x02.toByte val BINUNICODE: Byte = 'X' val STOP: Byte = '.' - val TUPLE2: Byte = 0x86.toByte - val EMPTY_LIST: Byte = ']' - val MARK: Byte = '(' - val APPENDS: Byte = 'e' } private class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] { diff --git a/python/pyspark/context.py b/python/pyspark/context.py index a7ca8bc888..0fec1a6bf6 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -42,7 +42,7 @@ class SparkContext(object): _gateway = None _jvm = None - _writeIteratorToPickleFile = None + _writeToFile = None _takePartition = None _next_accum_id = 0 _active_spark_context = None @@ -125,8 +125,8 @@ class SparkContext(object): if not SparkContext._gateway: SparkContext._gateway = launch_gateway() SparkContext._jvm = SparkContext._gateway.jvm - SparkContext._writeIteratorToPickleFile = \ - SparkContext._jvm.PythonRDD.writeIteratorToPickleFile + SparkContext._writeToFile = \ + SparkContext._jvm.PythonRDD.writeToFile SparkContext._takePartition = \ SparkContext._jvm.PythonRDD.takePartition @@ -190,8 +190,8 @@ class SparkContext(object): for x in c: write_with_length(dump_pickle(x), tempFile) tempFile.close() - readRDDFromPickleFile = self._jvm.PythonRDD.readRDDFromPickleFile - jrdd = readRDDFromPickleFile(self._jsc, tempFile.name, numSlices) + readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile + jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices) return RDD(jrdd, self) def textFile(self, name, minSplits=None): diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 7019fb8bee..d3c4d13a1e 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -54,6 +54,7 @@ class RDD(object): self.is_checkpointed = False self.ctx = ctx self._partitionFunc = None + self._stage_input_is_pairs = False @property def context(self): @@ -344,6 +345,7 @@ class RDD(object): yield pair else: yield pair + java_cartesian._stage_input_is_pairs = True return java_cartesian.flatMap(unpack_batches) def groupBy(self, f, numPartitions=None): @@ -391,8 +393,8 @@ class RDD(object): """ Return a list that contains all of the elements in this RDD. """ - picklesInJava = self._jrdd.collect().iterator() - return list(self._collect_iterator_through_file(picklesInJava)) + bytesInJava = self._jrdd.collect().iterator() + return list(self._collect_iterator_through_file(bytesInJava)) def _collect_iterator_through_file(self, iterator): # Transferring lots of data through Py4J can be slow because @@ -400,7 +402,7 @@ class RDD(object): # file and read it back. tempFile = NamedTemporaryFile(delete=False, dir=self.ctx._temp_dir) tempFile.close() - self.ctx._writeIteratorToPickleFile(iterator, tempFile.name) + self.ctx._writeToFile(iterator, tempFile.name) # Read the data into Python and deserialize it: with open(tempFile.name, 'rb') as tempFile: for item in read_from_pickle_file(tempFile): @@ -941,6 +943,7 @@ class PipelinedRDD(RDD): self.func = func self.preservesPartitioning = preservesPartitioning self._prev_jrdd = prev._jrdd + self._stage_input_is_pairs = prev._stage_input_is_pairs self.is_cached = False self.is_checkpointed = False self.ctx = prev.ctx @@ -959,7 +962,7 @@ class PipelinedRDD(RDD): def batched_func(split, iterator): return batched(oldfunc(split, iterator), batchSize) func = batched_func - cmds = [func, self._bypass_serializer] + cmds = [func, self._bypass_serializer, self._stage_input_is_pairs] pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds) broadcast_vars = ListConverter().convert( [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index fbc280fd37..fd02e1ee8f 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -93,6 +93,14 @@ def write_with_length(obj, stream): stream.write(obj) +def read_mutf8(stream): + """ + Read a string written with Java's DataOutputStream.writeUTF() method. + """ + length = struct.unpack('>H', stream.read(2))[0] + return stream.read(length).decode('utf8') + + def read_with_length(stream): length = read_int(stream) obj = stream.read(length) @@ -112,3 +120,13 @@ def read_from_pickle_file(stream): yield obj except EOFError: return + + +def read_pairs_from_pickle_file(stream): + try: + while True: + a = load_pickle(read_with_length(stream)) + b = load_pickle(read_with_length(stream)) + yield (a, b) + except EOFError: + return \ No newline at end of file diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 7696df9d1c..4e64557fc4 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -31,8 +31,8 @@ from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler from pyspark.files import SparkFiles from pyspark.serializers import write_with_length, read_with_length, write_int, \ - read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file \ - SpecialLengths + read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file, \ + SpecialLengths, read_mutf8, read_pairs_from_pickle_file def load_obj(infile): @@ -53,7 +53,7 @@ def main(infile, outfile): return # fetch name of workdir - spark_files_dir = load_pickle(read_with_length(infile)) + spark_files_dir = read_mutf8(infile) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True @@ -68,17 +68,21 @@ def main(infile, outfile): sys.path.append(spark_files_dir) # *.py files that were added will be copied here num_python_includes = read_int(infile) for _ in range(num_python_includes): - sys.path.append(os.path.join(spark_files_dir, load_pickle(read_with_length(infile)))) + sys.path.append(os.path.join(spark_files_dir, read_mutf8(infile))) # now load function func = load_obj(infile) bypassSerializer = load_obj(infile) + stageInputIsPairs = load_obj(infile) if bypassSerializer: dumps = lambda x: x else: dumps = dump_pickle init_time = time.time() - iterator = read_from_pickle_file(infile) + if stageInputIsPairs: + iterator = read_pairs_from_pickle_file(infile) + else: + iterator = read_from_pickle_file(infile) try: for obj in func(split_index, iterator): write_with_length(dumps(obj), outfile) -- cgit v1.2.3 From cbb7f04aef2220ece93dea9f3fa98b5db5f270d6 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 5 Nov 2013 17:52:39 -0800 Subject: Add custom serializer support to PySpark. For now, this only adds MarshalSerializer, but it lays the groundwork for other supporting custom serializers. Many of these mechanisms can also be used to support deserialization of different data formats sent by Java, such as data encoded by MsgPack. This also fixes a bug in SparkContext.union(). --- .../org/apache/spark/api/python/PythonRDD.scala | 23 +- python/epydoc.conf | 2 +- python/pyspark/accumulators.py | 6 +- python/pyspark/context.py | 61 ++-- python/pyspark/rdd.py | 86 +++--- python/pyspark/serializers.py | 310 ++++++++++++++++----- python/pyspark/tests.py | 3 +- python/pyspark/worker.py | 41 ++- python/run-tests | 1 + 9 files changed, 363 insertions(+), 170 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index eb0b0db0cc..ef9bf4db9b 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -221,18 +221,6 @@ private[spark] object PythonRDD { JavaRDD.fromRDD(sc.sc.parallelize(objs, parallelism)) } - def writeStringAsPickle(elem: String, dOut: DataOutputStream) { - val s = elem.getBytes("UTF-8") - val length = 2 + 1 + 4 + s.length + 1 - dOut.writeInt(length) - dOut.writeByte(Pickle.PROTO) - dOut.writeByte(Pickle.TWO) - dOut.write(Pickle.BINUNICODE) - dOut.writeInt(Integer.reverseBytes(s.length)) - dOut.write(s) - dOut.writeByte(Pickle.STOP) - } - def writeToStream(elem: Any, dataOut: DataOutputStream) { elem match { case bytes: Array[Byte] => @@ -244,9 +232,7 @@ private[spark] object PythonRDD { dataOut.writeInt(pair._2.length) dataOut.write(pair._2) case str: String => - // Until we've implemented full custom serializer support, we need to return - // strings as Pickles to properly support union() and cartesian(): - writeStringAsPickle(str, dataOut) + dataOut.writeUTF(str) case other => throw new SparkException("Unexpected element type " + other.getClass) } @@ -271,13 +257,6 @@ private[spark] object PythonRDD { } } -private object Pickle { - val PROTO: Byte = 0x80.toByte - val TWO: Byte = 0x02.toByte - val BINUNICODE: Byte = 'X' - val STOP: Byte = '.' -} - private class BytesToString extends org.apache.spark.api.java.function.Function[Array[Byte], String] { override def call(arr: Array[Byte]) : String = new String(arr, "UTF-8") } diff --git a/python/epydoc.conf b/python/epydoc.conf index 1d0d002d36..0b42e729f8 100644 --- a/python/epydoc.conf +++ b/python/epydoc.conf @@ -32,6 +32,6 @@ target: docs/ private: no -exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.serializers +exclude: pyspark.cloudpickle pyspark.worker pyspark.join pyspark.java_gateway pyspark.examples pyspark.shell pyspark.test pyspark.rddsampler pyspark.daemon diff --git a/python/pyspark/accumulators.py b/python/pyspark/accumulators.py index da3d96689a..2204e9c9ca 100644 --- a/python/pyspark/accumulators.py +++ b/python/pyspark/accumulators.py @@ -90,9 +90,11 @@ import struct import SocketServer import threading from pyspark.cloudpickle import CloudPickler -from pyspark.serializers import read_int, read_with_length, load_pickle +from pyspark.serializers import read_int, PickleSerializer +pickleSer = PickleSerializer() + # Holds accumulators registered on the current machine, keyed by ID. This is then used to send # the local accumulator updates back to the driver program at the end of a task. _accumulatorRegistry = {} @@ -211,7 +213,7 @@ class _UpdateRequestHandler(SocketServer.StreamRequestHandler): from pyspark.accumulators import _accumulatorRegistry num_updates = read_int(self.rfile) for _ in range(num_updates): - (aid, update) = load_pickle(read_with_length(self.rfile)) + (aid, update) = pickleSer._read_with_length(self.rfile) _accumulatorRegistry[aid] += update # Write a byte in acknowledgement self.wfile.write(struct.pack("!b", 1)) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 0fec1a6bf6..6bb1c6c3a1 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -26,7 +26,7 @@ from pyspark.accumulators import Accumulator from pyspark.broadcast import Broadcast from pyspark.files import SparkFiles from pyspark.java_gateway import launch_gateway -from pyspark.serializers import dump_pickle, write_with_length, batched +from pyspark.serializers import PickleSerializer, BatchedSerializer, MUTF8Deserializer from pyspark.storagelevel import StorageLevel from pyspark.rdd import RDD @@ -51,7 +51,7 @@ class SparkContext(object): def __init__(self, master, jobName, sparkHome=None, pyFiles=None, - environment=None, batchSize=1024): + environment=None, batchSize=1024, serializer=PickleSerializer()): """ Create a new SparkContext. @@ -67,6 +67,7 @@ class SparkContext(object): @param batchSize: The number of Python objects represented as a single Java object. Set 1 to disable batching or -1 to use an unlimited batch size. + @param serializer: The serializer for RDDs. >>> from pyspark.context import SparkContext @@ -83,7 +84,13 @@ class SparkContext(object): self.jobName = jobName self.sparkHome = sparkHome or None # None becomes null in Py4J self.environment = environment or {} - self.batchSize = batchSize # -1 represents a unlimited batch size + self._batchSize = batchSize # -1 represents an unlimited batch size + self._unbatched_serializer = serializer + if batchSize == 1: + self.serializer = self._unbatched_serializer + else: + self.serializer = BatchedSerializer(self._unbatched_serializer, + batchSize) # Create the Java SparkContext through Py4J empty_string_array = self._gateway.new_array(self._jvm.String, 0) @@ -184,15 +191,17 @@ class SparkContext(object): # Make sure we distribute data evenly if it's smaller than self.batchSize if "__len__" not in dir(c): c = list(c) # Make it a list so we can compute its length - batchSize = min(len(c) // numSlices, self.batchSize) + batchSize = min(len(c) // numSlices, self._batchSize) if batchSize > 1: - c = batched(c, batchSize) - for x in c: - write_with_length(dump_pickle(x), tempFile) + serializer = BatchedSerializer(self._unbatched_serializer, + batchSize) + else: + serializer = self._unbatched_serializer + serializer.dump_stream(c, tempFile) tempFile.close() readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile jrdd = readRDDFromFile(self._jsc, tempFile.name, numSlices) - return RDD(jrdd, self) + return RDD(jrdd, self, serializer) def textFile(self, name, minSplits=None): """ @@ -201,21 +210,39 @@ class SparkContext(object): RDD of Strings. """ minSplits = minSplits or min(self.defaultParallelism, 2) - jrdd = self._jsc.textFile(name, minSplits) - return RDD(jrdd, self) + return RDD(self._jsc.textFile(name, minSplits), self, + MUTF8Deserializer()) - def _checkpointFile(self, name): + def _checkpointFile(self, name, input_deserializer): jrdd = self._jsc.checkpointFile(name) - return RDD(jrdd, self) + return RDD(jrdd, self, input_deserializer) def union(self, rdds): """ Build the union of a list of RDDs. + + This supports unions() of RDDs with different serialized formats, + although this forces them to be reserialized using the default + serializer: + + >>> path = os.path.join(tempdir, "union-text.txt") + >>> with open(path, "w") as testFile: + ... testFile.write("Hello") + >>> textFile = sc.textFile(path) + >>> textFile.collect() + [u'Hello'] + >>> parallelized = sc.parallelize(["World!"]) + >>> sorted(sc.union([textFile, parallelized]).collect()) + [u'Hello', 'World!'] """ + first_jrdd_deserializer = rdds[0]._jrdd_deserializer + if any(x._jrdd_deserializer != first_jrdd_deserializer for x in rdds): + rdds = [x._reserialize() for x in rdds] first = rdds[0]._jrdd rest = [x._jrdd for x in rdds[1:]] - rest = ListConverter().convert(rest, self.gateway._gateway_client) - return RDD(self._jsc.union(first, rest), self) + rest = ListConverter().convert(rest, self._gateway._gateway_client) + return RDD(self._jsc.union(first, rest), self, + rdds[0]._jrdd_deserializer) def broadcast(self, value): """ @@ -223,7 +250,9 @@ class SparkContext(object): object for reading it in distributed functions. The variable will be sent to each cluster only once. """ - jbroadcast = self._jsc.broadcast(bytearray(dump_pickle(value))) + pickleSer = PickleSerializer() + pickled = pickleSer._dumps(value) + jbroadcast = self._jsc.broadcast(bytearray(pickled)) return Broadcast(jbroadcast.id(), value, jbroadcast, self._pickled_broadcast_vars) @@ -235,7 +264,7 @@ class SparkContext(object): and floating-point numbers if you do not provide one. For other types, a custom AccumulatorParam can be used. """ - if accum_param == None: + if accum_param is None: if isinstance(value, int): accum_param = accumulators.INT_ACCUMULATOR_PARAM elif isinstance(value, float): diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index d3c4d13a1e..6691c30519 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -18,7 +18,7 @@ from base64 import standard_b64encode as b64enc import copy from collections import defaultdict -from itertools import chain, ifilter, imap, product +from itertools import chain, ifilter, imap import operator import os import sys @@ -28,8 +28,8 @@ from tempfile import NamedTemporaryFile from threading import Thread from pyspark import cloudpickle -from pyspark.serializers import batched, Batch, dump_pickle, load_pickle, \ - read_from_pickle_file, pack_long +from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ + BatchedSerializer, pack_long from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup from pyspark.statcounter import StatCounter @@ -48,13 +48,12 @@ class RDD(object): operated on in parallel. """ - def __init__(self, jrdd, ctx): + def __init__(self, jrdd, ctx, jrdd_deserializer): self._jrdd = jrdd self.is_cached = False self.is_checkpointed = False self.ctx = ctx - self._partitionFunc = None - self._stage_input_is_pairs = False + self._jrdd_deserializer = jrdd_deserializer @property def context(self): @@ -248,7 +247,23 @@ class RDD(object): >>> rdd.union(rdd).collect() [1, 1, 2, 3, 1, 1, 2, 3] """ - return RDD(self._jrdd.union(other._jrdd), self.ctx) + if self._jrdd_deserializer == other._jrdd_deserializer: + rdd = RDD(self._jrdd.union(other._jrdd), self.ctx, + self._jrdd_deserializer) + return rdd + else: + # These RDDs contain data in different serialized formats, so we + # must normalize them to the default serializer. + self_copy = self._reserialize() + other_copy = other._reserialize() + return RDD(self_copy._jrdd.union(other_copy._jrdd), self.ctx, + self.ctx.serializer) + + def _reserialize(self): + if self._jrdd_deserializer == self.ctx.serializer: + return self + else: + return self.map(lambda x: x, preservesPartitioning=True) def __add__(self, other): """ @@ -335,18 +350,9 @@ class RDD(object): [(1, 1), (1, 2), (2, 1), (2, 2)] """ # Due to batching, we can't use the Java cartesian method. - java_cartesian = RDD(self._jrdd.cartesian(other._jrdd), self.ctx) - def unpack_batches(pair): - (x, y) = pair - if type(x) == Batch or type(y) == Batch: - xs = x.items if type(x) == Batch else [x] - ys = y.items if type(y) == Batch else [y] - for pair in product(xs, ys): - yield pair - else: - yield pair - java_cartesian._stage_input_is_pairs = True - return java_cartesian.flatMap(unpack_batches) + deserializer = CartesianDeserializer(self._jrdd_deserializer, + other._jrdd_deserializer) + return RDD(self._jrdd.cartesian(other._jrdd), self.ctx, deserializer) def groupBy(self, f, numPartitions=None): """ @@ -405,7 +411,7 @@ class RDD(object): self.ctx._writeToFile(iterator, tempFile.name) # Read the data into Python and deserialize it: with open(tempFile.name, 'rb') as tempFile: - for item in read_from_pickle_file(tempFile): + for item in self._jrdd_deserializer.load_stream(tempFile): yield item os.unlink(tempFile.name) @@ -573,7 +579,7 @@ class RDD(object): items = [] for partition in range(mapped._jrdd.splits().size()): iterator = self.ctx._takePartition(mapped._jrdd.rdd(), partition) - items.extend(self._collect_iterator_through_file(iterator)) + items.extend(mapped._collect_iterator_through_file(iterator)) if len(items) >= num: break return items[:num] @@ -737,6 +743,7 @@ class RDD(object): # Transferring O(n) objects to Java is too expensive. Instead, we'll # form the hash buckets in Python, transferring O(numPartitions) objects # to Java. Each object is a (splitNumber, [objects]) pair. + outputSerializer = self.ctx._unbatched_serializer def add_shuffle_key(split, iterator): buckets = defaultdict(list) @@ -745,14 +752,14 @@ class RDD(object): buckets[partitionFunc(k) % numPartitions].append((k, v)) for (split, items) in buckets.iteritems(): yield pack_long(split) - yield dump_pickle(Batch(items)) + yield outputSerializer._dumps(items) keyed = PipelinedRDD(self, add_shuffle_key) keyed._bypass_serializer = True pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() partitioner = self.ctx._jvm.PythonPartitioner(numPartitions, id(partitionFunc)) jrdd = pairRDD.partitionBy(partitioner).values() - rdd = RDD(jrdd, self.ctx) + rdd = RDD(jrdd, self.ctx, BatchedSerializer(outputSerializer)) # This is required so that id(partitionFunc) remains unique, even if # partitionFunc is a lambda: rdd._partitionFunc = partitionFunc @@ -789,7 +796,8 @@ class RDD(object): numPartitions = self.ctx.defaultParallelism def combineLocally(iterator): combiners = {} - for (k, v) in iterator: + for x in iterator: + (k, v) = x if k not in combiners: combiners[k] = createCombiner(v) else: @@ -931,38 +939,38 @@ class PipelinedRDD(RDD): 20 """ def __init__(self, prev, func, preservesPartitioning=False): - if isinstance(prev, PipelinedRDD) and prev._is_pipelinable(): + if not isinstance(prev, PipelinedRDD) or not prev._is_pipelinable(): + # This transformation is the first in its stage: + self.func = func + self.preservesPartitioning = preservesPartitioning + self._prev_jrdd = prev._jrdd + self._prev_jrdd_deserializer = prev._jrdd_deserializer + else: prev_func = prev.func def pipeline_func(split, iterator): return func(split, prev_func(split, iterator)) self.func = pipeline_func self.preservesPartitioning = \ prev.preservesPartitioning and preservesPartitioning - self._prev_jrdd = prev._prev_jrdd - else: - self.func = func - self.preservesPartitioning = preservesPartitioning - self._prev_jrdd = prev._jrdd - self._stage_input_is_pairs = prev._stage_input_is_pairs + self._prev_jrdd = prev._prev_jrdd # maintain the pipeline + self._prev_jrdd_deserializer = prev._prev_jrdd_deserializer self.is_cached = False self.is_checkpointed = False self.ctx = prev.ctx self.prev = prev self._jrdd_val = None + self._jrdd_deserializer = self.ctx.serializer self._bypass_serializer = False @property def _jrdd(self): if self._jrdd_val: return self._jrdd_val - func = self.func - if not self._bypass_serializer and self.ctx.batchSize != 1: - oldfunc = self.func - batchSize = self.ctx.batchSize - def batched_func(split, iterator): - return batched(oldfunc(split, iterator), batchSize) - func = batched_func - cmds = [func, self._bypass_serializer, self._stage_input_is_pairs] + if self._bypass_serializer: + serializer = NoOpSerializer() + else: + serializer = self.ctx.serializer + cmds = [self.func, self._prev_jrdd_deserializer, serializer] pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds) broadcast_vars = ListConverter().convert( [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index fd02e1ee8f..4fb444443f 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -15,8 +15,58 @@ # limitations under the License. # -import struct +""" +PySpark supports custom serializers for transferring data; this can improve +performance. + +By default, PySpark uses L{PickleSerializer} to serialize objects using Python's +C{cPickle} serializer, which can serialize nearly any Python object. +Other serializers, like L{MarshalSerializer}, support fewer datatypes but can be +faster. + +The serializer is chosen when creating L{SparkContext}: + +>>> from pyspark.context import SparkContext +>>> from pyspark.serializers import MarshalSerializer +>>> sc = SparkContext('local', 'test', serializer=MarshalSerializer()) +>>> sc.parallelize(list(range(1000))).map(lambda x: 2 * x).take(10) +[0, 2, 4, 6, 8, 10, 12, 14, 16, 18] +>>> sc.stop() + +By default, PySpark serialize objects in batches; the batch size can be +controlled through SparkContext's C{batchSize} parameter +(the default size is 1024 objects): + +>>> sc = SparkContext('local', 'test', batchSize=2) +>>> rdd = sc.parallelize(range(16), 4).map(lambda x: x) + +Behind the scenes, this creates a JavaRDD with four partitions, each of +which contains two batches of two objects: + +>>> rdd.glom().collect() +[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]] +>>> rdd._jrdd.count() +8L +>>> sc.stop() + +A batch size of -1 uses an unlimited batch size, and a size of 1 disables +batching: + +>>> sc = SparkContext('local', 'test', batchSize=1) +>>> rdd = sc.parallelize(range(16), 4).map(lambda x: x) +>>> rdd.glom().collect() +[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]] +>>> rdd._jrdd.count() +16L +""" + import cPickle +from itertools import chain, izip, product +import marshal +import struct + + +__all__ = ["PickleSerializer", "MarshalSerializer"] class SpecialLengths(object): @@ -25,41 +75,206 @@ class SpecialLengths(object): TIMING_DATA = -3 -class Batch(object): +class Serializer(object): + + def dump_stream(self, iterator, stream): + """ + Serialize an iterator of objects to the output stream. + """ + raise NotImplementedError + + def load_stream(self, stream): + """ + Return an iterator of deserialized objects from the input stream. + """ + raise NotImplementedError + + + def _load_stream_without_unbatching(self, stream): + return self.load_stream(stream) + + # Note: our notion of "equality" is that output generated by + # equal serializers can be deserialized using the same serializer. + + # This default implementation handles the simple cases; + # subclasses should override __eq__ as appropriate. + + def __eq__(self, other): + return isinstance(other, self.__class__) + + def __ne__(self, other): + return not self.__eq__(other) + + +class FramedSerializer(Serializer): + """ + Serializer that writes objects as a stream of (length, data) pairs, + where C{length} is a 32-bit integer and data is C{length} bytes. + """ + + def dump_stream(self, iterator, stream): + for obj in iterator: + self._write_with_length(obj, stream) + + def load_stream(self, stream): + while True: + try: + yield self._read_with_length(stream) + except EOFError: + return + + def _write_with_length(self, obj, stream): + serialized = self._dumps(obj) + write_int(len(serialized), stream) + stream.write(serialized) + + def _read_with_length(self, stream): + length = read_int(stream) + obj = stream.read(length) + if obj == "": + raise EOFError + return self._loads(obj) + + def _dumps(self, obj): + """ + Serialize an object into a byte array. + When batching is used, this will be called with an array of objects. + """ + raise NotImplementedError + + def _loads(self, obj): + """ + Deserialize an object from a byte array. + """ + raise NotImplementedError + + +class BatchedSerializer(Serializer): + """ + Serializes a stream of objects in batches by calling its wrapped + Serializer with streams of objects. + """ + + UNLIMITED_BATCH_SIZE = -1 + + def __init__(self, serializer, batchSize=UNLIMITED_BATCH_SIZE): + self.serializer = serializer + self.batchSize = batchSize + + def _batched(self, iterator): + if self.batchSize == self.UNLIMITED_BATCH_SIZE: + yield list(iterator) + else: + items = [] + count = 0 + for item in iterator: + items.append(item) + count += 1 + if count == self.batchSize: + yield items + items = [] + count = 0 + if items: + yield items + + def dump_stream(self, iterator, stream): + if isinstance(iterator, basestring): + iterator = [iterator] + self.serializer.dump_stream(self._batched(iterator), stream) + + def load_stream(self, stream): + return chain.from_iterable(self._load_stream_without_unbatching(stream)) + + def _load_stream_without_unbatching(self, stream): + return self.serializer.load_stream(stream) + + def __eq__(self, other): + return isinstance(other, BatchedSerializer) and \ + other.serializer == self.serializer + + def __str__(self): + return "BatchedSerializer<%s>" % str(self.serializer) + + +class CartesianDeserializer(FramedSerializer): """ - Used to store multiple RDD entries as a single Java object. + Deserializes the JavaRDD cartesian() of two PythonRDDs. + """ + + def __init__(self, key_ser, val_ser): + self.key_ser = key_ser + self.val_ser = val_ser + + def load_stream(self, stream): + key_stream = self.key_ser._load_stream_without_unbatching(stream) + val_stream = self.val_ser._load_stream_without_unbatching(stream) + key_is_batched = isinstance(self.key_ser, BatchedSerializer) + val_is_batched = isinstance(self.val_ser, BatchedSerializer) + for (keys, vals) in izip(key_stream, val_stream): + keys = keys if key_is_batched else [keys] + vals = vals if val_is_batched else [vals] + for pair in product(keys, vals): + yield pair + + def __eq__(self, other): + return isinstance(other, CartesianDeserializer) and \ + self.key_ser == other.key_ser and self.val_ser == other.val_ser + + def __str__(self): + return "CartesianDeserializer<%s, %s>" % \ + (str(self.key_ser), str(self.val_ser)) + + +class NoOpSerializer(FramedSerializer): + + def _loads(self, obj): return obj + def _dumps(self, obj): return obj + + +class PickleSerializer(FramedSerializer): + """ + Serializes objects using Python's cPickle serializer: + + http://docs.python.org/2/library/pickle.html + + This serializer supports nearly any Python object, but may + not be as fast as more specialized serializers. + """ + + def _dumps(self, obj): return cPickle.dumps(obj, 2) + _loads = cPickle.loads + - This relieves us from having to explicitly track whether an RDD - is stored as batches of objects and avoids problems when processing - the union() of batched and unbatched RDDs (e.g. the union() of textFile() - with another RDD). +class MarshalSerializer(FramedSerializer): """ - def __init__(self, items): - self.items = items + Serializes objects using Python's Marshal serializer: + http://docs.python.org/2/library/marshal.html -def batched(iterator, batchSize): - if batchSize == -1: # unlimited batch size - yield Batch(list(iterator)) - else: - items = [] - count = 0 - for item in iterator: - items.append(item) - count += 1 - if count == batchSize: - yield Batch(items) - items = [] - count = 0 - if items: - yield Batch(items) + This serializer is faster than PickleSerializer but supports fewer datatypes. + """ + + _dumps = marshal.dumps + _loads = marshal.loads -def dump_pickle(obj): - return cPickle.dumps(obj, 2) +class MUTF8Deserializer(Serializer): + """ + Deserializes streams written by Java's DataOutputStream.writeUTF(). + """ + def _loads(self, stream): + length = struct.unpack('>H', stream.read(2))[0] + return stream.read(length).decode('utf8') -load_pickle = cPickle.loads + def load_stream(self, stream): + while True: + try: + yield self._loads(stream) + except struct.error: + return + except EOFError: + return def read_long(stream): @@ -90,43 +305,4 @@ def write_int(value, stream): def write_with_length(obj, stream): write_int(len(obj), stream) - stream.write(obj) - - -def read_mutf8(stream): - """ - Read a string written with Java's DataOutputStream.writeUTF() method. - """ - length = struct.unpack('>H', stream.read(2))[0] - return stream.read(length).decode('utf8') - - -def read_with_length(stream): - length = read_int(stream) - obj = stream.read(length) - if obj == "": - raise EOFError - return obj - - -def read_from_pickle_file(stream): - try: - while True: - obj = load_pickle(read_with_length(stream)) - if type(obj) == Batch: # We don't care about inheritance - for item in obj.items: - yield item - else: - yield obj - except EOFError: - return - - -def read_pairs_from_pickle_file(stream): - try: - while True: - a = load_pickle(read_with_length(stream)) - b = load_pickle(read_with_length(stream)) - yield (a, b) - except EOFError: - return \ No newline at end of file + stream.write(obj) \ No newline at end of file diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 29d6a128f6..621e1cb58c 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -86,7 +86,8 @@ class TestCheckpoint(PySparkTestCase): time.sleep(1) # 1 second self.assertTrue(flatMappedRDD.getCheckpointFile() is not None) - recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile()) + recovered = self.sc._checkpointFile(flatMappedRDD.getCheckpointFile(), + flatMappedRDD._jrdd_deserializer) self.assertEquals([1, 2, 3, 4], recovered.collect()) diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 4e64557fc4..5b16d5db7e 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -30,13 +30,17 @@ from pyspark.accumulators import _accumulatorRegistry from pyspark.broadcast import Broadcast, _broadcastRegistry from pyspark.cloudpickle import CloudPickler from pyspark.files import SparkFiles -from pyspark.serializers import write_with_length, read_with_length, write_int, \ - read_long, write_long, read_int, dump_pickle, load_pickle, read_from_pickle_file, \ - SpecialLengths, read_mutf8, read_pairs_from_pickle_file +from pyspark.serializers import write_with_length, write_int, read_long, \ + write_long, read_int, SpecialLengths, MUTF8Deserializer, PickleSerializer + + +pickleSer = PickleSerializer() +mutf8_deserializer = MUTF8Deserializer() def load_obj(infile): - return load_pickle(standard_b64decode(infile.readline().strip())) + decoded = standard_b64decode(infile.readline().strip()) + return pickleSer._loads(decoded) def report_times(outfile, boot, init, finish): @@ -53,7 +57,7 @@ def main(infile, outfile): return # fetch name of workdir - spark_files_dir = read_mutf8(infile) + spark_files_dir = mutf8_deserializer._loads(infile) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True @@ -61,31 +65,24 @@ def main(infile, outfile): num_broadcast_variables = read_int(infile) for _ in range(num_broadcast_variables): bid = read_long(infile) - value = read_with_length(infile) - _broadcastRegistry[bid] = Broadcast(bid, load_pickle(value)) + value = pickleSer._read_with_length(infile) + _broadcastRegistry[bid] = Broadcast(bid, value) # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH sys.path.append(spark_files_dir) # *.py files that were added will be copied here num_python_includes = read_int(infile) for _ in range(num_python_includes): - sys.path.append(os.path.join(spark_files_dir, read_mutf8(infile))) + filename = mutf8_deserializer._loads(infile) + sys.path.append(os.path.join(spark_files_dir, filename)) - # now load function + # Load this stage's function and serializer: func = load_obj(infile) - bypassSerializer = load_obj(infile) - stageInputIsPairs = load_obj(infile) - if bypassSerializer: - dumps = lambda x: x - else: - dumps = dump_pickle + deserializer = load_obj(infile) + serializer = load_obj(infile) init_time = time.time() - if stageInputIsPairs: - iterator = read_pairs_from_pickle_file(infile) - else: - iterator = read_from_pickle_file(infile) try: - for obj in func(split_index, iterator): - write_with_length(dumps(obj), outfile) + iterator = deserializer.load_stream(infile) + serializer.dump_stream(func(split_index, iterator), outfile) except Exception as e: write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) write_with_length(traceback.format_exc(), outfile) @@ -96,7 +93,7 @@ def main(infile, outfile): write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) write_int(len(_accumulatorRegistry), outfile) for (aid, accum) in _accumulatorRegistry.items(): - write_with_length(dump_pickle((aid, accum._value)), outfile) + pickleSer._write_with_length((aid, accum._value), outfile) if __name__ == '__main__': diff --git a/python/run-tests b/python/run-tests index cbc554ea9d..d4dad672d2 100755 --- a/python/run-tests +++ b/python/run-tests @@ -37,6 +37,7 @@ run_test "pyspark/rdd.py" run_test "pyspark/context.py" run_test "-m doctest pyspark/broadcast.py" run_test "-m doctest pyspark/accumulators.py" +run_test "-m doctest pyspark/serializers.py" run_test "pyspark/tests.py" if [[ $FAILED != 0 ]]; then -- cgit v1.2.3 From ffa5bedf46fbc89ad5c5658f3b423dfff49b70f0 Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 10 Nov 2013 12:58:28 -0800 Subject: Send PySpark commands as bytes insetad of strings. --- .../org/apache/spark/api/python/PythonRDD.scala | 24 ++++------------------ python/pyspark/rdd.py | 12 +++++------ python/pyspark/serializers.py | 5 +++++ python/pyspark/worker.py | 12 ++--------- 4 files changed, 17 insertions(+), 36 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index ef9bf4db9b..132e4fb0d2 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -27,13 +27,12 @@ import org.apache.spark.api.java.{JavaSparkContext, JavaPairRDD, JavaRDD} import org.apache.spark.broadcast.Broadcast import org.apache.spark._ import org.apache.spark.rdd.RDD -import org.apache.spark.rdd.PipedRDD import org.apache.spark.util.Utils private[spark] class PythonRDD[T: ClassManifest]( parent: RDD[T], - command: Seq[String], + command: Array[Byte], envVars: JMap[String, String], pythonIncludes: JList[String], preservePartitoning: Boolean, @@ -44,21 +43,10 @@ private[spark] class PythonRDD[T: ClassManifest]( val bufferSize = System.getProperty("spark.buffer.size", "65536").toInt - // Similar to Runtime.exec(), if we are given a single string, split it into words - // using a standard StringTokenizer (i.e. by spaces) - def this(parent: RDD[T], command: String, envVars: JMap[String, String], - pythonIncludes: JList[String], - preservePartitoning: Boolean, pythonExec: String, - broadcastVars: JList[Broadcast[Array[Byte]]], - accumulator: Accumulator[JList[Array[Byte]]]) = - this(parent, PipedRDD.tokenize(command), envVars, pythonIncludes, preservePartitoning, pythonExec, - broadcastVars, accumulator) - override def getPartitions = parent.partitions override val partitioner = if (preservePartitoning) parent.partitioner else None - override def compute(split: Partition, context: TaskContext): Iterator[Array[Byte]] = { val startTime = System.currentTimeMillis val env = SparkEnv.get @@ -71,7 +59,6 @@ private[spark] class PythonRDD[T: ClassManifest]( SparkEnv.set(env) val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) val dataOut = new DataOutputStream(stream) - val printOut = new PrintWriter(stream) // Partition index dataOut.writeInt(split.index) // sparkFilesDir @@ -87,17 +74,14 @@ private[spark] class PythonRDD[T: ClassManifest]( dataOut.writeInt(pythonIncludes.length) pythonIncludes.foreach(dataOut.writeUTF) dataOut.flush() - // Serialized user code - for (elem <- command) { - printOut.println(elem) - } - printOut.flush() + // Serialized command: + dataOut.writeInt(command.length) + dataOut.write(command) // Data values for (elem <- parent.iterator(split, context)) { PythonRDD.writeToStream(elem, dataOut) } dataOut.flush() - printOut.flush() worker.shutdownOutput() } catch { case e: IOException => diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 6691c30519..062f44f81e 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -27,9 +27,8 @@ from subprocess import Popen, PIPE from tempfile import NamedTemporaryFile from threading import Thread -from pyspark import cloudpickle from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ - BatchedSerializer, pack_long + BatchedSerializer, CloudPickleSerializer, pack_long from pyspark.join import python_join, python_left_outer_join, \ python_right_outer_join, python_cogroup from pyspark.statcounter import StatCounter @@ -970,8 +969,8 @@ class PipelinedRDD(RDD): serializer = NoOpSerializer() else: serializer = self.ctx.serializer - cmds = [self.func, self._prev_jrdd_deserializer, serializer] - pipe_command = ' '.join(b64enc(cloudpickle.dumps(f)) for f in cmds) + command = (self.func, self._prev_jrdd_deserializer, serializer) + pickled_command = CloudPickleSerializer()._dumps(command) broadcast_vars = ListConverter().convert( [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], self.ctx._gateway._gateway_client) @@ -982,8 +981,9 @@ class PipelinedRDD(RDD): includes = ListConverter().convert(self.ctx._python_includes, self.ctx._gateway._gateway_client) python_rdd = self.ctx._jvm.PythonRDD(self._prev_jrdd.rdd(), - pipe_command, env, includes, self.preservesPartitioning, self.ctx.pythonExec, - broadcast_vars, self.ctx._javaAccumulator, class_manifest) + bytearray(pickled_command), env, includes, self.preservesPartitioning, + self.ctx.pythonExec, broadcast_vars, self.ctx._javaAccumulator, + class_manifest) self._jrdd_val = python_rdd.asJavaRDD() return self._jrdd_val diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 4fb444443f..b23804b33c 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -64,6 +64,7 @@ import cPickle from itertools import chain, izip, product import marshal import struct +from pyspark import cloudpickle __all__ = ["PickleSerializer", "MarshalSerializer"] @@ -244,6 +245,10 @@ class PickleSerializer(FramedSerializer): def _dumps(self, obj): return cPickle.dumps(obj, 2) _loads = cPickle.loads +class CloudPickleSerializer(PickleSerializer): + + def _dumps(self, obj): return cloudpickle.dumps(obj, 2) + class MarshalSerializer(FramedSerializer): """ diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 5b16d5db7e..2751f1239e 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -23,7 +23,6 @@ import sys import time import socket import traceback -from base64 import standard_b64decode # CloudPickler needs to be imported so that depicklers are registered using the # copy_reg module. from pyspark.accumulators import _accumulatorRegistry @@ -38,11 +37,6 @@ pickleSer = PickleSerializer() mutf8_deserializer = MUTF8Deserializer() -def load_obj(infile): - decoded = standard_b64decode(infile.readline().strip()) - return pickleSer._loads(decoded) - - def report_times(outfile, boot, init, finish): write_int(SpecialLengths.TIMING_DATA, outfile) write_long(1000 * boot, outfile) @@ -75,10 +69,8 @@ def main(infile, outfile): filename = mutf8_deserializer._loads(infile) sys.path.append(os.path.join(spark_files_dir, filename)) - # Load this stage's function and serializer: - func = load_obj(infile) - deserializer = load_obj(infile) - serializer = load_obj(infile) + command = pickleSer._read_with_length(infile) + (func, deserializer, serializer) = command init_time = time.time() try: iterator = deserializer.load_stream(infile) -- cgit v1.2.3 From 13122ceb8c74dc0c4ad37902a3d1b30bf273cc6a Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Sun, 10 Nov 2013 17:48:27 -0800 Subject: FramedSerializer: _dumps => dumps, _loads => loads. --- python/pyspark/context.py | 2 +- python/pyspark/rdd.py | 4 ++-- python/pyspark/serializers.py | 26 +++++++++++++------------- python/pyspark/worker.py | 4 ++-- 4 files changed, 18 insertions(+), 18 deletions(-) diff --git a/python/pyspark/context.py b/python/pyspark/context.py index 6bb1c6c3a1..cbd41e58c4 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -251,7 +251,7 @@ class SparkContext(object): sent to each cluster only once. """ pickleSer = PickleSerializer() - pickled = pickleSer._dumps(value) + pickled = pickleSer.dumps(value) jbroadcast = self._jsc.broadcast(bytearray(pickled)) return Broadcast(jbroadcast.id(), value, jbroadcast, self._pickled_broadcast_vars) diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py index 062f44f81e..957f3f89c0 100644 --- a/python/pyspark/rdd.py +++ b/python/pyspark/rdd.py @@ -751,7 +751,7 @@ class RDD(object): buckets[partitionFunc(k) % numPartitions].append((k, v)) for (split, items) in buckets.iteritems(): yield pack_long(split) - yield outputSerializer._dumps(items) + yield outputSerializer.dumps(items) keyed = PipelinedRDD(self, add_shuffle_key) keyed._bypass_serializer = True pairRDD = self.ctx._jvm.PairwiseRDD(keyed._jrdd.rdd()).asJavaPairRDD() @@ -970,7 +970,7 @@ class PipelinedRDD(RDD): else: serializer = self.ctx.serializer command = (self.func, self._prev_jrdd_deserializer, serializer) - pickled_command = CloudPickleSerializer()._dumps(command) + pickled_command = CloudPickleSerializer().dumps(command) broadcast_vars = ListConverter().convert( [x._jbroadcast for x in self.ctx._pickled_broadcast_vars], self.ctx._gateway._gateway_client) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index b23804b33c..9338df69ff 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -125,7 +125,7 @@ class FramedSerializer(Serializer): return def _write_with_length(self, obj, stream): - serialized = self._dumps(obj) + serialized = self.dumps(obj) write_int(len(serialized), stream) stream.write(serialized) @@ -134,16 +134,16 @@ class FramedSerializer(Serializer): obj = stream.read(length) if obj == "": raise EOFError - return self._loads(obj) + return self.loads(obj) - def _dumps(self, obj): + def dumps(self, obj): """ Serialize an object into a byte array. When batching is used, this will be called with an array of objects. """ raise NotImplementedError - def _loads(self, obj): + def loads(self, obj): """ Deserialize an object from a byte array. """ @@ -228,8 +228,8 @@ class CartesianDeserializer(FramedSerializer): class NoOpSerializer(FramedSerializer): - def _loads(self, obj): return obj - def _dumps(self, obj): return obj + def loads(self, obj): return obj + def dumps(self, obj): return obj class PickleSerializer(FramedSerializer): @@ -242,12 +242,12 @@ class PickleSerializer(FramedSerializer): not be as fast as more specialized serializers. """ - def _dumps(self, obj): return cPickle.dumps(obj, 2) - _loads = cPickle.loads + def dumps(self, obj): return cPickle.dumps(obj, 2) + loads = cPickle.loads class CloudPickleSerializer(PickleSerializer): - def _dumps(self, obj): return cloudpickle.dumps(obj, 2) + def dumps(self, obj): return cloudpickle.dumps(obj, 2) class MarshalSerializer(FramedSerializer): @@ -259,8 +259,8 @@ class MarshalSerializer(FramedSerializer): This serializer is faster than PickleSerializer but supports fewer datatypes. """ - _dumps = marshal.dumps - _loads = marshal.loads + dumps = marshal.dumps + loads = marshal.loads class MUTF8Deserializer(Serializer): @@ -268,14 +268,14 @@ class MUTF8Deserializer(Serializer): Deserializes streams written by Java's DataOutputStream.writeUTF(). """ - def _loads(self, stream): + def loads(self, stream): length = struct.unpack('>H', stream.read(2))[0] return stream.read(length).decode('utf8') def load_stream(self, stream): while True: try: - yield self._loads(stream) + yield self.loads(stream) except struct.error: return except EOFError: diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py index 2751f1239e..f2b3f3c142 100644 --- a/python/pyspark/worker.py +++ b/python/pyspark/worker.py @@ -51,7 +51,7 @@ def main(infile, outfile): return # fetch name of workdir - spark_files_dir = mutf8_deserializer._loads(infile) + spark_files_dir = mutf8_deserializer.loads(infile) SparkFiles._root_directory = spark_files_dir SparkFiles._is_running_on_worker = True @@ -66,7 +66,7 @@ def main(infile, outfile): sys.path.append(spark_files_dir) # *.py files that were added will be copied here num_python_includes = read_int(infile) for _ in range(num_python_includes): - filename = mutf8_deserializer._loads(infile) + filename = mutf8_deserializer.loads(infile) sys.path.append(os.path.join(spark_files_dir, filename)) command = pickleSer._read_with_length(infile) -- cgit v1.2.3 From 7de180fd13fda2e5d4486dfca9e2a9997ec7f4d0 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 18 Nov 2013 20:05:05 -0800 Subject: Remove explicit boxing --- core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index b002468442..70f7f01d2b 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -169,9 +169,9 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav * If the RDD contains infinity, NaN throws an exception * If the elements in RDD do not vary (max == min) always returns a single bucket. */ - def histogram(bucketCount: Int): Pair[Array[Double], Array[Long]] = { + def histogram(bucketCount: Int): Pair[Array[scala.Double], Array[Long]] = { val result = srdd.histogram(bucketCount) - (result._1.map(scala.Double.box(_)), result._2) + (result._1, result._2) } /** -- cgit v1.2.3 From e163e31c2003558d304ba5ac7b67361956037041 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 18 Nov 2013 20:13:25 -0800 Subject: Add spaces --- .../test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala index 0d8ac19024..7f50a5a47c 100644 --- a/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala +++ b/core/src/test/scala/org/apache/spark/rdd/DoubleRDDSuite.scala @@ -39,6 +39,7 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { assert(histogramResults === expectedHistogramResults) assert(histogramResults2 === expectedHistogramResults) } + test("WorksWithOutOfRangeWithOneBucket") { // Verify that if all of the elements are out of range the counts are zero val rdd = sc.parallelize(Seq(10.01, -0.01)) @@ -49,6 +50,7 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { assert(histogramResults === expectedHistogramResults) assert(histogramResults2 === expectedHistogramResults) } + test("WorksInRangeWithOneBucket") { // Verify the basic case of one bucket and all elements in that bucket works val rdd = sc.parallelize(Seq(1, 2, 3, 4)) @@ -59,6 +61,7 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { assert(histogramResults === expectedHistogramResults) assert(histogramResults2 === expectedHistogramResults) } + test("WorksInRangeWithOneBucketExactMatch") { // Verify the basic case of one bucket and all elements in that bucket works val rdd = sc.parallelize(Seq(1, 2, 3, 4)) @@ -69,6 +72,7 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { assert(histogramResults === expectedHistogramResults) assert(histogramResults2 === expectedHistogramResults) } + test("WorksWithOutOfRangeWithTwoBuckets") { // Verify that out of range works with two buckets val rdd = sc.parallelize(Seq(10.01, -0.01)) @@ -79,6 +83,7 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { assert(histogramResults === expectedHistogramResults) assert(histogramResults2 === expectedHistogramResults) } + test("WorksWithOutOfRangeWithTwoUnEvenBuckets") { // Verify that out of range works with two un even buckets val rdd = sc.parallelize(Seq(10.01, -0.01)) @@ -87,6 +92,7 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { val expectedHistogramResults = Array(0, 0) assert(histogramResults === expectedHistogramResults) } + test("WorksInRangeWithTwoBuckets") { // Make sure that it works with two equally spaced buckets and elements in each val rdd = sc.parallelize(Seq(1, 2, 3, 5, 6)) @@ -97,6 +103,7 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { assert(histogramResults === expectedHistogramResults) assert(histogramResults2 === expectedHistogramResults) } + test("WorksInRangeWithTwoBucketsAndNaN") { // Make sure that it works with two equally spaced buckets and elements in each val rdd = sc.parallelize(Seq(1, 2, 3, 5, 6, Double.NaN)) @@ -107,6 +114,7 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { assert(histogramResults === expectedHistogramResults) assert(histogramResults2 === expectedHistogramResults) } + test("WorksInRangeWithTwoUnevenBuckets") { // Make sure that it works with two unequally spaced buckets and elements in each val rdd = sc.parallelize(Seq(1, 2, 3, 5, 6)) @@ -115,6 +123,7 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { val expectedHistogramResults = Array(3, 2) assert(histogramResults === expectedHistogramResults) } + test("WorksMixedRangeWithTwoUnevenBuckets") { // Make sure that it works with two unequally spaced buckets and elements in each val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.0, 11.01)) @@ -123,6 +132,7 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { val expectedHistogramResults = Array(4, 3) assert(histogramResults === expectedHistogramResults) } + test("WorksMixedRangeWithFourUnevenBuckets") { // Make sure that it works with two unequally spaced buckets and elements in each val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, @@ -132,6 +142,7 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { val expectedHistogramResults = Array(4, 2, 1, 3) assert(histogramResults === expectedHistogramResults) } + test("WorksMixedRangeWithUnevenBucketsAndNaN") { // Make sure that it works with two unequally spaced buckets and elements in each val rdd = sc.parallelize(Seq(-0.01, 0.0, 1, 2, 3, 5, 6, 11.01, 12.0, 199.0, @@ -161,6 +172,7 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { val expectedHistogramResults = Array(4, 2, 1, 2, 4) assert(histogramResults === expectedHistogramResults) } + test("WorksWithOutOfRangeWithInfiniteBuckets") { // Verify that out of range works with two buckets val rdd = sc.parallelize(Seq(10.01, -0.01, Double.NaN)) @@ -214,6 +226,7 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { assert(histogramResults === expectedHistogramResults) assert(histogramBuckets === expectedHistogramBuckets) } + test("WorksWithoutBucketsBasicTwo") { // Verify the basic case of one bucket and all elements in that bucket works val rdd = sc.parallelize(Seq(1, 2, 3, 4)) @@ -223,6 +236,7 @@ class DoubleRDDSuite extends FunSuite with SharedSparkContext { assert(histogramResults === expectedHistogramResults) assert(histogramBuckets === expectedHistogramBuckets) } + test("WorksWithoutBucketsWithMoreRequestedThanElements") { // Verify the basic case of one bucket and all elements in that bucket works val rdd = sc.parallelize(Seq(1, 2)) -- cgit v1.2.3 From ab3cefde5349d0de85b23b49feef493ff0b2d1ed Mon Sep 17 00:00:00 2001 From: Raymond Liu Date: Wed, 23 Oct 2013 09:42:25 +0800 Subject: Add YarnClientClusterScheduler and Backend. With this scheduler, the user application is launched locally, While the executor will be launched by YARN on remote nodes. This enables spark-shell to run upon YARN. --- .../main/scala/org/apache/spark/SparkContext.scala | 25 +++ docs/running-on-yarn.md | 27 ++- .../org/apache/spark/deploy/yarn/Client.scala | 13 +- .../apache/spark/deploy/yarn/ClientArguments.scala | 40 ++-- .../apache/spark/deploy/yarn/WorkerLauncher.scala | 246 +++++++++++++++++++++ .../cluster/YarnClientClusterScheduler.scala | 47 ++++ .../cluster/YarnClientSchedulerBackend.scala | 109 +++++++++ 7 files changed, 484 insertions(+), 23 deletions(-) create mode 100644 yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala create mode 100644 yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala create mode 100644 yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 42b2985b50..3a80241daa 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -226,6 +226,31 @@ class SparkContext( scheduler.initialize(backend) scheduler + case "yarn-client" => + val scheduler = try { + val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClientClusterScheduler") + val cons = clazz.getConstructor(classOf[SparkContext]) + cons.newInstance(this).asInstanceOf[ClusterScheduler] + + } catch { + case th: Throwable => { + throw new SparkException("YARN mode not available ?", th) + } + } + + val backend = try { + val clazz = Class.forName("org.apache.spark.scheduler.cluster.YarnClientSchedulerBackend") + val cons = clazz.getConstructor(classOf[ClusterScheduler], classOf[SparkContext]) + cons.newInstance(scheduler, this).asInstanceOf[CoarseGrainedSchedulerBackend] + } catch { + case th: Throwable => { + throw new SparkException("YARN mode not available ?", th) + } + } + + scheduler.initialize(backend) + scheduler + case MESOS_REGEX(mesosUrl) => MesosNativeLibrary.load() val scheduler = new ClusterScheduler(this) diff --git a/docs/running-on-yarn.md b/docs/running-on-yarn.md index 4056e9c15d..68fd6c2ab1 100644 --- a/docs/running-on-yarn.md +++ b/docs/running-on-yarn.md @@ -45,6 +45,10 @@ System Properties: Ensure that HADOOP_CONF_DIR or YARN_CONF_DIR points to the directory which contains the (client side) configuration files for the hadoop cluster. This would be used to connect to the cluster, write to the dfs and submit jobs to the resource manager. +There are two scheduler mode that can be used to launch spark application on YARN. + +## Launch spark application by YARN Client with yarn-standalone mode. + The command to launch the YARN Client is as follows: SPARK_JAR= ./spark-class org.apache.spark.deploy.yarn.Client \ @@ -52,6 +56,7 @@ The command to launch the YARN Client is as follows: --class \ --args \ --num-workers \ + --master-class --master-memory \ --worker-memory \ --worker-cores \ @@ -85,11 +90,29 @@ For example: $ cat $YARN_APP_LOGS_DIR/$YARN_APP_ID/container*_000001/stdout Pi is roughly 3.13794 -The above starts a YARN Client programs which periodically polls the Application Master for status updates and displays them in the console. The client will exit once your application has finished running. +The above starts a YARN Client programs which start the default Application Master. Then SparkPi will be run as a child thread of Application Master, YARN Client will periodically polls the Application Master for status updates and displays them in the console. The client will exit once your application has finished running. + +With this mode, your application is actually run on the remote machine where the Application Master is run upon. Thus application that involve local interaction will not work well, e.g. spark-shell. + +## Launch spark application with yarn-client mode. + +With yarn-client mode, the application will be launched locally. Just like running application or spark-shell on Local / Mesos / Standalone mode. The launch method is also the similar with them, just make sure that when you need to specify a master url, use "yarn-client" instead. And you also need to export the env value for SPARK_JAR and SPARK_YARN_APP_JAR + +In order to tune worker core/number/memory etc. You need to export SPARK_WORKER_CORES, SPARK_WORKER_MEMORY, SPARK_WORKER_INSTANCES e.g. by ./conf/spark-env.sh + +For example: + + SPARK_JAR=./assembly/target/scala-{{site.SCALA_VERSION}}/spark-assembly-{{site.SPARK_VERSION}}-hadoop2.0.5-alpha.jar \ + SPARK_YARN_APP_JAR=examples/target/scala-{{site.SCALA_VERSION}}/spark-examples-assembly-{{site.SPARK_VERSION}}.jar \ + ./run-example org.apache.spark.examples.SparkPi yarn-client + + + SPARK_JAR=./assembly/target/scala-{{site.SCALA_VERSION}}/spark-assembly-{{site.SPARK_VERSION}}-hadoop2.0.5-alpha.jar \ + SPARK_YARN_APP_JAR=examples/target/scala-{{site.SCALA_VERSION}}/spark-examples-assembly-{{site.SPARK_VERSION}}.jar \ + MASTER=yarn-client ./spark-shell # Important Notes -- When your application instantiates a Spark context it must use a special "yarn-standalone" master url. This starts the scheduler without forcing it to connect to a cluster. A good way to handle this is to pass "yarn-standalone" as an argument to your program, as shown in the example above. - We do not requesting container resources based on the number of cores. Thus the numbers of cores given via command line arguments cannot be guaranteed. - The local directories used for spark will be the local directories configured for YARN (Hadoop Yarn config yarn.nodemanager.local-dirs). If the user specifies spark.local.dir, it will be ignored. - The --files and --archives options support specifying file names with the # similar to Hadoop. For example you can specify: --files localtest.txt#appSees.txt and this will upload the file you have locally named localtest.txt into HDFS but this will be linked to by the name appSees.txt and your application should use the name as appSees.txt to reference it when running on YARN. diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 94e353af2e..bb73f6d337 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -54,9 +54,10 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl // staging directory is private! -> rwx-------- val STAGING_DIR_PERMISSION: FsPermission = FsPermission.createImmutable(0700:Short) // app files are world-wide readable and owner writable -> rw-r--r-- - val APP_FILE_PERMISSION: FsPermission = FsPermission.createImmutable(0644:Short) + val APP_FILE_PERMISSION: FsPermission = FsPermission.createImmutable(0644:Short) - def run() { + // for client user who want to monitor app status by itself. + def runApp() = { validateArgs() init(yarnConf) @@ -78,7 +79,11 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl appContext.setUser(UserGroupInformation.getCurrentUser().getShortUserName()) submitApp(appContext) - + appId + } + + def run() { + val appId = runApp() monitorApplication(appId) System.exit(0) } @@ -372,7 +377,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl val commands = List[String](javaCommand + " -server " + JAVA_OPTS + - " org.apache.spark.deploy.yarn.ApplicationMaster" + + " " + args.amClass + " --class " + args.userClass + " --jar " + args.userJar + userArgsToString(args) + diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 852dbd7dab..b9dbc3fb87 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -35,6 +35,7 @@ class ClientArguments(val args: Array[String]) { var numWorkers = 2 var amQueue = System.getProperty("QUEUE", "default") var amMemory: Int = 512 + var amClass: String = "org.apache.spark.deploy.yarn.ApplicationMaster" var appName: String = "Spark" // TODO var inputFormatInfo: List[InputFormatInfo] = null @@ -62,18 +63,22 @@ class ClientArguments(val args: Array[String]) { userArgsBuffer += value args = tail - case ("--master-memory") :: MemoryParam(value) :: tail => - amMemory = value + case ("--master-class") :: value :: tail => + amClass = value args = tail - case ("--num-workers") :: IntParam(value) :: tail => - numWorkers = value + case ("--master-memory") :: MemoryParam(value) :: tail => + amMemory = value args = tail case ("--worker-memory") :: MemoryParam(value) :: tail => workerMemory = value args = tail + case ("--num-workers") :: IntParam(value) :: tail => + numWorkers = value + args = tail + case ("--worker-cores") :: IntParam(value) :: tail => workerCores = value args = tail @@ -119,19 +124,20 @@ class ClientArguments(val args: Array[String]) { System.err.println( "Usage: org.apache.spark.deploy.yarn.Client [options] \n" + "Options:\n" + - " --jar JAR_PATH Path to your application's JAR file (required)\n" + - " --class CLASS_NAME Name of your application's main class (required)\n" + - " --args ARGS Arguments to be passed to your application's main class.\n" + - " Mutliple invocations are possible, each will be passed in order.\n" + - " --num-workers NUM Number of workers to start (Default: 2)\n" + - " --worker-cores NUM Number of cores for the workers (Default: 1). This is unsused right now.\n" + - " --master-memory MEM Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb)\n" + - " --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n" + - " --name NAME The name of your application (Default: Spark)\n" + - " --queue QUEUE The hadoop queue to use for allocation requests (Default: 'default')\n" + - " --addJars jars Comma separated list of local jars that want SparkContext.addJar to work with.\n" + - " --files files Comma separated list of files to be distributed with the job.\n" + - " --archives archives Comma separated list of archives to be distributed with the job." + " --jar JAR_PATH Path to your application's JAR file (required)\n" + + " --class CLASS_NAME Name of your application's main class (required)\n" + + " --args ARGS Arguments to be passed to your application's main class.\n" + + " Mutliple invocations are possible, each will be passed in order.\n" + + " --num-workers NUM Number of workers to start (Default: 2)\n" + + " --worker-cores NUM Number of cores for the workers (Default: 1). This is unsused right now.\n" + + " --master-class CLASS_NAME Class Name for Master (Default: spark.deploy.yarn.ApplicationMaster)\n" + + " --master-memory MEM Memory for Master (e.g. 1000M, 2G) (Default: 512 Mb)\n" + + " --worker-memory MEM Memory per Worker (e.g. 1000M, 2G) (Default: 1G)\n" + + " --name NAME The name of your application (Default: Spark)\n" + + " --queue QUEUE The hadoop queue to use for allocation requests (Default: 'default')\n" + + " --addJars jars Comma separated list of local jars that want SparkContext.addJar to work with.\n" + + " --files files Comma separated list of files to be distributed with the job.\n" + + " --archives archives Comma separated list of archives to be distributed with the job." ) System.exit(exitCode) } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala new file mode 100644 index 0000000000..421a83c87a --- /dev/null +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/WorkerLauncher.scala @@ -0,0 +1,246 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.deploy.yarn + +import java.net.Socket +import org.apache.hadoop.conf.Configuration +import org.apache.hadoop.net.NetUtils +import org.apache.hadoop.yarn.api._ +import org.apache.hadoop.yarn.api.records._ +import org.apache.hadoop.yarn.api.protocolrecords._ +import org.apache.hadoop.yarn.conf.YarnConfiguration +import org.apache.hadoop.yarn.ipc.YarnRPC +import org.apache.hadoop.yarn.util.{ConverterUtils, Records} +import akka.actor._ +import akka.remote.{RemoteClientShutdown, RemoteClientDisconnected, RemoteClientLifeCycleEvent} +import akka.remote.RemoteClientShutdown +import akka.actor.Terminated +import akka.remote.RemoteClientDisconnected +import org.apache.spark.{SparkContext, Logging} +import org.apache.spark.util.{Utils, AkkaUtils} +import org.apache.spark.scheduler.cluster.CoarseGrainedSchedulerBackend +import org.apache.spark.scheduler.SplitInfo + +class WorkerLauncher(args: ApplicationMasterArguments, conf: Configuration) extends Logging { + + def this(args: ApplicationMasterArguments) = this(args, new Configuration()) + + private val rpc: YarnRPC = YarnRPC.create(conf) + private var resourceManager: AMRMProtocol = null + private var appAttemptId: ApplicationAttemptId = null + private var reporterThread: Thread = null + private val yarnConf: YarnConfiguration = new YarnConfiguration(conf) + + private var yarnAllocator: YarnAllocationHandler = null + private var driverClosed:Boolean = false + + val actorSystem : ActorSystem = AkkaUtils.createActorSystem("sparkYarnAM", Utils.localHostName, 0)._1 + var actor: ActorRef = null + + // This actor just working as a monitor to watch on Driver Actor. + class MonitorActor(driverUrl: String) extends Actor { + + var driver: ActorRef = null + + override def preStart() { + logInfo("Listen to driver: " + driverUrl) + driver = context.actorFor(driverUrl) + driver ! "hello" + context.system.eventStream.subscribe(self, classOf[RemoteClientLifeCycleEvent]) + context.watch(driver) // Doesn't work with remote actors, but useful for testing + } + + override def receive = { + case Terminated(_) | RemoteClientDisconnected(_, _) | RemoteClientShutdown(_, _) => + logInfo("Driver terminated or disconnected! Shutting down.") + driverClosed = true + } + } + + def run() { + + appAttemptId = getApplicationAttemptId() + resourceManager = registerWithResourceManager() + val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster() + + // Compute number of threads for akka + val minimumMemory = appMasterResponse.getMinimumResourceCapability().getMemory() + + if (minimumMemory > 0) { + val mem = args.workerMemory + YarnAllocationHandler.MEMORY_OVERHEAD + val numCore = (mem / minimumMemory) + (if (0 != (mem % minimumMemory)) 1 else 0) + + if (numCore > 0) { + // do not override - hits https://issues.apache.org/jira/browse/HADOOP-8406 + // TODO: Uncomment when hadoop is on a version which has this fixed. + // args.workerCores = numCore + } + } + + waitForSparkMaster() + + // Allocate all containers + allocateWorkers() + + // Launch a progress reporter thread, else app will get killed after expiration (def: 10mins) timeout + // ensure that progress is sent before YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS elapse. + + val timeoutInterval = yarnConf.getInt(YarnConfiguration.RM_AM_EXPIRY_INTERVAL_MS, 120000) + // must be <= timeoutInterval/ 2. + // On other hand, also ensure that we are reasonably responsive without causing too many requests to RM. + // so atleast 1 minute or timeoutInterval / 10 - whichever is higher. + val interval = math.min(timeoutInterval / 2, math.max(timeoutInterval/ 10, 60000L)) + reporterThread = launchReporterThread(interval) + + // Wait for the reporter thread to Finish. + reporterThread.join() + + finishApplicationMaster(FinalApplicationStatus.SUCCEEDED) + actorSystem.shutdown() + + logInfo("Exited") + System.exit(0) + } + + private def getApplicationAttemptId(): ApplicationAttemptId = { + val envs = System.getenv() + val containerIdString = envs.get(ApplicationConstants.AM_CONTAINER_ID_ENV) + val containerId = ConverterUtils.toContainerId(containerIdString) + val appAttemptId = containerId.getApplicationAttemptId() + logInfo("ApplicationAttemptId: " + appAttemptId) + return appAttemptId + } + + private def registerWithResourceManager(): AMRMProtocol = { + val rmAddress = NetUtils.createSocketAddr(yarnConf.get( + YarnConfiguration.RM_SCHEDULER_ADDRESS, + YarnConfiguration.DEFAULT_RM_SCHEDULER_ADDRESS)) + logInfo("Connecting to ResourceManager at " + rmAddress) + return rpc.getProxy(classOf[AMRMProtocol], rmAddress, conf).asInstanceOf[AMRMProtocol] + } + + private def registerApplicationMaster(): RegisterApplicationMasterResponse = { + logInfo("Registering the ApplicationMaster") + val appMasterRequest = Records.newRecord(classOf[RegisterApplicationMasterRequest]) + .asInstanceOf[RegisterApplicationMasterRequest] + appMasterRequest.setApplicationAttemptId(appAttemptId) + // Setting this to master host,port - so that the ApplicationReport at client has some sensible info. + // Users can then monitor stderr/stdout on that node if required. + appMasterRequest.setHost(Utils.localHostName()) + appMasterRequest.setRpcPort(0) + // What do we provide here ? Might make sense to expose something sensible later ? + appMasterRequest.setTrackingUrl("") + return resourceManager.registerApplicationMaster(appMasterRequest) + } + + private def waitForSparkMaster() { + logInfo("Waiting for spark driver to be reachable.") + var driverUp = false + val hostport = args.userArgs(0) + val (driverHost, driverPort) = Utils.parseHostPort(hostport) + while(!driverUp) { + try { + val socket = new Socket(driverHost, driverPort) + socket.close() + logInfo("Master now available: " + driverHost + ":" + driverPort) + driverUp = true + } catch { + case e: Exception => + logError("Failed to connect to driver at " + driverHost + ":" + driverPort) + Thread.sleep(100) + } + } + System.setProperty("spark.driver.host", driverHost) + System.setProperty("spark.driver.port", driverPort.toString) + + val driverUrl = "akka://spark@%s:%s/user/%s".format( + driverHost, driverPort.toString, CoarseGrainedSchedulerBackend.ACTOR_NAME) + + actor = actorSystem.actorOf(Props(new MonitorActor(driverUrl)), name = "YarnAM") + } + + + private def allocateWorkers() { + + // Fixme: should get preferredNodeLocationData from SparkContext, just fake a empty one for now. + val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] = scala.collection.immutable.Map() + + yarnAllocator = YarnAllocationHandler.newAllocator(yarnConf, resourceManager, appAttemptId, args, preferredNodeLocationData) + + logInfo("Allocating " + args.numWorkers + " workers.") + // Wait until all containers have finished + // TODO: This is a bit ugly. Can we make it nicer? + // TODO: Handle container failure + while(yarnAllocator.getNumWorkersRunning < args.numWorkers) { + yarnAllocator.allocateContainers(math.max(args.numWorkers - yarnAllocator.getNumWorkersRunning, 0)) + Thread.sleep(100) + } + + logInfo("All workers have launched.") + + } + + // TODO: We might want to extend this to allocate more containers in case they die ! + private def launchReporterThread(_sleepTime: Long): Thread = { + val sleepTime = if (_sleepTime <= 0 ) 0 else _sleepTime + + val t = new Thread { + override def run() { + while (!driverClosed) { + val missingWorkerCount = args.numWorkers - yarnAllocator.getNumWorkersRunning + if (missingWorkerCount > 0) { + logInfo("Allocating " + missingWorkerCount + " containers to make up for (potentially ?) lost containers") + yarnAllocator.allocateContainers(missingWorkerCount) + } + else sendProgress() + Thread.sleep(sleepTime) + } + } + } + // setting to daemon status, though this is usually not a good idea. + t.setDaemon(true) + t.start() + logInfo("Started progress reporter thread - sleep time : " + sleepTime) + return t + } + + private def sendProgress() { + logDebug("Sending progress") + // simulated with an allocate request with no nodes requested ... + yarnAllocator.allocateContainers(0) + } + + def finishApplicationMaster(status: FinalApplicationStatus) { + + logInfo("finish ApplicationMaster with " + status) + val finishReq = Records.newRecord(classOf[FinishApplicationMasterRequest]) + .asInstanceOf[FinishApplicationMasterRequest] + finishReq.setAppAttemptId(appAttemptId) + finishReq.setFinishApplicationStatus(status) + resourceManager.finishApplicationMaster(finishReq) + } + +} + + +object WorkerLauncher { + def main(argStrings: Array[String]) { + val args = new ApplicationMasterArguments(argStrings) + new WorkerLauncher(args).run() + } +} diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala new file mode 100644 index 0000000000..63a0449e5a --- /dev/null +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientClusterScheduler.scala @@ -0,0 +1,47 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import org.apache.spark._ +import org.apache.hadoop.conf.Configuration +import org.apache.spark.deploy.yarn.YarnAllocationHandler +import org.apache.spark.util.Utils + +/** + * + * This scheduler launch worker through Yarn - by call into Client to launch WorkerLauncher as AM. + */ +private[spark] class YarnClientClusterScheduler(sc: SparkContext, conf: Configuration) extends ClusterScheduler(sc) { + + def this(sc: SparkContext) = this(sc, new Configuration()) + + // By default, rack is unknown + override def getRackForHost(hostPort: String): Option[String] = { + val host = Utils.parseHostPort(hostPort)._1 + val retval = YarnAllocationHandler.lookupRack(conf, host) + if (retval != null) Some(retval) else None + } + + override def postStartHook() { + + // The yarn application is running, but the worker might not yet ready + // Wait for a few seconds for the slaves to bootstrap and register with master - best case attempt + Thread.sleep(2000L) + logInfo("YarnClientClusterScheduler.postStartHook done") + } +} diff --git a/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala new file mode 100644 index 0000000000..b206780c78 --- /dev/null +++ b/yarn/src/main/scala/org/apache/spark/scheduler/cluster/YarnClientSchedulerBackend.scala @@ -0,0 +1,109 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.scheduler.cluster + +import org.apache.hadoop.yarn.api.records.{ApplicationId, YarnApplicationState} +import org.apache.spark.{SparkException, Logging, SparkContext} +import org.apache.spark.deploy.yarn.{Client, ClientArguments} + +private[spark] class YarnClientSchedulerBackend( + scheduler: ClusterScheduler, + sc: SparkContext) + extends CoarseGrainedSchedulerBackend(scheduler, sc.env.actorSystem) + with Logging { + + var client: Client = null + var appId: ApplicationId = null + + override def start() { + super.start() + + val defalutWorkerCores = "2" + val defalutWorkerMemory = "512m" + val defaultWorkerNumber = "1" + + val userJar = System.getenv("SPARK_YARN_APP_JAR") + var workerCores = System.getenv("SPARK_WORKER_CORES") + var workerMemory = System.getenv("SPARK_WORKER_MEMORY") + var workerNumber = System.getenv("SPARK_WORKER_INSTANCES") + + if (userJar == null) + throw new SparkException("env SPARK_YARN_APP_JAR is not set") + + if (workerCores == null) + workerCores = defalutWorkerCores + if (workerMemory == null) + workerMemory = defalutWorkerMemory + if (workerNumber == null) + workerNumber = defaultWorkerNumber + + val driverHost = System.getProperty("spark.driver.host") + val driverPort = System.getProperty("spark.driver.port") + val hostport = driverHost + ":" + driverPort + + val argsArray = Array[String]( + "--class", "notused", + "--jar", userJar, + "--args", hostport, + "--worker-memory", workerMemory, + "--worker-cores", workerCores, + "--num-workers", workerNumber, + "--master-class", "org.apache.spark.deploy.yarn.WorkerLauncher" + ) + + val args = new ClientArguments(argsArray) + client = new Client(args) + appId = client.runApp() + waitForApp() + } + + def waitForApp() { + + // TODO : need a better way to find out whether the workers are ready or not + // maybe by resource usage report? + while(true) { + val report = client.getApplicationReport(appId) + + logInfo("Application report from ASM: \n" + + "\t appMasterRpcPort: " + report.getRpcPort() + "\n" + + "\t appStartTime: " + report.getStartTime() + "\n" + + "\t yarnAppState: " + report.getYarnApplicationState() + "\n" + ) + + // Ready to go, or already gone. + val state = report.getYarnApplicationState() + if (state == YarnApplicationState.RUNNING) { + return + } else if (state == YarnApplicationState.FINISHED || + state == YarnApplicationState.FAILED || + state == YarnApplicationState.KILLED) { + throw new SparkException("Yarn application already ended," + + "might be killed or not able to launch application master.") + } + + Thread.sleep(1000) + } + } + + override def stop() { + super.stop() + client.stop() + logInfo("Stoped") + } + +} -- cgit v1.2.3 From e9ff13ec72718ada705b85cc10da1b09bcc86dcc Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Sun, 24 Nov 2013 17:56:43 +0800 Subject: Consolidated both mapPartitions related RDDs into a single MapPartitionsRDD. Also changed the semantics of the index parameter in mapPartitionsWithIndex from the partition index of the output partition to the partition index in the current RDD. --- .../org/apache/spark/rdd/MapPartitionsRDD.scala | 10 +++--- .../spark/rdd/MapPartitionsWithContextRDD.scala | 41 ---------------------- core/src/main/scala/org/apache/spark/rdd/RDD.scala | 39 ++++++++++---------- .../scala/org/apache/spark/CheckpointSuite.scala | 2 -- 4 files changed, 22 insertions(+), 70 deletions(-) delete mode 100644 core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala index 203179c4ea..ae70d55951 100644 --- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsRDD.scala @@ -20,18 +20,16 @@ package org.apache.spark.rdd import org.apache.spark.{Partition, TaskContext} -private[spark] -class MapPartitionsRDD[U: ClassManifest, T: ClassManifest]( +private[spark] class MapPartitionsRDD[U: ClassManifest, T: ClassManifest]( prev: RDD[T], - f: Iterator[T] => Iterator[U], + f: (TaskContext, Int, Iterator[T]) => Iterator[U], // (TaskContext, partition index, iterator) preservesPartitioning: Boolean = false) extends RDD[U](prev) { - override val partitioner = - if (preservesPartitioning) firstParent[T].partitioner else None + override val partitioner = if (preservesPartitioning) firstParent[T].partitioner else None override def getPartitions: Array[Partition] = firstParent[T].partitions override def compute(split: Partition, context: TaskContext) = - f(firstParent[T].iterator(split, context)) + f(context, split.index, firstParent[T].iterator(split, context)) } diff --git a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala b/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala deleted file mode 100644 index aea08ff81b..0000000000 --- a/core/src/main/scala/org/apache/spark/rdd/MapPartitionsWithContextRDD.scala +++ /dev/null @@ -1,41 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.rdd - -import org.apache.spark.{Partition, TaskContext} - - -/** - * A variant of the MapPartitionsRDD that passes the TaskContext into the closure. From the - * TaskContext, the closure can either get access to the interruptible flag or get the index - * of the partition in the RDD. - */ -private[spark] -class MapPartitionsWithContextRDD[U: ClassManifest, T: ClassManifest]( - prev: RDD[T], - f: (TaskContext, Iterator[T]) => Iterator[U], - preservesPartitioning: Boolean - ) extends RDD[U](prev) { - - override def getPartitions: Array[Partition] = firstParent[T].partitions - - override val partitioner = if (preservesPartitioning) prev.partitioner else None - - override def compute(split: Partition, context: TaskContext) = - f(context, firstParent[T].iterator(split, context)) -} diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala index 7623c44d88..5b1285307d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala @@ -408,7 +408,6 @@ abstract class RDD[T: ClassManifest]( def pipe(command: String, env: Map[String, String]): RDD[String] = new PipedRDD(this, command, env) - /** * Return an RDD created by piping elements to a forked external process. * The print behavior can be customized by providing two functions. @@ -442,7 +441,8 @@ abstract class RDD[T: ClassManifest]( */ def mapPartitions[U: ClassManifest]( f: Iterator[T] => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { - new MapPartitionsRDD(this, sc.clean(f), preservesPartitioning) + val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(iter) + new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning) } /** @@ -451,8 +451,8 @@ abstract class RDD[T: ClassManifest]( */ def mapPartitionsWithIndex[U: ClassManifest]( f: (Int, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { - val func = (context: TaskContext, iter: Iterator[T]) => f(context.partitionId, iter) - new MapPartitionsWithContextRDD(this, sc.clean(func), preservesPartitioning) + val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(index, iter) + new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning) } /** @@ -462,7 +462,8 @@ abstract class RDD[T: ClassManifest]( def mapPartitionsWithContext[U: ClassManifest]( f: (TaskContext, Iterator[T]) => Iterator[U], preservesPartitioning: Boolean = false): RDD[U] = { - new MapPartitionsWithContextRDD(this, sc.clean(f), preservesPartitioning) + val func = (context: TaskContext, index: Int, iter: Iterator[T]) => f(context, iter) + new MapPartitionsRDD(this, sc.clean(func), preservesPartitioning) } /** @@ -483,11 +484,10 @@ abstract class RDD[T: ClassManifest]( def mapWith[A: ClassManifest, U: ClassManifest] (constructA: Int => A, preservesPartitioning: Boolean = false) (f: (T, A) => U): RDD[U] = { - def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = { - val a = constructA(context.partitionId) + mapPartitionsWithIndex((index, iter) => { + val a = constructA(index) iter.map(t => f(t, a)) - } - new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning) + }, preservesPartitioning) } /** @@ -498,11 +498,10 @@ abstract class RDD[T: ClassManifest]( def flatMapWith[A: ClassManifest, U: ClassManifest] (constructA: Int => A, preservesPartitioning: Boolean = false) (f: (T, A) => Seq[U]): RDD[U] = { - def iterF(context: TaskContext, iter: Iterator[T]): Iterator[U] = { - val a = constructA(context.partitionId) + mapPartitionsWithIndex((index, iter) => { + val a = constructA(index) iter.flatMap(t => f(t, a)) - } - new MapPartitionsWithContextRDD(this, sc.clean(iterF _), preservesPartitioning) + }, preservesPartitioning) } /** @@ -511,11 +510,10 @@ abstract class RDD[T: ClassManifest]( * partition with the index of that partition. */ def foreachWith[A: ClassManifest](constructA: Int => A)(f: (T, A) => Unit) { - def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = { - val a = constructA(context.partitionId) + mapPartitionsWithIndex { (index, iter) => + val a = constructA(index) iter.map(t => {f(t, a); t}) - } - new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true).foreach(_ => {}) + }.foreach(_ => {}) } /** @@ -524,11 +522,10 @@ abstract class RDD[T: ClassManifest]( * partition with the index of that partition. */ def filterWith[A: ClassManifest](constructA: Int => A)(p: (T, A) => Boolean): RDD[T] = { - def iterF(context: TaskContext, iter: Iterator[T]): Iterator[T] = { - val a = constructA(context.partitionId) + mapPartitionsWithIndex((index, iter) => { + val a = constructA(index) iter.filter(t => p(t, a)) - } - new MapPartitionsWithContextRDD(this, sc.clean(iterF _), true) + }, preservesPartitioning = true) } /** diff --git a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala index f26c44d3e7..d2226aa5a5 100644 --- a/core/src/test/scala/org/apache/spark/CheckpointSuite.scala +++ b/core/src/test/scala/org/apache/spark/CheckpointSuite.scala @@ -62,8 +62,6 @@ class CheckpointSuite extends FunSuite with LocalSparkContext with Logging { testCheckpointing(_.sample(false, 0.5, 0)) testCheckpointing(_.glom()) testCheckpointing(_.mapPartitions(_.map(_.toString))) - testCheckpointing(r => new MapPartitionsWithContextRDD(r, - (context: TaskContext, iter: Iterator[Int]) => iter.map(_.toString), false )) testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).mapValues(_.toString)) testCheckpointing(_.map(x => (x % 2, 1)).reduceByKey(_ + _).flatMapValues(x => 1 to x)) testCheckpointing(_.pipe(Seq("cat"))) -- cgit v1.2.3 From 95c55df1c21c1b8a90962415861b27ef91d3b20e Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 25 Nov 2013 18:27:06 +0800 Subject: Added unit tests for size estimation for specialized hash sets and maps. --- .../spark/util/collection/OpenHashMapSuite.scala | 16 +++- .../spark/util/collection/OpenHashSetSuite.scala | 20 +++- .../collection/PrimitiveKeyOpenHashMapSuite.scala | 102 +++++++++++++++++++++ .../collection/PrimitiveKeyOpenHashSetSuite.scala | 90 ------------------ 4 files changed, 135 insertions(+), 93 deletions(-) create mode 100644 core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala delete mode 100644 core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashSetSuite.scala diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala index ca3f684668..63e874fed3 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashMapSuite.scala @@ -2,8 +2,20 @@ package org.apache.spark.util.collection import scala.collection.mutable.HashSet import org.scalatest.FunSuite - -class OpenHashMapSuite extends FunSuite { +import org.scalatest.matchers.ShouldMatchers +import org.apache.spark.util.SizeEstimator + +class OpenHashMapSuite extends FunSuite with ShouldMatchers { + + test("size for specialized, primitive value (int)") { + val capacity = 1024 + val map = new OpenHashMap[String, Int](capacity) + val actualSize = SizeEstimator.estimate(map) + // 64 bit for pointers, 32 bit for ints, and 1 bit for the bitset. + val expectedSize = capacity * (64 + 32 + 1) / 8 + // Make sure we are not allocating a significant amount of memory beyond our expected. + actualSize should be <= (expectedSize * 1.1).toLong + } test("initialization") { val goodMap1 = new OpenHashMap[String, Int](1) diff --git a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala index 4e11e8a628..4768a1e60b 100644 --- a/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/collection/OpenHashSetSuite.scala @@ -1,9 +1,27 @@ package org.apache.spark.util.collection import org.scalatest.FunSuite +import org.scalatest.matchers.ShouldMatchers +import org.apache.spark.util.SizeEstimator -class OpenHashSetSuite extends FunSuite { + +class OpenHashSetSuite extends FunSuite with ShouldMatchers { + + test("size for specialized, primitive int") { + val loadFactor = 0.7 + val set = new OpenHashSet[Int](64, loadFactor) + for (i <- 0 until 1024) { + set.add(i) + } + assert(set.size === 1024) + assert(set.capacity > 1024) + val actualSize = SizeEstimator.estimate(set) + // 32 bits for the ints + 1 bit for the bitset + val expectedSize = set.capacity * (32 + 1) / 8 + // Make sure we are not allocating a significant amount of memory beyond our expected. + actualSize should be <= (expectedSize * 1.1).toLong + } test("primitive int") { val set = new OpenHashSet[Int] diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala new file mode 100644 index 0000000000..2220b4f0d5 --- /dev/null +++ b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashMapSuite.scala @@ -0,0 +1,102 @@ +package org.apache.spark.util.collection + +import scala.collection.mutable.HashSet +import org.scalatest.FunSuite +import org.scalatest.matchers.ShouldMatchers +import org.apache.spark.util.SizeEstimator + +class PrimitiveKeyOpenHashMapSuite extends FunSuite with ShouldMatchers { + + test("size for specialized, primitive key, value (int, int)") { + val capacity = 1024 + val map = new PrimitiveKeyOpenHashMap[Int, Int](capacity) + val actualSize = SizeEstimator.estimate(map) + // 32 bit for keys, 32 bit for values, and 1 bit for the bitset. + val expectedSize = capacity * (32 + 32 + 1) / 8 + // Make sure we are not allocating a significant amount of memory beyond our expected. + actualSize should be <= (expectedSize * 1.1).toLong + } + + test("initialization") { + val goodMap1 = new PrimitiveKeyOpenHashMap[Int, Int](1) + assert(goodMap1.size === 0) + val goodMap2 = new PrimitiveKeyOpenHashMap[Int, Int](255) + assert(goodMap2.size === 0) + val goodMap3 = new PrimitiveKeyOpenHashMap[Int, Int](256) + assert(goodMap3.size === 0) + intercept[IllegalArgumentException] { + new PrimitiveKeyOpenHashMap[Int, Int](1 << 30) // Invalid map size: bigger than 2^29 + } + intercept[IllegalArgumentException] { + new PrimitiveKeyOpenHashMap[Int, Int](-1) + } + intercept[IllegalArgumentException] { + new PrimitiveKeyOpenHashMap[Int, Int](0) + } + } + + test("basic operations") { + val longBase = 1000000L + val map = new PrimitiveKeyOpenHashMap[Long, Int] + + for (i <- 1 to 1000) { + map(i + longBase) = i + assert(map(i + longBase) === i) + } + + assert(map.size === 1000) + + for (i <- 1 to 1000) { + assert(map(i + longBase) === i) + } + + // Test iterator + val set = new HashSet[(Long, Int)] + for ((k, v) <- map) { + set.add((k, v)) + } + assert(set === (1 to 1000).map(x => (x + longBase, x)).toSet) + } + + test("null values") { + val map = new PrimitiveKeyOpenHashMap[Long, String]() + for (i <- 1 to 100) { + map(i.toLong) = null + } + assert(map.size === 100) + assert(map(1.toLong) === null) + } + + test("changeValue") { + val map = new PrimitiveKeyOpenHashMap[Long, String]() + for (i <- 1 to 100) { + map(i.toLong) = i.toString + } + assert(map.size === 100) + for (i <- 1 to 100) { + val res = map.changeValue(i.toLong, { assert(false); "" }, v => { + assert(v === i.toString) + v + "!" + }) + assert(res === i + "!") + } + // Iterate from 101 to 400 to make sure the map grows a couple of times, because we had a + // bug where changeValue would return the wrong result when the map grew on that insert + for (i <- 101 to 400) { + val res = map.changeValue(i.toLong, { i + "!" }, v => { assert(false); v }) + assert(res === i + "!") + } + assert(map.size === 400) + } + + test("inserting in capacity-1 map") { + val map = new PrimitiveKeyOpenHashMap[Long, String](1) + for (i <- 1 to 100) { + map(i.toLong) = i.toString + } + assert(map.size === 100) + for (i <- 1 to 100) { + assert(map(i.toLong) === i.toString) + } + } +} diff --git a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashSetSuite.scala b/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashSetSuite.scala deleted file mode 100644 index dfd6aed2c4..0000000000 --- a/core/src/test/scala/org/apache/spark/util/collection/PrimitiveKeyOpenHashSetSuite.scala +++ /dev/null @@ -1,90 +0,0 @@ -package org.apache.spark.util.collection - -import scala.collection.mutable.HashSet -import org.scalatest.FunSuite - -class PrimitiveKeyOpenHashSetSuite extends FunSuite { - - test("initialization") { - val goodMap1 = new PrimitiveKeyOpenHashMap[Int, Int](1) - assert(goodMap1.size === 0) - val goodMap2 = new PrimitiveKeyOpenHashMap[Int, Int](255) - assert(goodMap2.size === 0) - val goodMap3 = new PrimitiveKeyOpenHashMap[Int, Int](256) - assert(goodMap3.size === 0) - intercept[IllegalArgumentException] { - new PrimitiveKeyOpenHashMap[Int, Int](1 << 30) // Invalid map size: bigger than 2^29 - } - intercept[IllegalArgumentException] { - new PrimitiveKeyOpenHashMap[Int, Int](-1) - } - intercept[IllegalArgumentException] { - new PrimitiveKeyOpenHashMap[Int, Int](0) - } - } - - test("basic operations") { - val longBase = 1000000L - val map = new PrimitiveKeyOpenHashMap[Long, Int] - - for (i <- 1 to 1000) { - map(i + longBase) = i - assert(map(i + longBase) === i) - } - - assert(map.size === 1000) - - for (i <- 1 to 1000) { - assert(map(i + longBase) === i) - } - - // Test iterator - val set = new HashSet[(Long, Int)] - for ((k, v) <- map) { - set.add((k, v)) - } - assert(set === (1 to 1000).map(x => (x + longBase, x)).toSet) - } - - test("null values") { - val map = new PrimitiveKeyOpenHashMap[Long, String]() - for (i <- 1 to 100) { - map(i.toLong) = null - } - assert(map.size === 100) - assert(map(1.toLong) === null) - } - - test("changeValue") { - val map = new PrimitiveKeyOpenHashMap[Long, String]() - for (i <- 1 to 100) { - map(i.toLong) = i.toString - } - assert(map.size === 100) - for (i <- 1 to 100) { - val res = map.changeValue(i.toLong, { assert(false); "" }, v => { - assert(v === i.toString) - v + "!" - }) - assert(res === i + "!") - } - // Iterate from 101 to 400 to make sure the map grows a couple of times, because we had a - // bug where changeValue would return the wrong result when the map grew on that insert - for (i <- 101 to 400) { - val res = map.changeValue(i.toLong, { i + "!" }, v => { assert(false); v }) - assert(res === i + "!") - } - assert(map.size === 400) - } - - test("inserting in capacity-1 map") { - val map = new PrimitiveKeyOpenHashMap[Long, String](1) - for (i <- 1 to 100) { - map(i.toLong) = i.toString - } - assert(map.size === 100) - for (i <- 1 to 100) { - assert(map(i.toLong) === i.toString) - } - } -} -- cgit v1.2.3 From 466fd06475d8ed262c456421ed2dceba54229db1 Mon Sep 17 00:00:00 2001 From: Reynold Xin Date: Mon, 25 Nov 2013 18:27:26 +0800 Subject: Incorporated ideas from pull request #200. - Use Murmur Hash 3 finalization step to scramble the bits of HashCode instead of the simpler version in java.util.HashMap; the latter one had trouble with ranges of consecutive integers. Murmur Hash 3 is used by fastutil. - Don't check keys for equality when re-inserting due to growing the table; the keys will already be unique - Remember the grow threshold instead of recomputing it on each insert --- .../apache/spark/util/collection/OpenHashSet.scala | 107 +++++++++++---------- 1 file changed, 57 insertions(+), 50 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala index 4592e4f939..40986e3731 100644 --- a/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala +++ b/core/src/main/scala/org/apache/spark/util/collection/OpenHashSet.scala @@ -79,6 +79,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest]( protected var _capacity = nextPowerOf2(initialCapacity) protected var _mask = _capacity - 1 protected var _size = 0 + protected var _growThreshold = (loadFactor * _capacity).toInt protected var _bitset = new BitSet(_capacity) @@ -115,7 +116,29 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest]( * @return The position where the key is placed, plus the highest order bit is set if the key * exists previously. */ - def addWithoutResize(k: T): Int = putInto(_bitset, _data, k) + def addWithoutResize(k: T): Int = { + var pos = hashcode(hasher.hash(k)) & _mask + var i = 1 + while (true) { + if (!_bitset.get(pos)) { + // This is a new key. + _data(pos) = k + _bitset.set(pos) + _size += 1 + return pos | NONEXISTENCE_MASK + } else if (_data(pos) == k) { + // Found an existing key. + return pos + } else { + val delta = i + pos = (pos + delta) & _mask + i += 1 + } + } + // Never reached here + assert(INVALID_POS != INVALID_POS) + INVALID_POS + } /** * Rehash the set if it is overloaded. @@ -126,7 +149,7 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest]( * to a new position (in the new data array). */ def rehashIfNeeded(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) { - if (_size > loadFactor * _capacity) { + if (_size > _growThreshold) { rehash(k, allocateFunc, moveFunc) } } @@ -160,37 +183,6 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest]( */ def nextPos(fromPos: Int): Int = _bitset.nextSetBit(fromPos) - /** - * Put an entry into the set. Return the position where the key is placed. In addition, the - * highest bit in the returned position is set if the key exists prior to this put. - * - * This function assumes the data array has at least one empty slot. - */ - private def putInto(bitset: BitSet, data: Array[T], k: T): Int = { - val mask = data.length - 1 - var pos = hashcode(hasher.hash(k)) & mask - var i = 1 - while (true) { - if (!bitset.get(pos)) { - // This is a new key. - data(pos) = k - bitset.set(pos) - _size += 1 - return pos | NONEXISTENCE_MASK - } else if (data(pos) == k) { - // Found an existing key. - return pos - } else { - val delta = i - pos = (pos + delta) & mask - i += 1 - } - } - // Never reached here - assert(INVALID_POS != INVALID_POS) - INVALID_POS - } - /** * Double the table's size and re-hash everything. We are not really using k, but it is declared * so Scala compiler can specialize this method (which leads to calling the specialized version @@ -204,34 +196,49 @@ class OpenHashSet[@specialized(Long, Int) T: ClassManifest]( */ private def rehash(k: T, allocateFunc: (Int) => Unit, moveFunc: (Int, Int) => Unit) { val newCapacity = _capacity * 2 - require(newCapacity <= (1 << 29), "Can't make capacity bigger than 2^29 elements") - allocateFunc(newCapacity) - val newData = new Array[T](newCapacity) val newBitset = new BitSet(newCapacity) - var pos = 0 - _size = 0 - while (pos < _capacity) { - if (_bitset.get(pos)) { - val newPos = putInto(newBitset, newData, _data(pos)) - moveFunc(pos, newPos & POSITION_MASK) + val newData = new Array[T](newCapacity) + val newMask = newCapacity - 1 + + var oldPos = 0 + while (oldPos < capacity) { + if (_bitset.get(oldPos)) { + val key = _data(oldPos) + var newPos = hashcode(hasher.hash(key)) & newMask + var i = 1 + var keepGoing = true + // No need to check for equality here when we insert so this has one less if branch than + // the similar code path in addWithoutResize. + while (keepGoing) { + if (!newBitset.get(newPos)) { + // Inserting the key at newPos + newData(newPos) = key + newBitset.set(newPos) + moveFunc(oldPos, newPos) + keepGoing = false + } else { + val delta = i + newPos = (newPos + delta) & newMask + i += 1 + } + } } - pos += 1 + oldPos += 1 } + _bitset = newBitset _data = newData _capacity = newCapacity - _mask = newCapacity - 1 + _mask = newMask + _growThreshold = (loadFactor * newCapacity).toInt } /** - * Re-hash a value to deal better with hash functions that don't differ - * in the lower bits, similar to java.util.HashMap + * Re-hash a value to deal better with hash functions that don't differ in the lower bits. + * We use the Murmur Hash 3 finalization step that's also used in fastutil. */ - private def hashcode(h: Int): Int = { - val r = h ^ (h >>> 20) ^ (h >>> 12) - r ^ (r >>> 7) ^ (r >>> 4) - } + private def hashcode(h: Int): Int = it.unimi.dsi.fastutil.HashCommon.murmurHash3(h) private def nextPowerOf2(n: Int): Int = { val highBit = Integer.highestOneBit(n) -- cgit v1.2.3 From 08afef37a07c501b1ba14e3d6da445712852ca1e Mon Sep 17 00:00:00 2001 From: Andrew Ash Date: Mon, 25 Nov 2013 17:08:52 -0800 Subject: Update tuning.md Clarify when serializer is used based on recent user@ mailing list discussion. --- docs/tuning.md | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/docs/tuning.md b/docs/tuning.md index f33fda37eb..a4be188169 100644 --- a/docs/tuning.md +++ b/docs/tuning.md @@ -39,7 +39,8 @@ in your operations) and performance. It provides two serialization libraries: for best performance. You can switch to using Kryo by calling `System.setProperty("spark.serializer", "org.apache.spark.serializer.KryoSerializer")` -*before* creating your SparkContext. The only reason it is not the default is because of the custom +*before* creating your SparkContext. This setting configures the serializer used for not only shuffling data between worker +nodes but also when serializing RDDs to disk. The only reason Kryo is not the default is because of the custom registration requirement, but we recommend trying it in any network-intensive application. Finally, to register your classes with Kryo, create a public class that extends -- cgit v1.2.3 From 7222ee29779c3c5146aa5a3d6d060f3b039c1ff7 Mon Sep 17 00:00:00 2001 From: Holden Karau Date: Mon, 25 Nov 2013 21:06:42 -0800 Subject: Fix the test --- core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala | 4 ++-- core/src/test/scala/org/apache/spark/JavaAPISuite.java | 6 +++--- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala index 70f7f01d2b..dad5c72e1c 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaDoubleRDD.scala @@ -191,8 +191,8 @@ class JavaDoubleRDD(val srdd: RDD[scala.Double]) extends JavaRDDLike[Double, Jav * the maximum value of the last position and all NaN entries will be counted * in that bucket. */ - def histogram(buckets: Array[Double]): Array[Long] = { - srdd.histogram(buckets.map(_.toDouble), false) + def histogram(buckets: Array[scala.Double]): Array[Long] = { + srdd.histogram(buckets, false) } def histogram(buckets: Array[Double], evenBuckets: Boolean): Array[Long] = { diff --git a/core/src/test/scala/org/apache/spark/JavaAPISuite.java b/core/src/test/scala/org/apache/spark/JavaAPISuite.java index 8a9c6e63e0..44483fd4ab 100644 --- a/core/src/test/scala/org/apache/spark/JavaAPISuite.java +++ b/core/src/test/scala/org/apache/spark/JavaAPISuite.java @@ -368,10 +368,10 @@ public class JavaAPISuite implements Serializable { public void javaDoubleRDDHistoGram() { JavaDoubleRDD rdd = sc.parallelizeDoubles(Arrays.asList(1.0, 2.0, 3.0, 4.0)); // Test using generated buckets - Tuple2 results = rdd.histogram(2); - Double[] expected_buckets = {1.0, 2.5, 4.0}; + Tuple2 results = rdd.histogram(2); + double[] expected_buckets = {1.0, 2.5, 4.0}; long[] expected_counts = {2, 2}; - Assert.assertArrayEquals(expected_buckets, results._1); + Assert.assertArrayEquals(expected_buckets, results._1, 0.1); Assert.assertArrayEquals(expected_counts, results._2); // Test with provided buckets long[] histogram = rdd.histogram(expected_buckets); -- cgit v1.2.3 From 297c09d4bb26ba815c7fcb0a0ff04974959f551e Mon Sep 17 00:00:00 2001 From: Patrick Wendell Date: Mon, 25 Nov 2013 22:51:33 -0800 Subject: Improve docs for shuffle instrumentation --- .../org/apache/spark/executor/TaskMetrics.scala | 23 ++++++++++++---------- 1 file changed, 13 insertions(+), 10 deletions(-) diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala index 0b4892f98f..c0ce46e379 100644 --- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala +++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala @@ -61,50 +61,53 @@ object TaskMetrics { class ShuffleReadMetrics extends Serializable { /** - * Time when shuffle finishs + * Absolute time when this task finished reading shuffle data */ var shuffleFinishTime: Long = _ /** - * Total number of blocks fetched in a shuffle (remote or local) + * Number of blocks fetched in this shuffle by this task (remote or local) */ var totalBlocksFetched: Int = _ /** - * Number of remote blocks fetched in a shuffle + * Number of remote blocks fetched in this shuffle by this task */ var remoteBlocksFetched: Int = _ /** - * Local blocks fetched in a shuffle + * Number of local blocks fetched in this shuffle by this task */ var localBlocksFetched: Int = _ /** - * Total time that is spent blocked waiting for shuffle to fetch data + * Time the task spent waiting for remote shuffle blocks. This only includes the time + * blocking on shuffle input data. For instance if block B is being fetched while the task is + * still not finished processing block A, it is not considered to be blocking on block B. */ var fetchWaitTime: Long = _ /** - * The total amount of time for all the shuffle fetches. This adds up time from overlapping - * shuffles, so can be longer than task time + * Total time spent fetching remote shuffle blocks. This aggregates the time spent fetching all + * input blocks. Since block fetches are both pipelined and parallelized, this can + * exceed fetchWaitTime and executorRunTime. */ var remoteFetchTime: Long = _ /** - * Total number of remote bytes read from a shuffle + * Total number of remote bytes read from the shuffle by this task */ var remoteBytesRead: Long = _ } class ShuffleWriteMetrics extends Serializable { /** - * Number of bytes written for a shuffle + * Number of bytes written for the shuffle by this task */ var shuffleBytesWritten: Long = _ /** - * Time spent blocking on writes to disk or buffer cache, in nanoseconds. + * Time the task spent blocking on writes to disk or buffer cache, in nanoseconds */ var shuffleWriteTime: Long = _ } -- cgit v1.2.3 From ed7ecb93ce6ce259eae1f5aeb28e9e336fafa30f Mon Sep 17 00:00:00 2001 From: Mark Hamstra Date: Tue, 26 Nov 2013 13:30:17 -0800 Subject: [SPARK-963] Wait for SparkListenerBus eventQueue to be empty before checking jobLogger state --- .../src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala index 984881861c..002368ff55 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/JobLoggerSuite.scala @@ -31,6 +31,7 @@ import org.apache.spark.rdd.RDD class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers { + val WAIT_TIMEOUT_MILLIS = 10000 test("inner method") { sc = new SparkContext("local", "joblogger") @@ -92,6 +93,8 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) } rdd.reduceByKey(_+_).collect() + assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + val user = System.getProperty("user.name", SparkContext.SPARK_UNKNOWN_USER) joblogger.getLogDir should be ("/tmp/spark-%s".format(user)) @@ -120,7 +123,9 @@ class JobLoggerSuite extends FunSuite with LocalSparkContext with ShouldMatchers sc.addSparkListener(joblogger) val rdd = sc.parallelize(1 to 1e2.toInt, 4).map{ i => (i % 12, 2 * i) } rdd.reduceByKey(_+_).collect() - + + assert(sc.dagScheduler.listenerBus.waitUntilEmpty(WAIT_TIMEOUT_MILLIS)) + joblogger.onJobStartCount should be (1) joblogger.onJobEndCount should be (1) joblogger.onTaskEndCount should be (8) -- cgit v1.2.3 From 57579934f0454f258615c10e69ac2adafc5b9835 Mon Sep 17 00:00:00 2001 From: hhd Date: Mon, 25 Nov 2013 17:17:17 -0500 Subject: Emit warning when task size > 100KB --- .../scala/org/apache/spark/scheduler/DAGScheduler.scala | 15 +++++++++++++++ .../main/scala/org/apache/spark/scheduler/StageInfo.scala | 1 + .../main/scala/org/apache/spark/scheduler/TaskInfo.scala | 2 ++ .../spark/scheduler/cluster/ClusterTaskSetManager.scala | 1 + 4 files changed, 19 insertions(+) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 42bb3884c8..4457525ac8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -110,6 +110,9 @@ class DAGScheduler( // resubmit failed stages val POLL_TIMEOUT = 10L + // Warns the user if a stage contains a task with size greater than this value (in KB) + val TASK_SIZE_TO_WARN = 100 + private val eventProcessActor: ActorRef = env.actorSystem.actorOf(Props(new Actor { override def preStart() { context.system.scheduler.schedule(RESUBMIT_TIMEOUT milliseconds, RESUBMIT_TIMEOUT milliseconds) { @@ -430,6 +433,18 @@ class DAGScheduler( handleExecutorLost(execId) case BeginEvent(task, taskInfo) => + for ( + job <- idToActiveJob.get(task.stageId); + stage <- stageIdToStage.get(task.stageId); + stageInfo <- stageToInfos.get(stage) + ) { + if (taskInfo.serializedSize > TASK_SIZE_TO_WARN * 1024 && !stageInfo.emittedTaskSizeWarning) { + stageInfo.emittedTaskSizeWarning = true + logWarning(("Stage %d (%s) contains a task of very large " + + "size (%d KB). The maximum recommended task size is %d KB.").format( + task.stageId, stageInfo.name, taskInfo.serializedSize / 1024, TASK_SIZE_TO_WARN)) + } + } listenerBus.post(SparkListenerTaskStart(task, taskInfo)) case GettingResultEvent(task, taskInfo) => diff --git a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala index 93599dfdc8..e9f2198a00 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/StageInfo.scala @@ -33,4 +33,5 @@ class StageInfo( val name = stage.name val numPartitions = stage.numPartitions val numTasks = stage.numTasks + var emittedTaskSizeWarning = false } diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala index 4bae26f3a6..3c22edd524 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/TaskInfo.scala @@ -46,6 +46,8 @@ class TaskInfo( var failed = false + var serializedSize: Int = 0 + def markGettingResult(time: Long = System.currentTimeMillis) { gettingResultTime = time } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala index 4c5eca8537..8884ea85a3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala @@ -377,6 +377,7 @@ private[spark] class ClusterTaskSetManager( logInfo("Serialized task %s:%d as %d bytes in %d ms".format( taskSet.id, index, serializedTask.limit, timeTaken)) val taskName = "task %s:%d".format(taskSet.id, index) + info.serializedSize = serializedTask.limit if (taskAttempts(index).size == 1) taskStarted(task,info) return Some(new TaskDescription(taskId, execId, taskName, index, serializedTask)) -- cgit v1.2.3 From 1b74a27da026aba7dbe2088ee64974d772feb23d Mon Sep 17 00:00:00 2001 From: Josh Rosen Date: Tue, 26 Nov 2013 14:35:12 -0800 Subject: Removed unused basestring case from dump_stream. --- python/pyspark/serializers.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py index 9338df69ff..811fa6f018 100644 --- a/python/pyspark/serializers.py +++ b/python/pyspark/serializers.py @@ -179,8 +179,6 @@ class BatchedSerializer(Serializer): yield items def dump_stream(self, iterator, stream): - if isinstance(iterator, basestring): - iterator = [iterator] self.serializer.dump_stream(self._batched(iterator), stream) def load_stream(self, stream): -- cgit v1.2.3