From 2573add94cf920a88f74d80d8ea94218d812704d Mon Sep 17 00:00:00 2001 From: Prashant Sharma Date: Tue, 24 Dec 2013 18:30:31 +0530 Subject: spark-544, introducing SparkConf and related configuration overhaul. --- .../scala/org/apache/spark/mllib/recommendation/ALS.scala | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index 36853acab5..2f2d106f86 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -578,14 +578,13 @@ object ALS { val implicitPrefs = if (args.length >= 7) args(6).toBoolean else false val alpha = if (args.length >= 8) args(7).toDouble else 1 val blocks = if (args.length == 9) args(8).toInt else -1 - - System.setProperty("spark.serializer", "org.apache.spark.serializer.KryoSerializer") - System.setProperty("spark.kryo.registrator", classOf[ALSRegistrator].getName) - System.setProperty("spark.kryo.referenceTracking", "false") - System.setProperty("spark.kryoserializer.buffer.mb", "8") - System.setProperty("spark.locality.wait", "10000") - val sc = new SparkContext(master, "ALS") + sc.conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + sc.conf.set("spark.kryo.registrator", classOf[ALSRegistrator].getName) + sc.conf.set("spark.kryo.referenceTracking", "false") + sc.conf.set("spark.kryoserializer.buffer.mb", "8") + sc.conf.set("spark.locality.wait", "10000") + val ratings = sc.textFile(ratingsFile).map { line => val fields = line.split(',') Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble) -- cgit v1.2.3 From 3dc655aa19f678219e5d999fe97ab769567ffb1c Mon Sep 17 00:00:00 2001 From: Frank Dai Date: Wed, 25 Dec 2013 16:50:42 +0800 Subject: standard Naive Bayes classifier --- .../spark/mllib/classification/NaiveBayes.scala | 103 +++++++++++++++++++++ .../mllib/classification/NaiveBayesSuite.scala | 92 ++++++++++++++++++ 2 files changed, 195 insertions(+) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala create mode 100644 mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala new file mode 100644 index 0000000000..f1b0e6ee6a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.classification + +import scala.collection.mutable + +import org.apache.spark.Logging +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext._ +import org.jblas.DoubleMatrix + +/** + * Model for Naive Bayes Classifiers. + * + * @param weightPerLabel Weights computed for every label, which's dimension is C. + * @param weightMatrix Weights computed for every label and feature, which's dimension is CXD + */ +class NaiveBayesModel(val weightPerLabel: Array[Double], + val weightMatrix: Array[Array[Double]]) + extends ClassificationModel with Serializable { + + // Create a column vector that can be used for predictions + private val _weightPerLabel = new DoubleMatrix(weightPerLabel.length, 1, weightPerLabel:_*) + private val _weightMatrix = new DoubleMatrix(weightMatrix) + + def predict(testData: RDD[Array[Double]]): RDD[Double] = testData.map(predict) + + def predict(testData: Array[Double]): Double = { + val dataMatrix = new DoubleMatrix(testData.length, 1, testData: _*) + val result = _weightPerLabel.add(_weightMatrix.mmul(dataMatrix)) + result.argmax() + } +} + + + +class NaiveBayes private (val lambda: Double = 1.0) // smoothing parameter + extends Serializable with Logging { + + /** + * Run the algorithm with the configured parameters on an input + * RDD of LabeledPoint entries. + * + * @param C kind of labels, labels are continuous integers and the maximal label is C-1 + * @param D dimension of feature vectors + * @param data RDD of (label, array of features) pairs. + */ + def run(C: Int, D: Int, data: RDD[LabeledPoint]): NaiveBayesModel = { + val groupedData = data.map(p => p.label.toInt -> p.features).groupByKey() + + val countPerLabel = groupedData.mapValues(_.size) + val logDenominator = math.log(data.count() + C * lambda) + val weightPerLabel = countPerLabel.mapValues { + count => math.log(count + lambda) - logDenominator + } + + val summedObservations = groupedData.mapValues(_.reduce { + (lhs, rhs) => lhs.zip(rhs).map(pair => pair._1 + pair._2) + }) + + val weightsMatrix = summedObservations.mapValues { weights => + val sum = weights.sum + val logDenom = math.log(sum + D * lambda) + weights.map(w => math.log(w + lambda) - logDenom) + } + + val labelWeights = weightPerLabel.collect().sorted.map(_._2) + val weightsMat = weightsMatrix.collect().sortBy(_._1).map(_._2) + + new NaiveBayesModel(labelWeights, weightsMat) + } +} + +object NaiveBayes { + /** + * Train a naive bayes model given an RDD of (label, features) pairs. + * + * @param C kind of labels, the maximal label is C-1 + * @param D dimension of feature vectors + * @param input RDD of (label, array of features) pairs. + * @param lambda smooth parameter + */ + def train(C: Int, D: Int, input: RDD[LabeledPoint], + lambda: Double = 1.0): NaiveBayesModel = { + new NaiveBayes(lambda).run(C, D, input) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala new file mode 100644 index 0000000000..d871ed3672 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -0,0 +1,92 @@ +package org.apache.spark.mllib.classification + +import scala.collection.JavaConversions._ +import scala.util.Random + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.FunSuite + +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.SparkContext + +object NaiveBayesSuite { + + private def calcLabel(p: Double, weightPerLabel: Array[Double]): Int = { + var sum = 0.0 + for (j <- 0 until weightPerLabel.length) { + sum += weightPerLabel(j) + if (p < sum) return j + } + -1 + } + + // Generate input of the form Y = (weightMatrix*x).argmax() + def generateNaiveBayesInput( + weightPerLabel: Array[Double], // 1XC + weightsMatrix: Array[Array[Double]], // CXD + nPoints: Int, + seed: Int): Seq[LabeledPoint] = { + val D = weightsMatrix(0).length + val rnd = new Random(seed) + + val _weightPerLabel = weightPerLabel.map(math.pow(math.E, _)) + val _weightMatrix = weightsMatrix.map(row => row.map(math.pow(math.E, _))) + + for (i <- 0 until nPoints) yield { + val y = calcLabel(rnd.nextDouble(), _weightPerLabel) + val xi = Array.tabulate[Double](D) { j => + if (rnd.nextDouble() < _weightMatrix(y)(j)) 1 else 0 + } + + LabeledPoint(y, xi) + } + } +} + +class NaiveBayesSuite extends FunSuite with BeforeAndAfterAll { + @transient private var sc: SparkContext = _ + + override def beforeAll() { + sc = new SparkContext("local", "test") + } + + override def afterAll() { + sc.stop() + System.clearProperty("spark.driver.port") + } + + def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { + val numOffPredictions = predictions.zip(input).count { + case (prediction, expected) => + prediction != expected.label + } + // At least 80% of the predictions should be on. + assert(numOffPredictions < input.length / 5) + } + + test("Naive Bayes") { + val nPoints = 10000 + + val weightPerLabel = Array(math.log(0.5), math.log(0.3), math.log(0.2)) + val weightsMatrix = Array( + Array(math.log(0.91), math.log(0.03), math.log(0.03), math.log(0.03)), // label 0 + Array(math.log(0.03), math.log(0.91), math.log(0.03), math.log(0.03)), // label 1 + Array(math.log(0.03), math.log(0.03), math.log(0.91), math.log(0.03)) // label 2 + ) + + val testData = NaiveBayesSuite.generateNaiveBayesInput(weightPerLabel, weightsMatrix, nPoints, 42) + val testRDD = sc.parallelize(testData, 2) + testRDD.cache() + + val model = NaiveBayes.train(3, 4, testRDD) + + val validationData = NaiveBayesSuite.generateNaiveBayesInput(weightPerLabel, weightsMatrix, nPoints, 17) + val validationRDD = sc.parallelize(validationData, 2) + + // Test prediction on RDD. + validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + + // Test prediction on Array. + validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + } +} -- cgit v1.2.3 From 3bb714eaa3bdb7b7c33f6e5263c683f4c4beeddc Mon Sep 17 00:00:00 2001 From: "Lian, Cheng" Date: Wed, 25 Dec 2013 17:15:38 +0800 Subject: Refactored NaiveBayes * Minimized shuffle output with mapPartitions. * Reduced RDD actions from 3 to 1. --- .../spark/mllib/classification/NaiveBayes.scala | 60 +++++++++++++--------- .../mllib/classification/NaiveBayesSuite.scala | 9 ++-- 2 files changed, 41 insertions(+), 28 deletions(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index f1b0e6ee6a..edea5ed3e6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -48,11 +48,12 @@ class NaiveBayesModel(val weightPerLabel: Array[Double], } } - - class NaiveBayes private (val lambda: Double = 1.0) // smoothing parameter extends Serializable with Logging { + private[this] def vectorAdd(v1: Array[Double], v2: Array[Double]) = + v1.zip(v2).map(pair => pair._1 + pair._2) + /** * Run the algorithm with the configured parameters on an input * RDD of LabeledPoint entries. @@ -61,29 +62,42 @@ class NaiveBayes private (val lambda: Double = 1.0) // smoothing parameter * @param D dimension of feature vectors * @param data RDD of (label, array of features) pairs. */ - def run(C: Int, D: Int, data: RDD[LabeledPoint]): NaiveBayesModel = { - val groupedData = data.map(p => p.label.toInt -> p.features).groupByKey() - - val countPerLabel = groupedData.mapValues(_.size) - val logDenominator = math.log(data.count() + C * lambda) - val weightPerLabel = countPerLabel.mapValues { - count => math.log(count + lambda) - logDenominator + def run(C: Int, D: Int, data: RDD[LabeledPoint]) = { + val locallyReduced = data.mapPartitions { iterator => + val localLabelCounts = mutable.Map.empty[Int, Int].withDefaultValue(0) + val localSummedObservations = + mutable.Map.empty[Int, Array[Double]].withDefaultValue(Array.fill(D)(0.0)) + + for (LabeledPoint(label, features) <- iterator; i = label.toInt) { + localLabelCounts(i) += 1 + localSummedObservations(i) = vectorAdd(localSummedObservations(i), features) + } + + for ((label, count) <- localLabelCounts.toIterator) yield { + label -> (count, localSummedObservations(label)) + } + } + + val reduced = locallyReduced.reduceByKey { (lhs, rhs) => + (lhs._1 + rhs._1, vectorAdd(lhs._2, rhs._2)) } - - val summedObservations = groupedData.mapValues(_.reduce { - (lhs, rhs) => lhs.zip(rhs).map(pair => pair._1 + pair._2) - }) - - val weightsMatrix = summedObservations.mapValues { weights => - val sum = weights.sum - val logDenom = math.log(sum + D * lambda) - weights.map(w => math.log(w + lambda) - logDenom) + + val collected = reduced.mapValues { case (count, summed) => + val labelWeight = math.log(count + lambda) + val logDenom = math.log(summed.sum + D * lambda) + val weights = summed.map(w => math.log(w + lambda) - logDenom) + (count, labelWeight, weights) + }.collectAsMap() + + val weightPerLabel = { + val N = collected.values.map(_._1).sum + val logDenom = math.log(N + C * lambda) + collected.mapValues(_._2 - logDenom).toArray.sortBy(_._1).map(_._2) } - - val labelWeights = weightPerLabel.collect().sorted.map(_._2) - val weightsMat = weightsMatrix.collect().sortBy(_._1).map(_._2) - - new NaiveBayesModel(labelWeights, weightsMat) + + val weightMatrix = collected.mapValues(_._3).toArray.sortBy(_._1).map(_._2) + + new NaiveBayesModel(weightPerLabel, weightMatrix) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index d871ed3672..cc8d48a42b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -1,6 +1,5 @@ package org.apache.spark.mllib.classification -import scala.collection.JavaConversions._ import scala.util.Random import org.scalatest.BeforeAndAfterAll @@ -56,12 +55,12 @@ class NaiveBayesSuite extends FunSuite with BeforeAndAfterAll { } def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { - val numOffPredictions = predictions.zip(input).count { + val numOfPredictions = predictions.zip(input).count { case (prediction, expected) => prediction != expected.label } // At least 80% of the predictions should be on. - assert(numOffPredictions < input.length / 5) + assert(numOfPredictions < input.length / 5) } test("Naive Bayes") { @@ -71,8 +70,8 @@ class NaiveBayesSuite extends FunSuite with BeforeAndAfterAll { val weightsMatrix = Array( Array(math.log(0.91), math.log(0.03), math.log(0.03), math.log(0.03)), // label 0 Array(math.log(0.03), math.log(0.91), math.log(0.03), math.log(0.03)), // label 1 - Array(math.log(0.03), math.log(0.03), math.log(0.91), math.log(0.03)) // label 2 - ) + Array(math.log(0.03), math.log(0.03), math.log(0.91), math.log(0.03)) // label 2 + ) val testData = NaiveBayesSuite.generateNaiveBayesInput(weightPerLabel, weightsMatrix, nPoints, 42) val testRDD = sc.parallelize(testData, 2) -- cgit v1.2.3 From c0337c5bbfd5126c64964a9fdefd2bef11727d87 Mon Sep 17 00:00:00 2001 From: "Lian, Cheng" Date: Wed, 25 Dec 2013 22:45:57 +0800 Subject: Let reduceByKey to take care of local combine Also refactored some heavy FP code to improve readability and reduce memory footprint. --- .../spark/mllib/classification/NaiveBayes.scala | 43 ++++++++-------------- 1 file changed, 16 insertions(+), 27 deletions(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index edea5ed3e6..4c96b241eb 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -17,8 +17,6 @@ package org.apache.spark.mllib.classification -import scala.collection.mutable - import org.apache.spark.Logging import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD @@ -63,39 +61,30 @@ class NaiveBayes private (val lambda: Double = 1.0) // smoothing parameter * @param data RDD of (label, array of features) pairs. */ def run(C: Int, D: Int, data: RDD[LabeledPoint]) = { - val locallyReduced = data.mapPartitions { iterator => - val localLabelCounts = mutable.Map.empty[Int, Int].withDefaultValue(0) - val localSummedObservations = - mutable.Map.empty[Int, Array[Double]].withDefaultValue(Array.fill(D)(0.0)) - - for (LabeledPoint(label, features) <- iterator; i = label.toInt) { - localLabelCounts(i) += 1 - localSummedObservations(i) = vectorAdd(localSummedObservations(i), features) - } - - for ((label, count) <- localLabelCounts.toIterator) yield { - label -> (count, localSummedObservations(label)) - } - } - - val reduced = locallyReduced.reduceByKey { (lhs, rhs) => + val countsAndSummedFeatures = data.map { case LabeledPoint(label, features) => + label.toInt ->(1, features) + }.reduceByKey { (lhs, rhs) => (lhs._1 + rhs._1, vectorAdd(lhs._2, rhs._2)) } - val collected = reduced.mapValues { case (count, summed) => + val collected = countsAndSummedFeatures.mapValues { case (count, summedFeatureVector) => val labelWeight = math.log(count + lambda) - val logDenom = math.log(summed.sum + D * lambda) - val weights = summed.map(w => math.log(w + lambda) - logDenom) + val logDenom = math.log(summedFeatureVector.sum + D * lambda) + val weights = summedFeatureVector.map(w => math.log(w + lambda) - logDenom) (count, labelWeight, weights) }.collectAsMap() - val weightPerLabel = { - val N = collected.values.map(_._1).sum - val logDenom = math.log(N + C * lambda) - collected.mapValues(_._2 - logDenom).toArray.sortBy(_._1).map(_._2) - } + // We can simply call `data.count` to get `N`, but that triggers another RDD action, which is + // considerably expensive. + val N = collected.values.map(_._1).sum + val logDenom = math.log(N + C * lambda) + val weightPerLabel = Array.fill[Double](C)(0) + val weightMatrix = Array.fill[Array[Double]](C)(null) - val weightMatrix = collected.mapValues(_._3).toArray.sortBy(_._1).map(_._2) + for ((label, (_, labelWeight, weights)) <- collected) { + weightPerLabel(label) = labelWeight - logDenom + weightMatrix(label) = weights + } new NaiveBayesModel(weightPerLabel, weightMatrix) } -- cgit v1.2.3 From 654f42174aa912fec7355d779e4e02731c535c94 Mon Sep 17 00:00:00 2001 From: "Lian, Cheng" Date: Fri, 27 Dec 2013 04:45:04 +0800 Subject: Reformatted some lines commented by Matei --- .../scala/org/apache/spark/mllib/classification/NaiveBayes.scala | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 4c96b241eb..2bc4c5afc0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -49,8 +49,9 @@ class NaiveBayesModel(val weightPerLabel: Array[Double], class NaiveBayes private (val lambda: Double = 1.0) // smoothing parameter extends Serializable with Logging { - private[this] def vectorAdd(v1: Array[Double], v2: Array[Double]) = + private def vectorAdd(v1: Array[Double], v2: Array[Double]) = { v1.zip(v2).map(pair => pair._1 + pair._2) + } /** * Run the algorithm with the configured parameters on an input @@ -62,7 +63,7 @@ class NaiveBayes private (val lambda: Double = 1.0) // smoothing parameter */ def run(C: Int, D: Int, data: RDD[LabeledPoint]) = { val countsAndSummedFeatures = data.map { case LabeledPoint(label, features) => - label.toInt ->(1, features) + label.toInt -> (1, features) }.reduceByKey { (lhs, rhs) => (lhs._1 + rhs._1, vectorAdd(lhs._2, rhs._2)) } -- cgit v1.2.3 From d7086dc28a856ec8856278be108310ec8264a115 Mon Sep 17 00:00:00 2001 From: "Lian, Cheng" Date: Fri, 27 Dec 2013 08:20:41 +0800 Subject: Added Apache license header to NaiveBayesSuite --- .../spark/mllib/classification/NaiveBayesSuite.scala | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) (limited to 'mllib') diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index cc8d48a42b..a2821347a7 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -1,3 +1,20 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + package org.apache.spark.mllib.classification import scala.util.Random -- cgit v1.2.3 From 642029e7f43322f84abe4f7f36bb0b1b95d8101d Mon Sep 17 00:00:00 2001 From: Matei Zaharia Date: Sat, 28 Dec 2013 17:13:15 -0500 Subject: Various fixes to configuration code - Got rid of global SparkContext.globalConf - Pass SparkConf to serializers and compression codecs - Made SparkConf public instead of private[spark] - Improved API of SparkContext and SparkConf - Switched executor environment vars to be passed through SparkConf - Fixed some places that were still using system properties - Fixed some tests, though others are still failing This still fails several tests in core, repl and streaming, likely due to properties not being set or cleared correctly (some of the tests run fine in isolation). --- .../main/scala/org/apache/spark/Accumulators.scala | 8 +- .../scala/org/apache/spark/MapOutputTracker.scala | 4 +- .../main/scala/org/apache/spark/Partitioner.scala | 6 +- .../main/scala/org/apache/spark/SparkConf.scala | 158 ++++++++++++++++----- .../main/scala/org/apache/spark/SparkContext.scala | 138 ++++++++++-------- .../src/main/scala/org/apache/spark/SparkEnv.scala | 11 +- .../apache/spark/api/java/JavaSparkContext.scala | 15 +- .../org/apache/spark/api/python/PythonRDD.scala | 6 +- .../org/apache/spark/broadcast/HttpBroadcast.scala | 33 +++-- .../apache/spark/broadcast/TorrentBroadcast.scala | 38 ++--- .../spark/deploy/ApplicationDescription.scala | 2 +- .../apache/spark/deploy/LocalSparkCluster.scala | 7 +- .../org/apache/spark/deploy/SparkHadoopUtil.scala | 14 +- .../apache/spark/deploy/client/TestClient.scala | 9 +- .../org/apache/spark/deploy/master/Master.scala | 36 ++--- .../deploy/master/SparkZooKeeperSession.scala | 2 +- .../master/ZooKeeperLeaderElectionAgent.scala | 2 +- .../deploy/master/ZooKeeperPersistenceEngine.scala | 2 +- .../org/apache/spark/deploy/worker/Worker.scala | 14 +- .../executor/CoarseGrainedExecutorBackend.scala | 4 +- .../scala/org/apache/spark/executor/Executor.scala | 17 +-- .../org/apache/spark/io/CompressionCodec.scala | 13 +- .../apache/spark/network/ConnectionManager.scala | 4 +- .../org/apache/spark/network/ReceiverTest.scala | 12 +- .../org/apache/spark/network/SenderTest.scala | 16 +-- .../apache/spark/network/netty/ShuffleCopier.scala | 6 +- .../scala/org/apache/spark/rdd/CheckpointRDD.scala | 7 +- .../scala/org/apache/spark/rdd/CoGroupedRDD.scala | 2 +- .../scala/org/apache/spark/rdd/ShuffledRDD.scala | 2 +- .../scala/org/apache/spark/rdd/SubtractedRDD.scala | 2 +- .../org/apache/spark/scheduler/DAGScheduler.scala | 5 +- .../apache/spark/scheduler/InputFormatInfo.scala | 14 +- .../org/apache/spark/scheduler/JobLogger.scala | 2 +- .../org/apache/spark/scheduler/ResultTask.scala | 4 +- .../spark/scheduler/SchedulableBuilder.scala | 2 +- .../apache/spark/scheduler/ShuffleMapTask.scala | 6 +- .../spark/scheduler/cluster/ClusterScheduler.scala | 8 +- .../scheduler/cluster/ClusterTaskSetManager.scala | 12 +- .../cluster/CoarseGrainedSchedulerBackend.scala | 9 +- .../spark/scheduler/cluster/SchedulerBackend.scala | 3 - .../scheduler/cluster/SimrSchedulerBackend.scala | 2 +- .../cluster/SparkDeploySchedulerBackend.scala | 2 +- .../spark/scheduler/cluster/TaskResultGetter.scala | 2 +- .../mesos/CoarseMesosSchedulerBackend.scala | 6 +- .../cluster/mesos/MesosSchedulerBackend.scala | 6 +- .../spark/scheduler/local/LocalScheduler.scala | 2 +- .../apache/spark/serializer/JavaSerializer.scala | 3 +- .../apache/spark/serializer/KryoSerializer.scala | 13 +- .../spark/serializer/SerializerManager.scala | 12 +- .../spark/storage/BlockFetcherIterator.scala | 2 +- .../org/apache/spark/storage/BlockManager.scala | 46 +++--- .../apache/spark/storage/BlockManagerMaster.scala | 4 +- .../spark/storage/BlockManagerMasterActor.scala | 4 +- .../apache/spark/storage/DiskBlockManager.scala | 2 +- .../apache/spark/storage/ShuffleBlockManager.scala | 9 +- .../apache/spark/storage/StoragePerfTester.scala | 2 +- .../org/apache/spark/storage/ThreadingTest.scala | 6 +- .../org/apache/spark/ui/UIWorkloadGenerator.scala | 17 ++- .../org/apache/spark/ui/env/EnvironmentUI.scala | 2 +- .../apache/spark/ui/jobs/JobProgressListener.scala | 4 +- .../scala/org/apache/spark/util/AkkaUtils.scala | 18 +-- .../org/apache/spark/util/MetadataCleaner.scala | 33 +++-- .../org/apache/spark/util/SizeEstimator.scala | 17 +-- .../main/scala/org/apache/spark/util/Utils.scala | 14 +- .../apache/spark/io/CompressionCodecSuite.scala | 8 +- .../cluster/ClusterTaskSetManagerSuite.scala | 2 +- .../spark/serializer/KryoSerializerSuite.scala | 14 +- .../apache/spark/storage/BlockManagerSuite.scala | 8 +- .../org/apache/spark/util/SizeEstimatorSuite.scala | 2 - .../spark/examples/bagel/WikipediaPageRank.scala | 4 +- .../bagel/WikipediaPageRankStandalone.scala | 4 +- .../apache/spark/mllib/recommendation/ALS.scala | 10 +- .../spark/deploy/yarn/ApplicationMaster.scala | 44 +++--- .../org/apache/spark/deploy/yarn/Client.scala | 38 ++--- .../apache/spark/deploy/yarn/ClientArguments.scala | 2 +- .../scala/org/apache/spark/repl/SparkILoop.scala | 16 ++- .../scala/org/apache/spark/repl/SparkIMain.scala | 4 +- .../org/apache/spark/streaming/Checkpoint.scala | 22 +-- .../scala/org/apache/spark/streaming/DStream.scala | 2 +- .../org/apache/spark/streaming/Scheduler.scala | 10 +- .../apache/spark/streaming/StreamingContext.scala | 25 ++-- .../streaming/dstream/NetworkInputDStream.scala | 6 +- .../spark/streaming/util/RawTextSender.scala | 4 +- .../apache/spark/streaming/InputStreamsSuite.scala | 6 +- .../org/apache/spark/streaming/TestSuiteBase.scala | 6 +- .../spark/deploy/yarn/ApplicationMaster.scala | 56 ++++---- .../org/apache/spark/deploy/yarn/Client.scala | 50 +++---- .../apache/spark/deploy/yarn/ClientArguments.scala | 2 +- 88 files changed, 692 insertions(+), 536 deletions(-) (limited to 'mllib') diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala index 6e922a612a..5f73d234aa 100644 --- a/core/src/main/scala/org/apache/spark/Accumulators.scala +++ b/core/src/main/scala/org/apache/spark/Accumulators.scala @@ -41,7 +41,7 @@ class Accumulable[R, T] ( @transient initialValue: R, param: AccumulableParam[R, T]) extends Serializable { - + val id = Accumulators.newId @transient private var value_ = initialValue // Current value on master val zero = param.zero(initialValue) // Zero value to be passed to workers @@ -113,7 +113,7 @@ class Accumulable[R, T] ( def setValue(newValue: R) { this.value = newValue } - + // Called by Java when deserializing an object private def readObject(in: ObjectInputStream) { in.defaultReadObject() @@ -177,7 +177,7 @@ class GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Ser def zero(initialValue: R): R = { // We need to clone initialValue, but it's hard to specify that R should also be Cloneable. // Instead we'll serialize it to a buffer and load it back. - val ser = new JavaSerializer().newInstance() + val ser = new JavaSerializer(new SparkConf(false)).newInstance() val copy = ser.deserialize[R](ser.serialize(initialValue)) copy.clear() // In case it contained stuff copy @@ -215,7 +215,7 @@ private object Accumulators { val originals = Map[Long, Accumulable[_, _]]() val localAccums = Map[Thread, Map[Long, Accumulable[_, _]]]() var lastId: Long = 0 - + def newId: Long = synchronized { lastId += 1 return lastId diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala index 4520edb10d..cdae167aef 100644 --- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala +++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala @@ -65,7 +65,7 @@ private[spark] class MapOutputTracker(conf: SparkConf) extends Logging { protected val epochLock = new java.lang.Object private val metadataCleaner = - new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup) + new MetadataCleaner(MetadataCleanerType.MAP_OUTPUT_TRACKER, this.cleanup, conf) // Send a message to the trackerActor and get its result within a default timeout, or // throw a SparkException if this fails. @@ -129,7 +129,7 @@ private[spark] class MapOutputTracker(conf: SparkConf) extends Logging { if (fetchedStatuses == null) { // We won the race to fetch the output locs; do so logInfo("Doing the fetch; tracker actor = " + trackerActor) - val hostPort = Utils.localHostPort() + val hostPort = Utils.localHostPort(conf) // This try-finally prevents hangs due to timeouts: try { val fetchedBytes = diff --git a/core/src/main/scala/org/apache/spark/Partitioner.scala b/core/src/main/scala/org/apache/spark/Partitioner.scala index 04c1eedfeb..7cb545a6be 100644 --- a/core/src/main/scala/org/apache/spark/Partitioner.scala +++ b/core/src/main/scala/org/apache/spark/Partitioner.scala @@ -32,8 +32,6 @@ abstract class Partitioner extends Serializable { } object Partitioner { - - import SparkContext.{globalConf => conf} /** * Choose a partitioner to use for a cogroup-like operation between a number of RDDs. * @@ -54,7 +52,7 @@ object Partitioner { for (r <- bySize if r.partitioner != None) { return r.partitioner.get } - if (conf.getOrElse("spark.default.parallelism", null) != null) { + if (rdd.context.conf.getOrElse("spark.default.parallelism", null) != null) { return new HashPartitioner(rdd.context.defaultParallelism) } else { return new HashPartitioner(bySize.head.partitions.size) @@ -92,7 +90,7 @@ class HashPartitioner(partitions: Int) extends Partitioner { class RangePartitioner[K <% Ordered[K]: ClassTag, V]( partitions: Int, @transient rdd: RDD[_ <: Product2[K,V]], - private val ascending: Boolean = true) + private val ascending: Boolean = true) extends Partitioner { // An array of upper bounds for the first (partitions - 1) partitions diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala index 9a4eefad2e..185ddb1fe5 100644 --- a/core/src/main/scala/org/apache/spark/SparkConf.scala +++ b/core/src/main/scala/org/apache/spark/SparkConf.scala @@ -1,71 +1,159 @@ package org.apache.spark -import scala.collection.JavaConversions._ -import scala.collection.concurrent.TrieMap +import scala.collection.JavaConverters._ +import scala.collection.mutable.HashMap import com.typesafe.config.ConfigFactory -private[spark] class SparkConf(loadClasspathRes: Boolean = true) extends Serializable { - @transient lazy val config = ConfigFactory.systemProperties() - .withFallback(ConfigFactory.parseResources("spark.conf")) - // TODO this should actually be synchronized - private val configMap = TrieMap[String, String]() +/** + * Configuration for a Spark application. Used to set various Spark parameters as key-value pairs. + * + * Most of the time, you would create a SparkConf object with `new SparkConf()`, which will load + * values from both the `spark.*` Java system properties and any `spark.conf` on your application's + * classpath (if it has one). In this case, system properties take priority over `spark.conf`, and + * any parameters you set directly on the `SparkConf` object take priority over both of those. + * + * For unit tests, you can also call `new SparkConf(false)` to skip loading external settings and + * get the same configuration no matter what is on the classpath. + * + * @param loadDefaults whether to load values from the system properties and classpath + */ +class SparkConf(loadDefaults: Boolean) extends Serializable with Cloneable { - if (loadClasspathRes && !config.entrySet().isEmpty) { - for (e <- config.entrySet()) { - configMap += ((e.getKey, e.getValue.unwrapped().toString)) + /** Create a SparkConf that loads defaults from system properties and the classpath */ + def this() = this(true) + + private val settings = new HashMap[String, String]() + + if (loadDefaults) { + val typesafeConfig = ConfigFactory.systemProperties() + .withFallback(ConfigFactory.parseResources("spark.conf")) + for (e <- typesafeConfig.entrySet().asScala) { + settings(e.getKey) = e.getValue.unwrapped.toString } } - def setMasterUrl(master: String) = { - if (master != null) - configMap += (("spark.master", master)) + /** Set a configuration variable. */ + def set(key: String, value: String): SparkConf = { + settings(key) = value + this + } + + /** + * The master URL to connect to, such as "local" to run locally with one thread, "local[4]" to + * run locally with 4 cores, or "spark://master:7077" to run on a Spark standalone cluster. + */ + def setMaster(master: String): SparkConf = { + if (master != null) { + settings("spark.master") = master + } this } - def setAppName(name: String) = { - if (name != null) - configMap += (("spark.appName", name)) + /** Set a name for your application. Shown in the Spark web UI. */ + def setAppName(name: String): SparkConf = { + if (name != null) { + settings("spark.appName") = name + } this } - def setJars(jars: Seq[String]) = { - if (!jars.isEmpty) - configMap += (("spark.jars", jars.mkString(","))) + /** Set JAR files to distribute to the cluster. */ + def setJars(jars: Seq[String]): SparkConf = { + if (!jars.isEmpty) { + settings("spark.jars") = jars.mkString(",") + } this } - def set(k: String, value: String) = { - configMap += ((k, value)) + /** Set JAR files to distribute to the cluster. (Java-friendly version.) */ + def setJars(jars: Array[String]): SparkConf = { + if (!jars.isEmpty) { + settings("spark.jars") = jars.mkString(",") + } this } - def setSparkHome(home: String) = { - if (home != null) - configMap += (("spark.home", home)) + /** Set an environment variable to be used when launching executors for this application. */ + def setExecutorEnv(variable: String, value: String): SparkConf = { + settings("spark.executorEnv." + variable) = value this } - def set(map: Seq[(String, String)]) = { - if (map != null && !map.isEmpty) - configMap ++= map + /** Set multiple environment variables to be used when launching executors. */ + def setExecutorEnv(variables: Seq[(String, String)]): SparkConf = { + for ((k, v) <- variables) { + setExecutorEnv(k, v) + } this } - def get(k: String): String = { - configMap(k) + /** + * Set multiple environment variables to be used when launching executors. + * (Java-friendly version.) + */ + def setExecutorEnv(variables: Array[(String, String)]): SparkConf = { + for ((k, v) <- variables) { + setExecutorEnv(k, v) + } + this } - def getAllConfiguration = configMap.clone.entrySet().iterator + /** + * Set the location where Spark is installed on worker nodes. This is only needed on Mesos if + * you are not using `spark.executor.uri` to disseminate the Spark binary distribution. + */ + def setSparkHome(home: String): SparkConf = { + if (home != null) { + settings("spark.home") = home + } + this + } + /** Set multiple parameters together */ + def setAll(settings: Traversable[(String, String)]) = { + this.settings ++= settings + this + } + + /** Set a parameter if it isn't already configured */ + def setIfMissing(key: String, value: String): SparkConf = { + if (!settings.contains(key)) { + settings(key) = value + } + this + } + + /** Get a parameter; throws an exception if it's not set */ + def get(key: String): String = { + settings(key) + } + + /** Get a parameter as an Option */ + def getOption(key: String): Option[String] = { + settings.get(key) + } + + /** Get all parameters as a list of pairs */ + def getAll: Seq[(String, String)] = settings.clone().toSeq + + /** Get a parameter, falling back to a default if not set */ def getOrElse(k: String, defaultValue: String): String = { - configMap.getOrElse(k, defaultValue) + settings.getOrElse(k, defaultValue) } - override def clone: SparkConf = { - val conf = new SparkConf(false) - conf.set(configMap.toSeq) - conf + /** Get all executor environment variables set on this SparkConf */ + def getExecutorEnv: Seq[(String, String)] = { + val prefix = "spark.executorEnv." + getAll.filter(pair => pair._1.startsWith(prefix)) + .map(pair => (pair._1.substring(prefix.length), pair._2)) } + /** Does the configuration contain a given parameter? */ + def contains(key: String): Boolean = settings.contains(key) + + /** Copy this object */ + override def clone: SparkConf = { + new SparkConf(false).setAll(settings) + } } diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala index 4300b07bdb..0567f7f437 100644 --- a/core/src/main/scala/org/apache/spark/SparkContext.scala +++ b/core/src/main/scala/org/apache/spark/SparkContext.scala @@ -22,8 +22,7 @@ import java.net.URI import java.util.Properties import java.util.concurrent.atomic.AtomicInteger -import scala.collection.{Map, immutable} -import scala.collection.JavaConversions._ +import scala.collection.{Map, Set, immutable} import scala.collection.generic.Growable import scala.collection.mutable.{ArrayBuffer, HashMap} @@ -57,23 +56,32 @@ import org.apache.spark.util._ * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark * cluster, and can be used to create RDDs, accumulators and broadcast variables on that cluster. * - * @param conf a Spark Config object describing the context configuration. Any settings in this - * config overrides the default configs as well as system properties. - * - * @param environment Environment variables to set on worker nodes. + * @param conf_ a Spark Config object describing the application configuration. Any settings in + * this config overrides the default configs as well as system properties. + * @param preferredNodeLocationData used in YARN mode to select nodes to launch containers on. Can + * be generated using [[org.apache.spark.scheduler.InputFormatInfo.computePreferredLocations]] + * from a list of input files or InputFormats for the application. */ class SparkContext( - val conf: SparkConf, - val environment: Map[String, String] = Map(), + conf_ : SparkConf, // This is used only by YARN for now, but should be relevant to other cluster types (Mesos, etc) - // too. This is typically generated from InputFormatInfo.computePreferredLocations .. host, set - // of data-local splits on host - val preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] = immutable.Map()) + // too. This is typically generated from InputFormatInfo.computePreferredLocations. It contains + // a map from hostname to a list of input format splits on the host. + val preferredNodeLocationData: Map[String, Set[SplitInfo]] = Map()) extends Logging { /** - * Main entry point for Spark functionality. A SparkContext represents the connection to a Spark - * cluster, and can be used to create RDDs, accumulators and broadcast variables on that cluster. + * Alternative constructor that allows setting common Spark properties directly + * + * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). + * @param appName A name for your application, to display on the cluster web UI + * @param conf a [[org.apache.spark.SparkConf]] object specifying other Spark parameters + */ + def this(master: String, appName: String, conf: SparkConf) = + this(conf.setMaster(master).setAppName(appName)) + + /** + * Alternative constructor that allows setting common Spark properties directly * * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). * @param appName A name for your application, to display on the cluster web UI. @@ -82,24 +90,42 @@ class SparkContext( * system or HDFS, HTTP, HTTPS, or FTP URLs. * @param environment Environment variables to set on worker nodes. */ - def this(master: String, appName: String, sparkHome: String = null, - jars: Seq[String] = Nil, environment: Map[String, String] = Map(), - preferredNodeLocationData: scala.collection.Map[String, scala.collection.Set[SplitInfo]] = - immutable.Map()) = - this(new SparkConf(false).setAppName(appName).setMasterUrl(master) - .setJars(jars).set(environment.toSeq).setSparkHome(sparkHome), - environment, preferredNodeLocationData) + def this( + master: String, + appName: String, + sparkHome: String = null, + jars: Seq[String] = Nil, + environment: Map[String, String] = Map(), + preferredNodeLocationData: Map[String, Set[SplitInfo]] = Map()) = + { + this( + new SparkConf() + .setMaster(master) + .setAppName(appName) + .setJars(jars) + .setExecutorEnv(environment.toSeq) + .setSparkHome(sparkHome), + preferredNodeLocationData) + } - // Set Spark driver host and port system properties - Try(conf.get("spark.driver.host")) - .getOrElse(conf.set("spark.driver.host", Utils.localHostName())) + val conf = conf_.clone() + + if (!conf.contains("spark.master")) { + throw new SparkException("A master URL must be set in your configuration") + } + if (!conf.contains("spark.appName")) { + throw new SparkException("An application must be set in your configuration") + } - Try(conf.get("spark.driver.port")) - .getOrElse(conf.set("spark.driver.port", "0")) + // Set Spark driver host and port system properties + conf.setIfMissing("spark.driver.host", Utils.localHostName()) + conf.setIfMissing("spark.driver.port", "0") - val jars: Seq[String] = if (conf.getOrElse("spark.jars", null) != null) { - conf.get("spark.jars").split(",") - } else null + val jars: Seq[String] = if (conf.contains("spark.jars")) { + conf.get("spark.jars").split(",").filter(_.size != 0) + } else { + null + } val master = conf.get("spark.master") val appName = conf.get("spark.appName") @@ -115,8 +141,8 @@ class SparkContext( conf.get("spark.driver.host"), conf.get("spark.driver.port").toInt, conf, - true, - isLocal) + isDriver = true, + isLocal = isLocal) SparkEnv.set(env) // Used to store a URL for each static file/jar together with the file's local timestamp @@ -125,7 +151,8 @@ class SparkContext( // Keeps track of all persisted RDDs private[spark] val persistentRdds = new TimeStampedHashMap[Int, RDD[_]] - private[spark] val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup) + private[spark] val metadataCleaner = + new MetadataCleaner(MetadataCleanerType.SPARK_CONTEXT, this.cleanup, conf) // Initialize the Spark UI private[spark] val ui = new SparkUI(this) @@ -135,9 +162,14 @@ class SparkContext( // Add each JAR given through the constructor if (jars != null) { - jars.foreach { addJar(_) } + jars.foreach(addJar) } + private[spark] val executorMemory = conf.getOption("spark.executor.memory") + .orElse(Option(System.getenv("SPARK_MEM"))) + .map(Utils.memoryStringToMb) + .getOrElse(512) + // Environment variables to pass to our executors private[spark] val executorEnvs = HashMap[String, String]() // Note: SPARK_MEM is included for Mesos, but overwritten for standalone mode in ExecutorRunner @@ -148,10 +180,8 @@ class SparkContext( } } // Since memory can be set with a system property too, use that - executorEnvs("SPARK_MEM") = SparkContext.executorMemoryRequested + "m" - if (environment != null) { - executorEnvs ++= environment - } + executorEnvs("SPARK_MEM") = executorMemory + "m" + executorEnvs ++= conf.getExecutorEnv // Set SPARK_USER for user who is running SparkContext. val sparkUser = Option { @@ -183,12 +213,12 @@ class SparkContext( hadoopConf.set("fs.s3n.awsSecretAccessKey", System.getenv("AWS_SECRET_ACCESS_KEY")) } // Copy any "spark.hadoop.foo=bar" system properties into conf as "foo=bar" - Utils.getSystemProperties.foreach { case (key, value) => + conf.getAll.foreach { case (key, value) => if (key.startsWith("spark.hadoop.")) { hadoopConf.set(key.substring("spark.hadoop.".length), value) } } - val bufferSize = conf.getOrElse("spark.buffer.size", "65536") + val bufferSize = conf.getOrElse("spark.buffer.size", "65536") hadoopConf.set("io.file.buffer.size", bufferSize) hadoopConf } @@ -200,7 +230,7 @@ class SparkContext( override protected def childValue(parent: Properties): Properties = new Properties(parent) } - private[spark] def getLocalProperties(): Properties = localProperties.get() + private[spark] def getLocalProperties: Properties = localProperties.get() private[spark] def setLocalProperties(props: Properties) { localProperties.set(props) @@ -533,7 +563,7 @@ class SparkContext( // Fetch the file locally in case a job is executed locally. // Jobs that run through LocalScheduler will already fetch the required dependencies, // but jobs run in DAGScheduler.runLocally() will not so we must fetch the files here. - Utils.fetchFile(path, new File(SparkFiles.getRootDirectory)) + Utils.fetchFile(path, new File(SparkFiles.getRootDirectory), conf) logInfo("Added file " + path + " at " + key + " with timestamp " + addedFiles(key)) } @@ -915,14 +945,6 @@ object SparkContext { private[spark] val SPARK_UNKNOWN_USER = "" - private lazy val conf = new SparkConf() - - private[spark] def globalConf = { - if (SparkEnv.get != null) { - SparkEnv.get.conf - } else conf - } - implicit object DoubleAccumulatorParam extends AccumulatorParam[Double] { def addInPlace(t1: Double, t2: Double): Double = t1 + t2 def zero(initialValue: Double) = 0.0 @@ -1031,18 +1053,10 @@ object SparkContext { /** Find the JAR that contains the class of a particular object */ def jarOfObject(obj: AnyRef): Seq[String] = jarOfClass(obj.getClass) - /** Get the amount of memory per executor requested through system properties or SPARK_MEM */ - private[spark] val executorMemoryRequested = { - // TODO: Might need to add some extra memory for the non-heap parts of the JVM - Try(globalConf.get("spark.executor.memory")).toOption - .orElse(Option(System.getenv("SPARK_MEM"))) - .map(Utils.memoryStringToMb) - .getOrElse(512) - } - // Creates a task scheduler based on a given master URL. Extracted for testing. - private - def createTaskScheduler(sc: SparkContext, master: String, appName: String): TaskScheduler = { + private def createTaskScheduler(sc: SparkContext, master: String, appName: String) + : TaskScheduler = + { // Regular expression used for local[N] master format val LOCAL_N_REGEX = """local\[([0-9]+)\]""".r // Regular expression for local[N, maxRetries], used in tests with failing tasks @@ -1076,10 +1090,10 @@ object SparkContext { case LOCAL_CLUSTER_REGEX(numSlaves, coresPerSlave, memoryPerSlave) => // Check to make sure memory requested <= memoryPerSlave. Otherwise Spark will just hang. val memoryPerSlaveInt = memoryPerSlave.toInt - if (SparkContext.executorMemoryRequested > memoryPerSlaveInt) { + if (sc.executorMemory > memoryPerSlaveInt) { throw new SparkException( "Asked to launch cluster with %d MB RAM / worker but requested %d MB/worker".format( - memoryPerSlaveInt, SparkContext.executorMemoryRequested)) + memoryPerSlaveInt, sc.executorMemory)) } val scheduler = new ClusterScheduler(sc) @@ -1137,7 +1151,7 @@ object SparkContext { case mesosUrl @ MESOS_REGEX(_) => MesosNativeLibrary.load() val scheduler = new ClusterScheduler(sc) - val coarseGrained = globalConf.getOrElse("spark.mesos.coarse", "false").toBoolean + val coarseGrained = sc.conf.getOrElse("spark.mesos.coarse", "false").toBoolean val url = mesosUrl.stripPrefix("mesos://") // strip scheme from raw Mesos URLs val backend = if (coarseGrained) { new CoarseMesosSchedulerBackend(scheduler, sc, url, appName) diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala index 78e4ae27b2..34fad3e763 100644 --- a/core/src/main/scala/org/apache/spark/SparkEnv.scala +++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala @@ -107,7 +107,7 @@ object SparkEnv extends Logging { /** * Returns the ThreadLocal SparkEnv. */ - def getThreadLocal : SparkEnv = { + def getThreadLocal: SparkEnv = { env.get() } @@ -150,18 +150,19 @@ object SparkEnv extends Logging { val serializerManager = new SerializerManager val serializer = serializerManager.setDefault( - conf.getOrElse("spark.serializer", "org.apache.spark.serializer.JavaSerializer")) + conf.getOrElse("spark.serializer", "org.apache.spark.serializer.JavaSerializer"), conf) val closureSerializer = serializerManager.get( - conf.getOrElse("spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer")) + conf.getOrElse("spark.closure.serializer", "org.apache.spark.serializer.JavaSerializer"), + conf) def registerOrLookup(name: String, newActor: => Actor): Either[ActorRef, ActorSelection] = { if (isDriver) { logInfo("Registering " + name) Left(actorSystem.actorOf(Props(newActor), name = name)) } else { - val driverHost: String = conf.getOrElse("spark.driver.host", "localhost") - val driverPort: Int = conf.getOrElse("spark.driver.port", "7077").toInt + val driverHost: String = conf.getOrElse("spark.driver.host", "localhost") + val driverPort: Int = conf.getOrElse("spark.driver.port", "7077").toInt Utils.checkHost(driverHost, "Expected hostname") val url = "akka.tcp://spark@%s:%s/user/%s".format(driverHost, driverPort, name) logInfo("Connecting to " + name + ": " + url) diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala index acf328aa6a..e03cf9d13a 100644 --- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala +++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala @@ -29,17 +29,22 @@ import org.apache.hadoop.mapred.JobConf import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat} import com.google.common.base.Optional -import org.apache.spark.{Accumulable, AccumulableParam, Accumulator, AccumulatorParam, SparkContext} +import org.apache.spark._ import org.apache.spark.SparkContext.IntAccumulatorParam import org.apache.spark.SparkContext.DoubleAccumulatorParam import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD +import scala.Tuple2 /** * A Java-friendly version of [[org.apache.spark.SparkContext]] that returns [[org.apache.spark.api.java.JavaRDD]]s and * works with Java collections instead of Scala ones. */ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWorkaround { + /** + * @param conf a [[org.apache.spark.SparkConf]] object specifying Spark parameters + */ + def this(conf: SparkConf) = this(new SparkContext(conf)) /** * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). @@ -47,6 +52,14 @@ class JavaSparkContext(val sc: SparkContext) extends JavaSparkContextVarargsWork */ def this(master: String, appName: String) = this(new SparkContext(master, appName)) + /** + * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). + * @param appName A name for your application, to display on the cluster web UI + * @param conf a [[org.apache.spark.SparkConf]] object specifying other Spark parameters + */ + def this(master: String, appName: String, conf: SparkConf) = + this(conf.setMaster(master).setAppName(appName)) + /** * @param master Cluster URL to connect to (e.g. mesos://host:port, spark://host:port, local[4]). * @param appName A name for your application, to display on the cluster web UI diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala index d6eacfe23e..05fd824254 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala @@ -41,7 +41,7 @@ private[spark] class PythonRDD[T: ClassTag]( accumulator: Accumulator[JList[Array[Byte]]]) extends RDD[Array[Byte]](parent) { - val bufferSize = conf.getOrElse("spark.buffer.size", "65536").toInt + val bufferSize = conf.getOrElse("spark.buffer.size", "65536").toInt override def getPartitions = parent.partitions @@ -247,10 +247,10 @@ private class BytesToString extends org.apache.spark.api.java.function.Function[ */ private class PythonAccumulatorParam(@transient serverHost: String, serverPort: Int) extends AccumulatorParam[JList[Array[Byte]]] { - import SparkContext.{globalConf => conf} + Utils.checkHost(serverHost, "Expected hostname") - val bufferSize = conf.getOrElse("spark.buffer.size", "65536").toInt + val bufferSize = SparkEnv.get.conf.getOrElse("spark.buffer.size", "65536").toInt override def zero(value: JList[Array[Byte]]): JList[Array[Byte]] = new JArrayList diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala index cecb8c228b..47528bcee8 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala @@ -31,7 +31,7 @@ import org.apache.spark.util.{MetadataCleaner, MetadataCleanerType, TimeStampedH private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolean, id: Long) extends Broadcast[T](id) with Logging with Serializable { - + def value = value_ def blockId = BroadcastBlockId(id) @@ -40,7 +40,7 @@ private[spark] class HttpBroadcast[T](@transient var value_ : T, isLocal: Boolea SparkEnv.get.blockManager.putSingle(blockId, value_, StorageLevel.MEMORY_AND_DISK, false) } - if (!isLocal) { + if (!isLocal) { HttpBroadcast.write(id, value_) } @@ -81,41 +81,48 @@ private object HttpBroadcast extends Logging { private var serverUri: String = null private var server: HttpServer = null + // TODO: This shouldn't be a global variable so that multiple SparkContexts can coexist private val files = new TimeStampedHashSet[String] - private val cleaner = new MetadataCleaner(MetadataCleanerType.HTTP_BROADCAST, cleanup) + private var cleaner: MetadataCleaner = null - private val httpReadTimeout = TimeUnit.MILLISECONDS.convert(5,TimeUnit.MINUTES).toInt + private val httpReadTimeout = TimeUnit.MILLISECONDS.convert(5, TimeUnit.MINUTES).toInt - private lazy val compressionCodec = CompressionCodec.createCodec() + private var compressionCodec: CompressionCodec = null def initialize(isDriver: Boolean, conf: SparkConf) { synchronized { if (!initialized) { - bufferSize = conf.getOrElse("spark.buffer.size", "65536").toInt - compress = conf.getOrElse("spark.broadcast.compress", "true").toBoolean + bufferSize = conf.getOrElse("spark.buffer.size", "65536").toInt + compress = conf.getOrElse("spark.broadcast.compress", "true").toBoolean if (isDriver) { - createServer() + createServer(conf) conf.set("spark.httpBroadcast.uri", serverUri) } serverUri = conf.get("spark.httpBroadcast.uri") + cleaner = new MetadataCleaner(MetadataCleanerType.HTTP_BROADCAST, cleanup, conf) + compressionCodec = CompressionCodec.createCodec(conf) initialized = true } } } - + def stop() { synchronized { if (server != null) { server.stop() server = null } + if (cleaner != null) { + cleaner.cancel() + cleaner = null + } + compressionCodec = null initialized = false - cleaner.cancel() } } - private def createServer() { - broadcastDir = Utils.createTempDir(Utils.getLocalDir) + private def createServer(conf: SparkConf) { + broadcastDir = Utils.createTempDir(Utils.getLocalDir(conf)) server = new HttpServer(broadcastDir) server.start() serverUri = server.uri @@ -143,7 +150,7 @@ private object HttpBroadcast extends Logging { val in = { val httpConnection = new URL(url).openConnection() httpConnection.setReadTimeout(httpReadTimeout) - val inputStream = httpConnection.getInputStream() + val inputStream = httpConnection.getInputStream if (compress) { compressionCodec.compressedInputStream(inputStream) } else { diff --git a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala index 4a3801dc48..00ec3b971b 100644 --- a/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala +++ b/core/src/main/scala/org/apache/spark/broadcast/TorrentBroadcast.scala @@ -83,13 +83,13 @@ extends Broadcast[T](id) with Logging with Serializable { case None => val start = System.nanoTime logInfo("Started reading broadcast variable " + id) - + // Initialize @transient variables that will receive garbage values from the master. resetWorkerVariables() if (receiveBroadcast(id)) { value_ = TorrentBroadcast.unBlockifyObject[T](arrayOfBlocks, totalBytes, totalBlocks) - + // Store the merged copy in cache so that the next worker doesn't need to rebuild it. // This creates a tradeoff between memory usage and latency. // Storing copy doubles the memory footprint; not storing doubles deserialization cost. @@ -122,14 +122,14 @@ extends Broadcast[T](id) with Logging with Serializable { while (attemptId > 0 && totalBlocks == -1) { TorrentBroadcast.synchronized { SparkEnv.get.blockManager.getSingle(metaId) match { - case Some(x) => + case Some(x) => val tInfo = x.asInstanceOf[TorrentInfo] totalBlocks = tInfo.totalBlocks totalBytes = tInfo.totalBytes arrayOfBlocks = new Array[TorrentBlock](totalBlocks) hasBlocks = 0 - - case None => + + case None => Thread.sleep(500) } } @@ -145,13 +145,13 @@ extends Broadcast[T](id) with Logging with Serializable { val pieceId = BroadcastHelperBlockId(broadcastId, "piece" + pid) TorrentBroadcast.synchronized { SparkEnv.get.blockManager.getSingle(pieceId) match { - case Some(x) => + case Some(x) => arrayOfBlocks(pid) = x.asInstanceOf[TorrentBlock] hasBlocks += 1 SparkEnv.get.blockManager.putSingle( pieceId, arrayOfBlocks(pid), StorageLevel.MEMORY_AND_DISK, true) - - case None => + + case None => throw new SparkException("Failed to get " + pieceId + " of " + broadcastId) } } @@ -175,13 +175,13 @@ extends Logging { } } } - + def stop() { initialized = false } - lazy val BLOCK_SIZE = conf.getOrElse("spark.broadcast.blockSize", "4096").toInt * 1024 - + lazy val BLOCK_SIZE = conf.getOrElse("spark.broadcast.blockSize", "4096").toInt * 1024 + def blockifyObject[T](obj: T): TorrentInfo = { val byteArray = Utils.serialize[T](obj) val bais = new ByteArrayInputStream(byteArray) @@ -210,7 +210,7 @@ extends Logging { } def unBlockifyObject[T](arrayOfBlocks: Array[TorrentBlock], - totalBytes: Int, + totalBytes: Int, totalBlocks: Int): T = { var retByteArray = new Array[Byte](totalBytes) for (i <- 0 until totalBlocks) { @@ -223,22 +223,22 @@ extends Logging { } private[spark] case class TorrentBlock( - blockID: Int, - byteArray: Array[Byte]) + blockID: Int, + byteArray: Array[Byte]) extends Serializable private[spark] case class TorrentInfo( @transient arrayOfBlocks : Array[TorrentBlock], - totalBlocks: Int, - totalBytes: Int) + totalBlocks: Int, + totalBytes: Int) extends Serializable { - - @transient var hasBlocks = 0 + + @transient var hasBlocks = 0 } private[spark] class TorrentBroadcastFactory extends BroadcastFactory { - + def initialize(isDriver: Boolean, conf: SparkConf) { TorrentBroadcast.initialize(isDriver, conf) } def newBroadcast[T](value_ : T, isLocal: Boolean, id: Long) = diff --git a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala index dda43dc018..19d393a0db 100644 --- a/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala +++ b/core/src/main/scala/org/apache/spark/deploy/ApplicationDescription.scala @@ -26,7 +26,7 @@ private[spark] class ApplicationDescription( val appUiUrl: String) extends Serializable { - val user = System.getProperty("user.name", "") + val user = System.getProperty("user.name", "") override def toString: String = "ApplicationDescription(" + name + ")" } diff --git a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala index 59d12a3e6f..ffc0cb0903 100644 --- a/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala +++ b/core/src/main/scala/org/apache/spark/deploy/LocalSparkCluster.scala @@ -22,7 +22,7 @@ import akka.actor.ActorSystem import org.apache.spark.deploy.worker.Worker import org.apache.spark.deploy.master.Master import org.apache.spark.util.Utils -import org.apache.spark.Logging +import org.apache.spark.{SparkConf, Logging} import scala.collection.mutable.ArrayBuffer @@ -43,7 +43,8 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I logInfo("Starting a local Spark cluster with " + numWorkers + " workers.") /* Start the Master */ - val (masterSystem, masterPort, _) = Master.startSystemAndActor(localHostname, 0, 0) + val conf = new SparkConf(false) + val (masterSystem, masterPort, _) = Master.startSystemAndActor(localHostname, 0, 0, conf) masterActorSystems += masterSystem val masterUrl = "spark://" + localHostname + ":" + masterPort val masters = Array(masterUrl) @@ -55,7 +56,7 @@ class LocalSparkCluster(numWorkers: Int, coresPerWorker: Int, memoryPerWorker: I workerActorSystems += workerSystem } - return masters + masters } def stop() { diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala index 1c979ac3e0..4f402c1121 100644 --- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala +++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala @@ -34,10 +34,10 @@ class SparkHadoopUtil { UserGroupInformation.setConfiguration(conf) def runAsUser(user: String)(func: () => Unit) { - // if we are already running as the user intended there is no reason to do the doAs. It + // if we are already running as the user intended there is no reason to do the doAs. It // will actually break secure HDFS access as it doesn't fill in the credentials. Also if - // the user is UNKNOWN then we shouldn't be creating a remote unknown user - // (this is actually the path spark on yarn takes) since SPARK_USER is initialized only + // the user is UNKNOWN then we shouldn't be creating a remote unknown user + // (this is actually the path spark on yarn takes) since SPARK_USER is initialized only // in SparkContext. val currentUser = Option(System.getProperty("user.name")). getOrElse(SparkContext.SPARK_UNKNOWN_USER) @@ -67,12 +67,14 @@ class SparkHadoopUtil { } object SparkHadoopUtil { - import SparkContext.{globalConf => conf} + private val hadoop = { - val yarnMode = java.lang.Boolean.valueOf(conf.getOrElse("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE"))) + val yarnMode = java.lang.Boolean.valueOf(System.getenv("SPARK_YARN_MODE")) if (yarnMode) { try { - Class.forName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil").newInstance.asInstanceOf[SparkHadoopUtil] + Class.forName("org.apache.spark.deploy.yarn.YarnSparkHadoopUtil") + .newInstance() + .asInstanceOf[SparkHadoopUtil] } catch { case th: Throwable => throw new SparkException("Unable to load YARN support", th) } diff --git a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala index 426cf524ae..ef649fd80c 100644 --- a/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala +++ b/core/src/main/scala/org/apache/spark/deploy/client/TestClient.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy.client import org.apache.spark.util.{Utils, AkkaUtils} -import org.apache.spark.{SparkContext, Logging} +import org.apache.spark.{SparkConf, SparkContext, Logging} import org.apache.spark.deploy.{Command, ApplicationDescription} private[spark] object TestClient { @@ -46,11 +46,12 @@ private[spark] object TestClient { def main(args: Array[String]) { val url = args(0) val (actorSystem, port) = AkkaUtils.createActorSystem("spark", Utils.localIpAddress, 0, - conf = SparkContext.globalConf) + conf = new SparkConf) val desc = new ApplicationDescription( - "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), "dummy-spark-home", "ignored") + "TestClient", 1, 512, Command("spark.deploy.client.TestExecutor", Seq(), Map()), + "dummy-spark-home", "ignored") val listener = new TestListener - val client = new Client(actorSystem, Array(url), desc, listener, SparkContext.globalConf) + val client = new Client(actorSystem, Array(url), desc, listener, new SparkConf) client.start() actorSystem.awaitTermination() } diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala index 2c162c4fa2..9c89e36b14 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala @@ -29,7 +29,7 @@ import akka.pattern.ask import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent} import akka.serialization.SerializationExtension -import org.apache.spark.{SparkContext, Logging, SparkException} +import org.apache.spark.{SparkConf, SparkContext, Logging, SparkException} import org.apache.spark.deploy.{ApplicationDescription, ExecutorState} import org.apache.spark.deploy.DeployMessages._ import org.apache.spark.deploy.master.MasterMessages._ @@ -38,14 +38,16 @@ import org.apache.spark.metrics.MetricsSystem import org.apache.spark.util.{AkkaUtils, Utils} private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Actor with Logging { - import context.dispatcher - val conf = SparkContext.globalConf + import context.dispatcher // to use Akka's scheduler.schedule() + + val conf = new SparkConf + val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For application IDs - val WORKER_TIMEOUT = conf.getOrElse("spark.worker.timeout", "60").toLong * 1000 - val RETAINED_APPLICATIONS = conf.getOrElse("spark.deploy.retainedApplications", "200").toInt - val REAPER_ITERATIONS = conf.getOrElse("spark.dead.worker.persistence", "15").toInt - val RECOVERY_DIR = conf.getOrElse("spark.deploy.recoveryDirectory", "") - val RECOVERY_MODE = conf.getOrElse("spark.deploy.recoveryMode", "NONE") + val WORKER_TIMEOUT = conf.getOrElse("spark.worker.timeout", "60").toLong * 1000 + val RETAINED_APPLICATIONS = conf.getOrElse("spark.deploy.retainedApplications", "200").toInt + val REAPER_ITERATIONS = conf.getOrElse("spark.dead.worker.persistence", "15").toInt + val RECOVERY_DIR = conf.getOrElse("spark.deploy.recoveryDirectory", "") + val RECOVERY_MODE = conf.getOrElse("spark.deploy.recoveryMode", "NONE") var nextAppNumber = 0 val workers = new HashSet[WorkerInfo] @@ -86,7 +88,7 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act // As a temporary workaround before better ways of configuring memory, we allow users to set // a flag that will perform round-robin scheduling across the nodes (spreading out each app // among all the nodes) instead of trying to consolidate each app onto a small # of nodes. - val spreadOutApps = conf.getOrElse("spark.deploy.spreadOut", "true").toBoolean + val spreadOutApps = conf.getOrElse("spark.deploy.spreadOut", "true").toBoolean override def preStart() { logInfo("Starting Spark master at " + masterUrl) @@ -495,7 +497,7 @@ private[spark] class Master(host: String, port: Int, webUiPort: Int) extends Act removeWorker(worker) } else { if (worker.lastHeartbeat < currentTime - ((REAPER_ITERATIONS + 1) * WORKER_TIMEOUT)) - workers -= worker // we've seen this DEAD worker in the UI, etc. for long enough; cull it + workers -= worker // we've seen this DEAD worker in the UI, etc. for long enough; cull it } } } @@ -507,8 +509,9 @@ private[spark] object Master { val sparkUrlRegex = "spark://([^:]+):([0-9]+)".r def main(argStrings: Array[String]) { - val args = new MasterArguments(argStrings, SparkContext.globalConf) - val (actorSystem, _, _) = startSystemAndActor(args.host, args.port, args.webUiPort) + val conf = new SparkConf + val args = new MasterArguments(argStrings, conf) + val (actorSystem, _, _) = startSystemAndActor(args.host, args.port, args.webUiPort, conf) actorSystem.awaitTermination() } @@ -522,11 +525,12 @@ private[spark] object Master { } } - def startSystemAndActor(host: String, port: Int, webUiPort: Int): (ActorSystem, Int, Int) = { - val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, - conf = SparkContext.globalConf) + def startSystemAndActor(host: String, port: Int, webUiPort: Int, conf: SparkConf) + : (ActorSystem, Int, Int) = + { + val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, conf = conf) val actor = actorSystem.actorOf(Props(classOf[Master], host, boundPort, webUiPort), actorName) - val timeout = AkkaUtils.askTimeout(SparkContext.globalConf) + val timeout = AkkaUtils.askTimeout(conf) val respFuture = actor.ask(RequestWebUIPort)(timeout) val resp = Await.result(respFuture, timeout).asInstanceOf[WebUIPortResponse] (actorSystem, boundPort, resp.webUIBoundPort) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/SparkZooKeeperSession.scala b/core/src/main/scala/org/apache/spark/deploy/master/SparkZooKeeperSession.scala index 79d95b1a83..60c7a7c2d6 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/SparkZooKeeperSession.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/SparkZooKeeperSession.scala @@ -37,7 +37,7 @@ import org.apache.spark.{SparkConf, Logging} */ private[spark] class SparkZooKeeperSession(zkWatcher: SparkZooKeeperWatcher, conf: SparkConf) extends Logging { - val ZK_URL = conf.getOrElse("spark.deploy.zookeeper.url", "") + val ZK_URL = conf.getOrElse("spark.deploy.zookeeper.url", "") val ZK_ACL = ZooDefs.Ids.OPEN_ACL_UNSAFE val ZK_TIMEOUT_MILLIS = 30000 diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala index df5bb368a2..a61597bbdf 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperLeaderElectionAgent.scala @@ -28,7 +28,7 @@ private[spark] class ZooKeeperLeaderElectionAgent(val masterActor: ActorRef, masterUrl: String, conf: SparkConf) extends LeaderElectionAgent with SparkZooKeeperWatcher with Logging { - val WORKING_DIR = conf.getOrElse("spark.deploy.zookeeper.dir", "/spark") + "/leader_election" + val WORKING_DIR = conf.getOrElse("spark.deploy.zookeeper.dir", "/spark") + "/leader_election" private val watcher = new ZooKeeperWatcher() private val zk = new SparkZooKeeperSession(this, conf) diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala index c55b720422..245a558a59 100644 --- a/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala +++ b/core/src/main/scala/org/apache/spark/deploy/master/ZooKeeperPersistenceEngine.scala @@ -27,7 +27,7 @@ class ZooKeeperPersistenceEngine(serialization: Serialization, conf: SparkConf) with SparkZooKeeperWatcher with Logging { - val WORKING_DIR = conf.getOrElse("spark.deploy.zookeeper.dir", "/spark") + "/master_status" + val WORKING_DIR = conf.getOrElse("spark.deploy.zookeeper.dir", "/spark") + "/master_status" val zk = new SparkZooKeeperSession(this, conf) diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala index 75a6e75c78..f844fcbbfc 100644 --- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala +++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala @@ -55,7 +55,7 @@ private[spark] class Worker( val DATE_FORMAT = new SimpleDateFormat("yyyyMMddHHmmss") // For worker and executor IDs // Send a heartbeat every (heartbeat timeout) / 4 milliseconds - val HEARTBEAT_MILLIS = conf.getOrElse("spark.worker.timeout", "60").toLong * 1000 / 4 + val HEARTBEAT_MILLIS = conf.getOrElse("spark.worker.timeout", "60").toLong * 1000 / 4 val REGISTRATION_TIMEOUT = 20.seconds val REGISTRATION_RETRIES = 3 @@ -267,7 +267,7 @@ private[spark] class Worker( } private[spark] object Worker { - import org.apache.spark.SparkContext.globalConf + def main(argStrings: Array[String]) { val args = new WorkerArguments(argStrings) val (actorSystem, _) = startSystemAndActor(args.host, args.port, args.webUiPort, args.cores, @@ -276,14 +276,16 @@ private[spark] object Worker { } def startSystemAndActor(host: String, port: Int, webUiPort: Int, cores: Int, memory: Int, - masterUrls: Array[String], workDir: String, workerNumber: Option[Int] = None) - : (ActorSystem, Int) = { + masterUrls: Array[String], workDir: String, workerNumber: Option[Int] = None) + : (ActorSystem, Int) = + { // The LocalSparkCluster runs multiple local sparkWorkerX actor systems + val conf = new SparkConf val systemName = "sparkWorker" + workerNumber.map(_.toString).getOrElse("") val (actorSystem, boundPort) = AkkaUtils.createActorSystem(systemName, host, port, - conf = globalConf) + conf = conf) actorSystem.actorOf(Props(classOf[Worker], host, boundPort, webUiPort, cores, memory, - masterUrls, workDir, globalConf), name = "Worker") + masterUrls, workDir, conf), name = "Worker") (actorSystem, boundPort) } diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala index c8319f6f6e..53a2b94a52 100644 --- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala +++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala @@ -22,7 +22,7 @@ import java.nio.ByteBuffer import akka.actor._ import akka.remote._ -import org.apache.spark.{SparkContext, Logging} +import org.apache.spark.{SparkConf, SparkContext, Logging} import org.apache.spark.TaskState.TaskState import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._ import org.apache.spark.util.{Utils, AkkaUtils} @@ -98,7 +98,7 @@ private[spark] object CoarseGrainedExecutorBackend { // Create a new ActorSystem to run the backend, because we can't create a SparkEnv / Executor // before getting started with all our system properties, etc val (actorSystem, boundPort) = AkkaUtils.createActorSystem("sparkExecutor", hostname, 0, - indestructible = true, conf = SparkContext.globalConf) + indestructible = true, conf = new SparkConf) // set it val sparkHostPort = hostname + ":" + boundPort // conf.set("spark.hostPort", sparkHostPort) diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala index 70fc30e993..a6eabc462b 100644 --- a/core/src/main/scala/org/apache/spark/executor/Executor.scala +++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala @@ -57,17 +57,18 @@ private[spark] class Executor( // Make sure the local hostname we report matches the cluster scheduler's name for this host Utils.setCustomHostname(slaveHostname) + + // Set spark.* properties from executor arg val conf = new SparkConf(false) - // Set spark.* system properties from executor arg - for ((key, value) <- properties) { - conf.set(key, value) - } + conf.setAll(properties) // If we are in yarn mode, systems can have different disk layouts so we must set it // to what Yarn on this system said was available. This will be used later when SparkEnv // created. - if (java.lang.Boolean.valueOf(System.getenv("SPARK_YARN_MODE"))) { - conf.set("spark.local.dir", getYarnLocalDirs()) + if (java.lang.Boolean.valueOf( + System.getProperty("SPARK_YARN_MODE", System.getenv("SPARK_YARN_MODE")))) + { + conf.set("spark.local.dir", getYarnLocalDirs()) } // Create our ClassLoader and set it on this thread @@ -331,12 +332,12 @@ private[spark] class Executor( // Fetch missing dependencies for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf) currentFiles(name) = timestamp } for ((name, timestamp) <- newJars if currentJars.getOrElse(name, -1L) < timestamp) { logInfo("Fetching " + name + " with timestamp " + timestamp) - Utils.fetchFile(name, new File(SparkFiles.getRootDirectory)) + Utils.fetchFile(name, new File(SparkFiles.getRootDirectory), conf) currentJars(name) = timestamp // Add it to our class loader val localName = name.split("/").last diff --git a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala index 8ef5019b6c..20402686a8 100644 --- a/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala +++ b/core/src/main/scala/org/apache/spark/io/CompressionCodec.scala @@ -22,7 +22,7 @@ import java.io.{InputStream, OutputStream} import com.ning.compress.lzf.{LZFInputStream, LZFOutputStream} import org.xerial.snappy.{SnappyInputStream, SnappyOutputStream} -import org.apache.spark.SparkConf +import org.apache.spark.{SparkEnv, SparkConf} /** @@ -38,16 +38,15 @@ trait CompressionCodec { private[spark] object CompressionCodec { - import org.apache.spark.SparkContext.globalConf - def createCodec(): CompressionCodec = { - createCodec(System.getProperty( + def createCodec(conf: SparkConf): CompressionCodec = { + createCodec(conf, conf.getOrElse( "spark.io.compression.codec", classOf[LZFCompressionCodec].getName)) } - def createCodec(codecName: String): CompressionCodec = { + def createCodec(conf: SparkConf, codecName: String): CompressionCodec = { val ctor = Class.forName(codecName, true, Thread.currentThread.getContextClassLoader) .getConstructor(classOf[SparkConf]) - ctor.newInstance(globalConf).asInstanceOf[CompressionCodec] + ctor.newInstance(conf).asInstanceOf[CompressionCodec] } } @@ -72,7 +71,7 @@ class LZFCompressionCodec(conf: SparkConf) extends CompressionCodec { class SnappyCompressionCodec(conf: SparkConf) extends CompressionCodec { override def compressedOutputStream(s: OutputStream): OutputStream = { - val blockSize = conf.getOrElse("spark.io.compression.snappy.block.size", "32768").toInt + val blockSize = conf.getOrElse("spark.io.compression.snappy.block.size", "32768").toInt new SnappyOutputStream(s, blockSize) } diff --git a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala index 3e902f8ac5..697096fa76 100644 --- a/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala +++ b/core/src/main/scala/org/apache/spark/network/ConnectionManager.scala @@ -593,10 +593,8 @@ private[spark] class ConnectionManager(port: Int, conf: SparkConf) extends Loggi private[spark] object ConnectionManager { - import SparkContext.globalConf - def main(args: Array[String]) { - val manager = new ConnectionManager(9999, globalConf) + val manager = new ConnectionManager(9999, new SparkConf) manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { println("Received [" + msg + "] from [" + id + "]") None diff --git a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala b/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala index 4ca3cd390b..1c9d6030d6 100644 --- a/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala +++ b/core/src/main/scala/org/apache/spark/network/ReceiverTest.scala @@ -19,19 +19,19 @@ package org.apache.spark.network import java.nio.ByteBuffer import java.net.InetAddress +import org.apache.spark.SparkConf private[spark] object ReceiverTest { - import org.apache.spark.SparkContext.globalConf def main(args: Array[String]) { - val manager = new ConnectionManager(9999, globalConf) + val manager = new ConnectionManager(9999, new SparkConf) println("Started connection manager with id = " + manager.id) - - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + + manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { /*println("Received [" + msg + "] from [" + id + "] at " + System.currentTimeMillis)*/ - val buffer = ByteBuffer.wrap("response".getBytes()) + val buffer = ByteBuffer.wrap("response".getBytes) Some(Message.createBufferMessage(buffer, msg.id)) }) - Thread.currentThread.join() + Thread.currentThread.join() } } diff --git a/core/src/main/scala/org/apache/spark/network/SenderTest.scala b/core/src/main/scala/org/apache/spark/network/SenderTest.scala index 11c21fc1d5..dcbd183c88 100644 --- a/core/src/main/scala/org/apache/spark/network/SenderTest.scala +++ b/core/src/main/scala/org/apache/spark/network/SenderTest.scala @@ -19,29 +19,29 @@ package org.apache.spark.network import java.nio.ByteBuffer import java.net.InetAddress +import org.apache.spark.SparkConf private[spark] object SenderTest { - import org.apache.spark.SparkContext.globalConf def main(args: Array[String]) { - + if (args.length < 2) { println("Usage: SenderTest ") System.exit(1) } - + val targetHost = args(0) val targetPort = args(1).toInt val targetConnectionManagerId = new ConnectionManagerId(targetHost, targetPort) - val manager = new ConnectionManager(0, globalConf) + val manager = new ConnectionManager(0, new SparkConf) println("Started connection manager with id = " + manager.id) - manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { + manager.onReceiveMessage((msg: Message, id: ConnectionManagerId) => { println("Received [" + msg + "] from [" + id + "]") None }) - - val size = 100 * 1024 * 1024 + + val size = 100 * 1024 * 1024 val buffer = ByteBuffer.allocate(size).put(Array.tabulate[Byte](size)(x => x.toByte)) buffer.flip @@ -50,7 +50,7 @@ private[spark] object SenderTest { val count = 100 (0 until count).foreach(i => { val dataMessage = Message.createBufferMessage(buffer.duplicate) - val startTime = System.currentTimeMillis + val startTime = System.currentTimeMillis /*println("Started timer at " + startTime)*/ val responseStr = manager.sendMessageReliablySync(targetConnectionManagerId, dataMessage) match { case Some(response) => diff --git a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala index 81b3104afd..db28ddf9ac 100644 --- a/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala +++ b/core/src/main/scala/org/apache/spark/network/netty/ShuffleCopier.scala @@ -36,7 +36,7 @@ private[spark] class ShuffleCopier(conf: SparkConf) extends Logging { resultCollectCallback: (BlockId, Long, ByteBuf) => Unit) { val handler = new ShuffleCopier.ShuffleClientHandler(resultCollectCallback) - val connectTimeout = conf.getOrElse("spark.shuffle.netty.connect.timeout", "60000").toInt + val connectTimeout = conf.getOrElse("spark.shuffle.netty.connect.timeout", "60000").toInt val fc = new FileClient(handler, connectTimeout) try { @@ -104,10 +104,10 @@ private[spark] object ShuffleCopier extends Logging { val threads = if (args.length > 3) args(3).toInt else 10 val copiers = Executors.newFixedThreadPool(80) - val tasks = (for (i <- Range(0, threads)) yield { + val tasks = (for (i <- Range(0, threads)) yield { Executors.callable(new Runnable() { def run() { - val copier = new ShuffleCopier(SparkContext.globalConf) + val copier = new ShuffleCopier(new SparkConf) copier.getBlock(host, port, blockId, echoResultCollectCallBack) } }) diff --git a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala index 9fbe002748..2897c4b841 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CheckpointRDD.scala @@ -74,9 +74,6 @@ class CheckpointRDD[T: ClassTag](sc: SparkContext, val checkpointPath: String) } private[spark] object CheckpointRDD extends Logging { - - import SparkContext.{globalConf => conf} - def splitIdToFile(splitId: Int): String = { "part-%05d".format(splitId) } @@ -94,7 +91,7 @@ private[spark] object CheckpointRDD extends Logging { throw new IOException("Checkpoint failed: temporary path " + tempOutputPath + " already exists") } - val bufferSize = conf.getOrElse("spark.buffer.size", "65536").toInt + val bufferSize = env.conf.getOrElse("spark.buffer.size", "65536").toInt val fileOutputStream = if (blockSize < 0) { fs.create(tempOutputPath, false, bufferSize) @@ -124,7 +121,7 @@ private[spark] object CheckpointRDD extends Logging { def readFromFile[T](path: Path, context: TaskContext): Iterator[T] = { val env = SparkEnv.get val fs = path.getFileSystem(SparkHadoopUtil.get.newConfiguration()) - val bufferSize = conf.getOrElse("spark.buffer.size", "65536").toInt + val bufferSize = env.conf.getOrElse("spark.buffer.size", "65536").toInt val fileInputStream = fs.open(path, bufferSize) val serializer = env.serializer.newInstance() val deserializeStream = serializer.deserializeStream(fileInputStream) diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala index 911a002884..4ba4696fef 100644 --- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala @@ -114,7 +114,7 @@ class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: map.changeValue(k, update) } - val ser = SparkEnv.get.serializerManager.get(serializerClass) + val ser = SparkEnv.get.serializerManager.get(serializerClass, SparkEnv.get.conf) for ((dep, depNum) <- split.deps.zipWithIndex) dep match { case NarrowCoGroupSplitDep(rdd, _, itsSplit) => { // Read them from the parent diff --git a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala index 3682c84598..0ccb309d0d 100644 --- a/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/ShuffledRDD.scala @@ -59,7 +59,7 @@ class ShuffledRDD[K, V, P <: Product2[K, V] : ClassTag]( override def compute(split: Partition, context: TaskContext): Iterator[P] = { val shuffledId = dependencies.head.asInstanceOf[ShuffleDependency[K, V]].shuffleId SparkEnv.get.shuffleFetcher.fetch[P](shuffledId, split.index, context, - SparkEnv.get.serializerManager.get(serializerClass)) + SparkEnv.get.serializerManager.get(serializerClass, SparkEnv.get.conf)) } override def clearDependencies() { diff --git a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala index aab30b1bb4..4f90c7d3d6 100644 --- a/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala +++ b/core/src/main/scala/org/apache/spark/rdd/SubtractedRDD.scala @@ -93,7 +93,7 @@ private[spark] class SubtractedRDD[K: ClassTag, V: ClassTag, W: ClassTag]( override def compute(p: Partition, context: TaskContext): Iterator[(K, V)] = { val partition = p.asInstanceOf[CoGroupPartition] - val serializer = SparkEnv.get.serializerManager.get(serializerClass) + val serializer = SparkEnv.get.serializerManager.get(serializerClass, SparkEnv.get.conf) val map = new JHashMap[K, ArrayBuffer[V]] def getSeq(k: K): ArrayBuffer[V] = { val seq = map.get(k) diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala index 963d15b76d..77aa24e6b6 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala @@ -158,7 +158,8 @@ class DAGScheduler( val activeJobs = new HashSet[ActiveJob] val resultStageToJob = new HashMap[Stage, ActiveJob] - val metadataCleaner = new MetadataCleaner(MetadataCleanerType.DAG_SCHEDULER, this.cleanup) + val metadataCleaner = new MetadataCleaner( + MetadataCleanerType.DAG_SCHEDULER, this.cleanup, env.conf) /** * Starts the event processing actor. The actor has two responsibilities: @@ -529,7 +530,7 @@ class DAGScheduler( case JobSubmitted(jobId, rdd, func, partitions, allowLocal, callSite, listener, properties) => var finalStage: Stage = null try { - // New stage creation at times and if its not protected, the scheduler thread is killed. + // New stage creation at times and if its not protected, the scheduler thread is killed. // e.g. it can fail when jobs are run on HadoopRDD whose underlying hdfs files have been deleted finalStage = newStage(rdd, partitions.size, None, jobId, Some(callSite)) } catch { diff --git a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala index 1791ee660d..90eb8a747f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/InputFormatInfo.scala @@ -32,7 +32,7 @@ import scala.collection.JavaConversions._ /** * Parses and holds information about inputFormat (and files) specified as a parameter. */ -class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Class[_], +class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Class[_], val path: String) extends Logging { var mapreduceInputFormat: Boolean = false @@ -40,7 +40,7 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl validate() - override def toString(): String = { + override def toString: String = { "InputFormatInfo " + super.toString + " .. inputFormatClazz " + inputFormatClazz + ", path : " + path } @@ -125,7 +125,7 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl } private def findPreferredLocations(): Set[SplitInfo] = { - logDebug("mapreduceInputFormat : " + mapreduceInputFormat + ", mapredInputFormat : " + mapredInputFormat + + logDebug("mapreduceInputFormat : " + mapreduceInputFormat + ", mapredInputFormat : " + mapredInputFormat + ", inputFormatClazz : " + inputFormatClazz) if (mapreduceInputFormat) { return prefLocsFromMapreduceInputFormat() @@ -143,14 +143,14 @@ class InputFormatInfo(val configuration: Configuration, val inputFormatClazz: Cl object InputFormatInfo { /** Computes the preferred locations based on input(s) and returned a location to block map. - Typical use of this method for allocation would follow some algo like this - (which is what we currently do in YARN branch) : + Typical use of this method for allocation would follow some algo like this: + a) For each host, count number of splits hosted on that host. b) Decrement the currently allocated containers on that host. c) Compute rack info for each host and update rack -> count map based on (b). d) Allocate nodes based on (c) - e) On the allocation result, ensure that we dont allocate "too many" jobs on a single node - (even if data locality on that is very high) : this is to prevent fragility of job if a single + e) On the allocation result, ensure that we dont allocate "too many" jobs on a single node + (even if data locality on that is very high) : this is to prevent fragility of job if a single (or small set of) hosts go down. go to (a) until required nodes are allocated. diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala index 3f55cd5642..60927831a1 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala @@ -41,7 +41,7 @@ import org.apache.spark.storage.StorageLevel class JobLogger(val user: String, val logDirName: String) extends SparkListener with Logging { - def this() = this(System.getProperty("user.name", ""), + def this() = this(System.getProperty("user.name", ""), String.valueOf(System.currentTimeMillis())) private val logDir = diff --git a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala index 310ec62ca8..28f3ba53b8 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ResultTask.scala @@ -32,7 +32,9 @@ private[spark] object ResultTask { // expensive on the master node if it needs to launch thousands of tasks. val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]] - val metadataCleaner = new MetadataCleaner(MetadataCleanerType.RESULT_TASK, serializedInfoCache.clearOldValues) + // TODO: This object shouldn't have global variables + val metadataCleaner = new MetadataCleaner( + MetadataCleanerType.RESULT_TASK, serializedInfoCache.clearOldValues, new SparkConf) def serializeInfo(stageId: Int, rdd: RDD[_], func: (TaskContext, Iterator[_]) => _): Array[Byte] = { synchronized { diff --git a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala index 9002d33cda..3cf995ea74 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/SchedulableBuilder.scala @@ -52,7 +52,7 @@ private[spark] class FIFOSchedulableBuilder(val rootPool: Pool) private[spark] class FairSchedulableBuilder(val rootPool: Pool, conf: SparkConf) extends SchedulableBuilder with Logging { - val schedulerAllocFile = Option(conf.get("spark.scheduler.allocation.file")) + val schedulerAllocFile = conf.getOption("spark.scheduler.allocation.file") val DEFAULT_SCHEDULER_FILE = "fairscheduler.xml" val FAIR_SCHEDULER_PROPERTIES = "spark.scheduler.pool" val DEFAULT_POOL_NAME = "default" diff --git a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala index 0f2deb4bcb..a37ead5632 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/ShuffleMapTask.scala @@ -37,7 +37,9 @@ private[spark] object ShuffleMapTask { // expensive on the master node if it needs to launch thousands of tasks. val serializedInfoCache = new TimeStampedHashMap[Int, Array[Byte]] - val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SHUFFLE_MAP_TASK, serializedInfoCache.clearOldValues) + // TODO: This object shouldn't have global variables + val metadataCleaner = new MetadataCleaner( + MetadataCleanerType.SHUFFLE_MAP_TASK, serializedInfoCache.clearOldValues, new SparkConf) def serializeInfo(stageId: Int, rdd: RDD[_], dep: ShuffleDependency[_,_]): Array[Byte] = { synchronized { @@ -152,7 +154,7 @@ private[spark] class ShuffleMapTask( try { // Obtain all the block writers for shuffle blocks. - val ser = SparkEnv.get.serializerManager.get(dep.serializerClass) + val ser = SparkEnv.get.serializerManager.get(dep.serializerClass, SparkEnv.get.conf) shuffle = shuffleBlockManager.forMapTask(dep.shuffleId, partitionId, numOutputSplits, ser) // Write the map output to its associated buckets. diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala index 7e231ec44c..2707740d44 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterScheduler.scala @@ -51,10 +51,10 @@ private[spark] class ClusterScheduler(val sc: SparkContext) { val conf = sc.conf // How often to check for speculative tasks - val SPECULATION_INTERVAL = conf.getOrElse("spark.speculation.interval", "100").toLong + val SPECULATION_INTERVAL = conf.getOrElse("spark.speculation.interval", "100").toLong // Threshold above which we warn user initial TaskSet may be starved - val STARVATION_TIMEOUT = conf.getOrElse("spark.starvation.timeout", "15000").toLong + val STARVATION_TIMEOUT = conf.getOrElse("spark.starvation.timeout", "15000").toLong // ClusterTaskSetManagers are not thread safe, so any access to one should be synchronized // on this class. @@ -91,7 +91,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) var rootPool: Pool = null // default scheduler is FIFO val schedulingMode: SchedulingMode = SchedulingMode.withName( - conf.getOrElse("spark.scheduler.mode", "FIFO")) + conf.getOrElse("spark.scheduler.mode", "FIFO")) // This is a var so that we can reset it for testing purposes. private[spark] var taskResultGetter = new TaskResultGetter(sc.env, this) @@ -120,7 +120,7 @@ private[spark] class ClusterScheduler(val sc: SparkContext) override def start() { backend.start() - if (conf.getOrElse("spark.speculation", "false").toBoolean) { + if (conf.getOrElse("spark.speculation", "false").toBoolean) { logInfo("Starting speculative execution thread") import sc.env.actorSystem.dispatcher sc.env.actorSystem.scheduler.schedule(SPECULATION_INTERVAL milliseconds, diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala index 398b0cefbf..a46b16b92f 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManager.scala @@ -52,14 +52,14 @@ private[spark] class ClusterTaskSetManager( { val conf = sched.sc.conf // CPUs to request per task - val CPUS_PER_TASK = conf.getOrElse("spark.task.cpus", "1").toInt + val CPUS_PER_TASK = conf.getOrElse("spark.task.cpus", "1").toInt // Maximum times a task is allowed to fail before failing the job - val MAX_TASK_FAILURES = conf.getOrElse("spark.task.maxFailures", "4").toInt + val MAX_TASK_FAILURES = conf.getOrElse("spark.task.maxFailures", "4").toInt // Quantile of tasks at which to start speculation - val SPECULATION_QUANTILE = conf.getOrElse("spark.speculation.quantile", "0.75").toDouble - val SPECULATION_MULTIPLIER = conf.getOrElse("spark.speculation.multiplier", "1.5").toDouble + val SPECULATION_QUANTILE = conf.getOrElse("spark.speculation.quantile", "0.75").toDouble + val SPECULATION_MULTIPLIER = conf.getOrElse("spark.speculation.multiplier", "1.5").toDouble // Serializer for closures and tasks. val env = SparkEnv.get @@ -118,7 +118,7 @@ private[spark] class ClusterTaskSetManager( // How frequently to reprint duplicate exceptions in full, in milliseconds val EXCEPTION_PRINT_INTERVAL = - conf.getOrElse("spark.logging.exceptionPrintInterval", "10000").toLong + conf.getOrElse("spark.logging.exceptionPrintInterval", "10000").toLong // Map of recent exceptions (identified by string representation and top stack frame) to // duplicate count (how many times the same exception has appeared) and time the full exception @@ -678,7 +678,7 @@ private[spark] class ClusterTaskSetManager( } private def getLocalityWait(level: TaskLocality.TaskLocality): Long = { - val defaultWait = conf.getOrElse("spark.locality.wait", "3000") + val defaultWait = conf.getOrElse("spark.locality.wait", "3000") level match { case TaskLocality.PROCESS_LOCAL => conf.getOrElse("spark.locality.wait.process", defaultWait).toLong diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala index 40555903ac..156b01b149 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala @@ -62,7 +62,7 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent]) // Periodically revive offers to allow delay scheduling to work - val reviveInterval = conf.getOrElse("spark.scheduler.revive.interval", "1000").toLong + val reviveInterval = conf.getOrElse("spark.scheduler.revive.interval", "1000").toLong import context.dispatcher context.system.scheduler.schedule(0.millis, reviveInterval.millis, self, ReviveOffers) } @@ -118,7 +118,7 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac removeExecutor(executorId, reason) sender ! true - case DisassociatedEvent(_, address, _) => + case DisassociatedEvent(_, address, _) => addressToExecutorId.get(address).foreach(removeExecutor(_, "remote Akka client disassociated")) } @@ -163,10 +163,7 @@ class CoarseGrainedSchedulerBackend(scheduler: ClusterScheduler, actorSystem: Ac override def start() { val properties = new ArrayBuffer[(String, String)] - val iterator = scheduler.sc.conf.getAllConfiguration - while (iterator.hasNext) { - val entry = iterator.next - val (key, value) = (entry.getKey.toString, entry.getValue.toString) + for ((key, value) <- scheduler.sc.conf.getAll) { if (key.startsWith("spark.") && !key.equals("spark.hostPort")) { properties += ((key, value)) } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala index 5367218faa..65d3fc8187 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SchedulerBackend.scala @@ -31,7 +31,4 @@ private[spark] trait SchedulerBackend { def defaultParallelism(): Int def killTask(taskId: Long, executorId: String): Unit = throw new UnsupportedOperationException - - // Memory used by each executor (in megabytes) - protected val executorMemory: Int = SparkContext.executorMemoryRequested } diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala index d01329b2b3..d74f000ebb 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SimrSchedulerBackend.scala @@ -31,7 +31,7 @@ private[spark] class SimrSchedulerBackend( val tmpPath = new Path(driverFilePath + "_tmp") val filePath = new Path(driverFilePath) - val maxCores = conf.getOrElse("spark.simr.executor.cores", "1").toInt + val maxCores = conf.getOrElse("spark.simr.executor.cores", "1").toInt override def start() { super.start() diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala index d6b8ac2d57..de69e3260d 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/SparkDeploySchedulerBackend.scala @@ -49,7 +49,7 @@ private[spark] class SparkDeploySchedulerBackend( val command = Command( "org.apache.spark.executor.CoarseGrainedExecutorBackend", args, sc.executorEnvs) val sparkHome = sc.getSparkHome().getOrElse(null) - val appDesc = new ApplicationDescription(appName, maxCores, executorMemory, command, sparkHome, + val appDesc = new ApplicationDescription(appName, maxCores, sc.executorMemory, command, sparkHome, "http://" + sc.ui.appUIAddress) client = new Client(sc.env.actorSystem, masters, appDesc, this, conf) diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala index ff6cc37f1d..319c91b933 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/TaskResultGetter.scala @@ -32,7 +32,7 @@ import org.apache.spark.util.Utils private[spark] class TaskResultGetter(sparkEnv: SparkEnv, scheduler: ClusterScheduler) extends Logging { - private val THREADS = sparkEnv.conf.getOrElse("spark.resultGetter.threads", "4").toInt + private val THREADS = sparkEnv.conf.getOrElse("spark.resultGetter.threads", "4").toInt private val getTaskResultExecutor = Utils.newDaemonFixedThreadPool( THREADS, "Result resolver thread") diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala index 2a3b0e15f7..1695374152 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala @@ -76,7 +76,7 @@ private[spark] class CoarseMesosSchedulerBackend( "Spark home is not set; set it through the spark.home system " + "property, the SPARK_HOME environment variable or the SparkContext constructor")) - val extraCoresPerSlave = conf.getOrElse("spark.mesos.extra.cores", "0").toInt + val extraCoresPerSlave = conf.getOrElse("spark.mesos.extra.cores", "0").toInt var nextMesosTaskId = 0 @@ -176,7 +176,7 @@ private[spark] class CoarseMesosSchedulerBackend( val slaveId = offer.getSlaveId.toString val mem = getResource(offer.getResourcesList, "mem") val cpus = getResource(offer.getResourcesList, "cpus").toInt - if (totalCoresAcquired < maxCores && mem >= executorMemory && cpus >= 1 && + if (totalCoresAcquired < maxCores && mem >= sc.executorMemory && cpus >= 1 && failuresBySlaveId.getOrElse(slaveId, 0) < MAX_SLAVE_FAILURES && !slaveIdsWithExecutors.contains(slaveId)) { // Launch an executor on the slave @@ -192,7 +192,7 @@ private[spark] class CoarseMesosSchedulerBackend( .setCommand(createCommand(offer, cpusToUse + extraCoresPerSlave)) .setName("Task " + taskId) .addResources(createResource("cpus", cpusToUse)) - .addResources(createResource("mem", executorMemory)) + .addResources(createResource("mem", sc.executorMemory)) .build() d.launchTasks(offer.getId, Collections.singletonList(task), filters) } else { diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala index 9bb92b4f01..8dfd4d5fb3 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala @@ -114,7 +114,7 @@ private[spark] class MesosSchedulerBackend( val memory = Resource.newBuilder() .setName("mem") .setType(Value.Type.SCALAR) - .setScalar(Value.Scalar.newBuilder().setValue(executorMemory).build()) + .setScalar(Value.Scalar.newBuilder().setValue(sc.executorMemory).build()) .build() ExecutorInfo.newBuilder() .setExecutorId(ExecutorID.newBuilder().setValue(execId).build()) @@ -199,7 +199,7 @@ private[spark] class MesosSchedulerBackend( def enoughMemory(o: Offer) = { val mem = getResource(o.getResourcesList, "mem") val slaveId = o.getSlaveId.getValue - mem >= executorMemory || slaveIdsWithExecutors.contains(slaveId) + mem >= sc.executorMemory || slaveIdsWithExecutors.contains(slaveId) } for ((offer, index) <- offers.zipWithIndex if enoughMemory(offer)) { @@ -341,5 +341,5 @@ private[spark] class MesosSchedulerBackend( } // TODO: query Mesos for number of cores - override def defaultParallelism() = sc.conf.getOrElse("spark.default.parallelism", "8").toInt + override def defaultParallelism() = sc.conf.getOrElse("spark.default.parallelism", "8").toInt } diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala index 6069c1db3a..8498cffd31 100644 --- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala +++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalScheduler.scala @@ -92,7 +92,7 @@ private[spark] class LocalScheduler(val threads: Int, val maxFailures: Int, val var schedulableBuilder: SchedulableBuilder = null var rootPool: Pool = null val schedulingMode: SchedulingMode = SchedulingMode.withName( - conf.getOrElse("spark.scheduler.mode", "FIFO")) + conf.getOrElse("spark.scheduler.mode", "FIFO")) val activeTaskSets = new HashMap[String, LocalTaskSetManager] val taskIdToTaskSetId = new HashMap[Long, String] val taskSetTaskIds = new HashMap[String, HashSet[Long]] diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala index 4de81617b1..5d3d43623d 100644 --- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala @@ -21,6 +21,7 @@ import java.io._ import java.nio.ByteBuffer import org.apache.spark.util.ByteBufferInputStream +import org.apache.spark.SparkConf private[spark] class JavaSerializationStream(out: OutputStream) extends SerializationStream { val objOut = new ObjectOutputStream(out) @@ -77,6 +78,6 @@ private[spark] class JavaSerializerInstance extends SerializerInstance { /** * A Spark serializer that uses Java's built-in serialization. */ -class JavaSerializer extends Serializer { +class JavaSerializer(conf: SparkConf) extends Serializer { def newInstance(): SerializerInstance = new JavaSerializerInstance } diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala index 17cec81038..2367f3f521 100644 --- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala +++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala @@ -25,20 +25,21 @@ import com.esotericsoftware.kryo.{KryoException, Kryo} import com.esotericsoftware.kryo.io.{Input => KryoInput, Output => KryoOutput} import com.twitter.chill.{EmptyScalaKryoInstantiator, AllScalaRegistrar} -import org.apache.spark.{SparkContext, SparkConf, SerializableWritable, Logging} +import org.apache.spark._ import org.apache.spark.broadcast.HttpBroadcast import org.apache.spark.scheduler.MapStatus import org.apache.spark.storage._ import scala.util.Try +import org.apache.spark.storage.PutBlock +import org.apache.spark.storage.GetBlock +import org.apache.spark.storage.GotBlock /** * A Spark serializer that uses the [[https://code.google.com/p/kryo/ Kryo serialization library]]. */ -class KryoSerializer extends org.apache.spark.serializer.Serializer with Logging { - - private val conf = SparkContext.globalConf +class KryoSerializer(conf: SparkConf) extends org.apache.spark.serializer.Serializer with Logging { private val bufferSize = { - conf.getOrElse("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024 + conf.getOrElse("spark.kryoserializer.buffer.mb", "2").toInt * 1024 * 1024 } def newKryoOutput() = new KryoOutput(bufferSize) @@ -50,7 +51,7 @@ class KryoSerializer extends org.apache.spark.serializer.Serializer with Logging // Allow disabling Kryo reference tracking if user knows their object graphs don't have loops. // Do this before we invoke the user registrator so the user registrator can override this. - kryo.setReferences(conf.getOrElse("spark.kryo.referenceTracking", "true").toBoolean) + kryo.setReferences(conf.getOrElse("spark.kryo.referenceTracking", "true").toBoolean) for (cls <- KryoSerializer.toRegister) kryo.register(cls) diff --git a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala index 2955986fec..22465272f3 100644 --- a/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala +++ b/core/src/main/scala/org/apache/spark/serializer/SerializerManager.scala @@ -18,6 +18,7 @@ package org.apache.spark.serializer import java.util.concurrent.ConcurrentHashMap +import org.apache.spark.SparkConf /** @@ -32,12 +33,12 @@ private[spark] class SerializerManager { def default = _default - def setDefault(clsName: String): Serializer = { - _default = get(clsName) + def setDefault(clsName: String, conf: SparkConf): Serializer = { + _default = get(clsName, conf) _default } - def get(clsName: String): Serializer = { + def get(clsName: String, conf: SparkConf): Serializer = { if (clsName == null) { default } else { @@ -51,8 +52,9 @@ private[spark] class SerializerManager { serializer = serializers.get(clsName) if (serializer == null) { val clsLoader = Thread.currentThread.getContextClassLoader - serializer = - Class.forName(clsName, true, clsLoader).newInstance().asInstanceOf[Serializer] + val cls = Class.forName(clsName, true, clsLoader) + val constructor = cls.getConstructor(classOf[SparkConf]) + serializer = constructor.newInstance(conf).asInstanceOf[Serializer] serializers.put(clsName, serializer) } serializer diff --git a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala index ee2ae471a9..3b25f68ca8 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockFetcherIterator.scala @@ -327,7 +327,7 @@ object BlockFetcherIterator { fetchRequestsSync.put(request) } - copiers = startCopiers(conf.getOrElse("spark.shuffle.copier.threads", "6").toInt) + copiers = startCopiers(conf.getOrElse("spark.shuffle.copier.threads", "6").toInt) logInfo("Started " + fetchRequestsSync.size + " remote gets in " + Utils.getUsedTimeMs(startTime)) diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala index ffd166e93a..16ee208617 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala @@ -58,8 +58,8 @@ private[spark] class BlockManager( // If we use Netty for shuffle, start a new Netty-based shuffle sender service. private val nettyPort: Int = { - val useNetty = conf.getOrElse("spark.shuffle.use.netty", "false").toBoolean - val nettyPortConfig = conf.getOrElse("spark.shuffle.sender.port", "0").toInt + val useNetty = conf.getOrElse("spark.shuffle.use.netty", "false").toBoolean + val nettyPortConfig = conf.getOrElse("spark.shuffle.sender.port", "0").toInt if (useNetty) diskBlockManager.startShuffleBlockSender(nettyPortConfig) else 0 } @@ -72,18 +72,18 @@ private[spark] class BlockManager( // Max megabytes of data to keep in flight per reducer (to avoid over-allocating memory // for receiving shuffle outputs) val maxBytesInFlight = - conf.getOrElse("spark.reducer.maxMbInFlight", "48").toLong * 1024 * 1024 + conf.getOrElse("spark.reducer.maxMbInFlight", "48").toLong * 1024 * 1024 // Whether to compress broadcast variables that are stored - val compressBroadcast = conf.getOrElse("spark.broadcast.compress", "true").toBoolean + val compressBroadcast = conf.getOrElse("spark.broadcast.compress", "true").toBoolean // Whether to compress shuffle output that are stored - val compressShuffle = conf.getOrElse("spark.shuffle.compress", "true").toBoolean + val compressShuffle = conf.getOrElse("spark.shuffle.compress", "true").toBoolean // Whether to compress RDD partitions that are stored serialized - val compressRdds = conf.getOrElse("spark.rdd.compress", "false").toBoolean + val compressRdds = conf.getOrElse("spark.rdd.compress", "false").toBoolean - val heartBeatFrequency = BlockManager.getHeartBeatFrequencyFromSystemProperties + val heartBeatFrequency = BlockManager.getHeartBeatFrequency(conf) - val hostPort = Utils.localHostPort() + val hostPort = Utils.localHostPort(conf) val slaveActor = actorSystem.actorOf(Props(new BlockManagerSlaveActor(this)), name = "BlockManagerActor" + BlockManager.ID_GENERATOR.next) @@ -101,8 +101,11 @@ private[spark] class BlockManager( var heartBeatTask: Cancellable = null - private val metadataCleaner = new MetadataCleaner(MetadataCleanerType.BLOCK_MANAGER, this.dropOldNonBroadcastBlocks) - private val broadcastCleaner = new MetadataCleaner(MetadataCleanerType.BROADCAST_VARS, this.dropOldBroadcastBlocks) + private val metadataCleaner = new MetadataCleaner( + MetadataCleanerType.BLOCK_MANAGER, this.dropOldNonBroadcastBlocks, conf) + private val broadcastCleaner = new MetadataCleaner( + MetadataCleanerType.BROADCAST_VARS, this.dropOldBroadcastBlocks, conf) + initialize() // The compression codec to use. Note that the "lazy" val is necessary because we want to delay @@ -110,14 +113,14 @@ private[spark] class BlockManager( // program could be using a user-defined codec in a third party jar, which is loaded in // Executor.updateDependencies. When the BlockManager is initialized, user level jars hasn't been // loaded yet. - private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec() + private lazy val compressionCodec: CompressionCodec = CompressionCodec.createCodec(conf) /** * Construct a BlockManager with a memory limit set based on system properties. */ def this(execId: String, actorSystem: ActorSystem, master: BlockManagerMaster, serializer: Serializer, conf: SparkConf) = { - this(execId, actorSystem, master, serializer, BlockManager.getMaxMemoryFromSystemProperties, conf) + this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf), conf) } /** @@ -127,7 +130,7 @@ private[spark] class BlockManager( private def initialize() { master.registerBlockManager(blockManagerId, maxMemory, slaveActor) BlockManagerWorker.startBlockManagerWorker(this) - if (!BlockManager.getDisableHeartBeatsForTesting) { + if (!BlockManager.getDisableHeartBeatsForTesting(conf)) { heartBeatTask = actorSystem.scheduler.schedule(0.seconds, heartBeatFrequency.milliseconds) { heartBeat() } @@ -440,7 +443,7 @@ private[spark] class BlockManager( : BlockFetcherIterator = { val iter = - if (conf.getOrElse("spark.shuffle.use.netty", "false").toBoolean) { + if (conf.getOrElse("spark.shuffle.use.netty", "false").toBoolean) { new BlockFetcherIterator.NettyBlockFetcherIterator(this, blocksByAddress, serializer) } else { new BlockFetcherIterator.BasicBlockFetcherIterator(this, blocksByAddress, serializer) @@ -466,7 +469,7 @@ private[spark] class BlockManager( def getDiskWriter(blockId: BlockId, file: File, serializer: Serializer, bufferSize: Int) : BlockObjectWriter = { val compressStream: OutputStream => OutputStream = wrapForCompression(blockId, _) - val syncWrites = conf.getOrElse("spark.shuffle.sync", "false").toBoolean + val syncWrites = conf.getOrElse("spark.shuffle.sync", "false").toBoolean new DiskBlockObjectWriter(blockId, file, serializer, bufferSize, compressStream, syncWrites) } @@ -858,19 +861,18 @@ private[spark] class BlockManager( private[spark] object BlockManager extends Logging { - import org.apache.spark.SparkContext.{globalConf => conf} val ID_GENERATOR = new IdGenerator - def getMaxMemoryFromSystemProperties: Long = { - val memoryFraction = conf.getOrElse("spark.storage.memoryFraction", "0.66").toDouble + def getMaxMemory(conf: SparkConf): Long = { + val memoryFraction = conf.getOrElse("spark.storage.memoryFraction", "0.66").toDouble (Runtime.getRuntime.maxMemory * memoryFraction).toLong } - def getHeartBeatFrequencyFromSystemProperties: Long = - conf.getOrElse("spark.storage.blockManagerTimeoutIntervalMs", "60000").toLong / 4 + def getHeartBeatFrequency(conf: SparkConf): Long = + conf.getOrElse("spark.storage.blockManagerTimeoutIntervalMs", "60000").toLong / 4 - def getDisableHeartBeatsForTesting: Boolean = - conf.getOrElse("spark.test.disableBlockManagerHeartBeat", "false").toBoolean + def getDisableHeartBeatsForTesting(conf: SparkConf): Boolean = + conf.getOrElse("spark.test.disableBlockManagerHeartBeat", "false").toBoolean /** * Attempt to clean up a ByteBuffer if it is memory-mapped. This uses an *unsafe* Sun API that diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala index fde7d63a68..8e4a88b20a 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala @@ -31,8 +31,8 @@ private[spark] class BlockManagerMaster(var driverActor : Either[ActorRef, ActorSelection], conf: SparkConf) extends Logging { - val AKKA_RETRY_ATTEMPTS: Int = conf.getOrElse("spark.akka.num.retries", "3").toInt - val AKKA_RETRY_INTERVAL_MS: Int = conf.getOrElse("spark.akka.retry.wait", "3000").toInt + val AKKA_RETRY_ATTEMPTS: Int = conf.getOrElse("spark.akka.num.retries", "3").toInt + val AKKA_RETRY_INTERVAL_MS: Int = conf.getOrElse("spark.akka.retry.wait", "3000").toInt val DRIVER_AKKA_ACTOR_NAME = "BlockManagerMaster" diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala index 05502e4451..73a1da2de6 100644 --- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala +++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala @@ -53,7 +53,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf) extends Act initLogging() val slaveTimeout = conf.getOrElse("spark.storage.blockManagerSlaveTimeoutMs", - "" + (BlockManager.getHeartBeatFrequencyFromSystemProperties * 3)).toLong + "" + (BlockManager.getHeartBeatFrequency(conf) * 3)).toLong val checkTimeoutInterval = conf.getOrElse("spark.storage.blockManagerTimeoutIntervalMs", "60000").toLong @@ -61,7 +61,7 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf) extends Act var timeoutCheckingTask: Cancellable = null override def preStart() { - if (!BlockManager.getDisableHeartBeatsForTesting) { + if (!BlockManager.getDisableHeartBeatsForTesting(conf)) { import context.dispatcher timeoutCheckingTask = context.system.scheduler.schedule( 0.seconds, checkTimeoutInterval.milliseconds, self, ExpireDeadHosts) diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala index 8f528babd4..7697092e1b 100644 --- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala @@ -38,7 +38,7 @@ private[spark] class DiskBlockManager(shuffleManager: ShuffleBlockManager, rootD extends PathResolver with Logging { private val MAX_DIR_CREATION_ATTEMPTS: Int = 10 - private val subDirsPerLocalDir = shuffleManager.conf.getOrElse("spark.diskStore.subDirectories", "64").toInt + private val subDirsPerLocalDir = shuffleManager.conf.getOrElse("spark.diskStore.subDirectories", "64").toInt // Create one local directory for each path mentioned in spark.local.dir; then, inside this // directory, create multiple subdirectories that we will hash files into, in order to avoid diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala index 850d3178dd..f592df283a 100644 --- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala +++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockManager.scala @@ -62,12 +62,13 @@ private[spark] trait ShuffleWriterGroup { private[spark] class ShuffleBlockManager(blockManager: BlockManager) { def conf = blockManager.conf + // Turning off shuffle file consolidation causes all shuffle Blocks to get their own file. // TODO: Remove this once the shuffle file consolidation feature is stable. val consolidateShuffleFiles = - conf.getOrElse("spark.shuffle.consolidateFiles", "false").toBoolean + conf.getOrElse("spark.shuffle.consolidateFiles", "false").toBoolean - private val bufferSize = conf.getOrElse("spark.shuffle.file.buffer.kb", "100").toInt * 1024 + private val bufferSize = conf.getOrElse("spark.shuffle.file.buffer.kb", "100").toInt * 1024 /** * Contains all the state related to a particular shuffle. This includes a pool of unused @@ -82,8 +83,8 @@ class ShuffleBlockManager(blockManager: BlockManager) { type ShuffleId = Int private val shuffleStates = new TimeStampedHashMap[ShuffleId, ShuffleState] - private - val metadataCleaner = new MetadataCleaner(MetadataCleanerType.SHUFFLE_BLOCK_MANAGER, this.cleanup) + private val metadataCleaner = + new MetadataCleaner(MetadataCleanerType.SHUFFLE_BLOCK_MANAGER, this.cleanup, conf) def forMapTask(shuffleId: Int, mapId: Int, numBuckets: Int, serializer: Serializer) = { new ShuffleWriterGroup { diff --git a/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala b/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala index d52b3d8284..40734aab49 100644 --- a/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala +++ b/core/src/main/scala/org/apache/spark/storage/StoragePerfTester.scala @@ -56,7 +56,7 @@ object StoragePerfTester { def writeOutputBytes(mapId: Int, total: AtomicLong) = { val shuffle = blockManager.shuffleBlockManager.forMapTask(1, mapId, numOutputSplits, - new KryoSerializer()) + new KryoSerializer(sc.conf)) val writers = shuffle.writers for (i <- 1 to recordsPerMap) { writers(i % numOutputSplits).write(writeData) diff --git a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala index b3b3893393..dca98c6c05 100644 --- a/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala +++ b/core/src/main/scala/org/apache/spark/storage/ThreadingTest.scala @@ -22,7 +22,7 @@ import akka.actor._ import java.util.concurrent.ArrayBlockingQueue import util.Random import org.apache.spark.serializer.KryoSerializer -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} /** * This class tests the BlockManager and MemoryStore for thread safety and @@ -92,8 +92,8 @@ private[spark] object ThreadingTest { def main(args: Array[String]) { System.setProperty("spark.kryoserializer.buffer.mb", "1") val actorSystem = ActorSystem("test") - val conf = SparkContext.globalConf - val serializer = new KryoSerializer + val conf = new SparkConf() + val serializer = new KryoSerializer(conf) val blockManagerMaster = new BlockManagerMaster( Left(actorSystem.actorOf(Props(new BlockManagerMasterActor(true, conf)))), conf) val blockManager = new BlockManager( diff --git a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala index 14751e8e8e..58d47a201d 100644 --- a/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala +++ b/core/src/main/scala/org/apache/spark/ui/UIWorkloadGenerator.scala @@ -19,7 +19,7 @@ package org.apache.spark.ui import scala.util.Random -import org.apache.spark.SparkContext +import org.apache.spark.{SparkConf, SparkContext} import org.apache.spark.SparkContext._ import org.apache.spark.scheduler.SchedulingMode @@ -31,7 +31,6 @@ import org.apache.spark.scheduler.SchedulingMode */ private[spark] object UIWorkloadGenerator { - import SparkContext.{globalConf => conf} val NUM_PARTITIONS = 100 val INTER_JOB_WAIT_MS = 5000 @@ -40,14 +39,14 @@ private[spark] object UIWorkloadGenerator { println("usage: ./spark-class org.apache.spark.ui.UIWorkloadGenerator [master] [FIFO|FAIR]") System.exit(1) } - val master = args(0) - val schedulingMode = SchedulingMode.withName(args(1)) - val appName = "Spark UI Tester" + val conf = new SparkConf().setMaster(args(0)).setAppName("Spark UI tester") + + val schedulingMode = SchedulingMode.withName(args(1)) if (schedulingMode == SchedulingMode.FAIR) { - conf.set("spark.scheduler.mode", "FAIR") + conf.set("spark.scheduler.mode", "FAIR") } - val sc = new SparkContext(master, appName) + val sc = new SparkContext(conf) def setProperties(s: String) = { if(schedulingMode == SchedulingMode.FAIR) { @@ -57,11 +56,11 @@ private[spark] object UIWorkloadGenerator { } val baseData = sc.makeRDD(1 to NUM_PARTITIONS * 10, NUM_PARTITIONS) - def nextFloat() = (new Random()).nextFloat() + def nextFloat() = new Random().nextFloat() val jobs = Seq[(String, () => Long)]( ("Count", baseData.count), - ("Cache and Count", baseData.map(x => x).cache.count), + ("Cache and Count", baseData.map(x => x).cache().count), ("Single Shuffle", baseData.map(x => (x % 10, x)).reduceByKey(_ + _).count), ("Entirely failed phase", baseData.map(x => throw new Exception).count), ("Partially failed phase", { diff --git a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala index b637d37517..91fa00a66c 100644 --- a/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala +++ b/core/src/main/scala/org/apache/spark/ui/env/EnvironmentUI.scala @@ -63,7 +63,7 @@ private[spark] class EnvironmentUI(sc: SparkContext) { UIUtils.listingTable(propertyHeaders, propertyRow, otherProperties, fixedWidth = true) val classPathEntries = classPathProperty._2 - .split(sc.conf.getOrElse("path.separator", ":")) + .split(sc.conf.getOrElse("path.separator", ":")) .filterNot(e => e.isEmpty) .map(e => (e, "System Classpath")) val addedJars = sc.addedJars.iterator.toSeq.map{case (path, time) => (path, "Added By User")} diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala index f01a1380b9..6ff8e9fb14 100644 --- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala +++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala @@ -33,7 +33,7 @@ import org.apache.spark.scheduler._ */ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkListener { // How many stages to remember - val RETAINED_STAGES = sc.conf.getOrElse("spark.ui.retained_stages", "1000").toInt + val RETAINED_STAGES = sc.conf.getOrElse("spark.ui.retained_stages", "1000").toInt val DEFAULT_POOL_NAME = "default" val stageIdToPool = new HashMap[Int, String]() @@ -105,7 +105,7 @@ private[spark] class JobProgressListener(val sc: SparkContext) extends SparkList val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashSet[StageInfo]()) stages += stage } - + override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized { val sid = taskStart.task.stageId val tasksActive = stageIdToTasksActive.getOrElseUpdate(sid, new HashSet[TaskInfo]()) diff --git a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala index 76febd5702..58b26f7f12 100644 --- a/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala +++ b/core/src/main/scala/org/apache/spark/util/AkkaUtils.scala @@ -41,19 +41,19 @@ private[spark] object AkkaUtils { def createActorSystem(name: String, host: String, port: Int, indestructible: Boolean = false, conf: SparkConf): (ActorSystem, Int) = { - val akkaThreads = conf.getOrElse("spark.akka.threads", "4").toInt - val akkaBatchSize = conf.getOrElse("spark.akka.batchSize", "15").toInt + val akkaThreads = conf.getOrElse("spark.akka.threads", "4").toInt + val akkaBatchSize = conf.getOrElse("spark.akka.batchSize", "15").toInt - val akkaTimeout = conf.getOrElse("spark.akka.timeout", "100").toInt + val akkaTimeout = conf.getOrElse("spark.akka.timeout", "100").toInt - val akkaFrameSize = conf.getOrElse("spark.akka.frameSize", "10").toInt + val akkaFrameSize = conf.getOrElse("spark.akka.frameSize", "10").toInt val lifecycleEvents = - if (conf.getOrElse("spark.akka.logLifecycleEvents", "false").toBoolean) "on" else "off" + if (conf.getOrElse("spark.akka.logLifecycleEvents", "false").toBoolean) "on" else "off" - val akkaHeartBeatPauses = conf.getOrElse("spark.akka.heartbeat.pauses", "600").toInt + val akkaHeartBeatPauses = conf.getOrElse("spark.akka.heartbeat.pauses", "600").toInt val akkaFailureDetector = - conf.getOrElse("spark.akka.failure-detector.threshold", "300.0").toDouble - val akkaHeartBeatInterval = conf.getOrElse("spark.akka.heartbeat.interval", "1000").toInt + conf.getOrElse("spark.akka.failure-detector.threshold", "300.0").toDouble + val akkaHeartBeatInterval = conf.getOrElse("spark.akka.heartbeat.interval", "1000").toInt val akkaConf = ConfigFactory.parseString( s""" @@ -89,6 +89,6 @@ private[spark] object AkkaUtils { /** Returns the default Spark timeout to use for Akka ask operations. */ def askTimeout(conf: SparkConf): FiniteDuration = { - Duration.create(conf.getOrElse("spark.akka.askTimeout", "30").toLong, "seconds") + Duration.create(conf.getOrElse("spark.akka.askTimeout", "30").toLong, "seconds") } } diff --git a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala index bf71d17a21..431d88838f 100644 --- a/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala +++ b/core/src/main/scala/org/apache/spark/util/MetadataCleaner.scala @@ -18,16 +18,21 @@ package org.apache.spark.util import java.util.{TimerTask, Timer} -import org.apache.spark.{SparkContext, Logging} +import org.apache.spark.{SparkConf, SparkContext, Logging} /** * Runs a timer task to periodically clean up metadata (e.g. old files or hashtable entries) */ -class MetadataCleaner(cleanerType: MetadataCleanerType.MetadataCleanerType, cleanupFunc: (Long) => Unit) extends Logging { +class MetadataCleaner( + cleanerType: MetadataCleanerType.MetadataCleanerType, + cleanupFunc: (Long) => Unit, + conf: SparkConf) + extends Logging +{ val name = cleanerType.toString - private val delaySeconds = MetadataCleaner.getDelaySeconds + private val delaySeconds = MetadataCleaner.getDelaySeconds(conf) private val periodSeconds = math.max(10, delaySeconds / 10) private val timer = new Timer(name + " cleanup timer", true) @@ -65,22 +70,28 @@ object MetadataCleanerType extends Enumeration { def systemProperty(which: MetadataCleanerType.MetadataCleanerType) = "spark.cleaner.ttl." + which.toString } +// TODO: This mutates a Conf to set properties right now, which is kind of ugly when used in the +// initialization of StreamingContext. It's okay for users trying to configure stuff themselves. object MetadataCleaner { - private val conf = SparkContext.globalConf - // using only sys props for now : so that workers can also get to it while preserving earlier behavior. - def getDelaySeconds = conf.getOrElse("spark.cleaner.ttl", "3500").toInt //TODO: this is to fix tests for time being + def getDelaySeconds(conf: SparkConf) = { + conf.getOrElse("spark.cleaner.ttl", "3500").toInt + } - def getDelaySeconds(cleanerType: MetadataCleanerType.MetadataCleanerType): Int = { - conf.getOrElse(MetadataCleanerType.systemProperty(cleanerType), getDelaySeconds.toString).toInt + def getDelaySeconds(conf: SparkConf, cleanerType: MetadataCleanerType.MetadataCleanerType): Int = + { + conf.getOrElse(MetadataCleanerType.systemProperty(cleanerType), getDelaySeconds(conf).toString) + .toInt } - def setDelaySeconds(cleanerType: MetadataCleanerType.MetadataCleanerType, delay: Int) { + def setDelaySeconds(conf: SparkConf, cleanerType: MetadataCleanerType.MetadataCleanerType, + delay: Int) + { conf.set(MetadataCleanerType.systemProperty(cleanerType), delay.toString) } - def setDelaySeconds(delay: Int, resetAll: Boolean = true) { + def setDelaySeconds(conf: SparkConf, delay: Int, resetAll: Boolean = true) { // override for all ? - conf.set("spark.cleaner.ttl", delay.toString) + conf.set("spark.cleaner.ttl", delay.toString) if (resetAll) { for (cleanerType <- MetadataCleanerType.values) { System.clearProperty(MetadataCleanerType.systemProperty(cleanerType)) diff --git a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala index 1407c39bfb..bddb3bb735 100644 --- a/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala +++ b/core/src/main/scala/org/apache/spark/util/SizeEstimator.scala @@ -30,10 +30,10 @@ import java.lang.management.ManagementFactory import scala.collection.mutable.ArrayBuffer import it.unimi.dsi.fastutil.ints.IntOpenHashSet -import org.apache.spark.{SparkConf, SparkContext, Logging} +import org.apache.spark.{SparkEnv, SparkConf, SparkContext, Logging} /** - * Estimates the sizes of Java objects (number of bytes of memory they occupy), for use in + * Estimates the sizes of Java objects (number of bytes of memory they occupy), for use in * memory-aware caches. * * Based on the following JavaWorld article: @@ -41,7 +41,6 @@ import org.apache.spark.{SparkConf, SparkContext, Logging} */ private[spark] object SizeEstimator extends Logging { - private def conf = SparkContext.globalConf // Sizes of primitive types private val BYTE_SIZE = 1 private val BOOLEAN_SIZE = 1 @@ -90,9 +89,11 @@ private[spark] object SizeEstimator extends Logging { classInfos.put(classOf[Object], new ClassInfo(objectSize, Nil)) } - private def getIsCompressedOops : Boolean = { - if (conf.getOrElse("spark.test.useCompressedOops", null) != null) { - return conf.get("spark.test.useCompressedOops").toBoolean + private def getIsCompressedOops: Boolean = { + // This is only used by tests to override the detection of compressed oops. The test + // actually uses a system property instead of a SparkConf, so we'll stick with that. + if (System.getProperty("spark.test.useCompressedOops") != null) { + return System.getProperty("spark.test.useCompressedOops").toBoolean } try { @@ -104,7 +105,7 @@ private[spark] object SizeEstimator extends Logging { val getVMMethod = hotSpotMBeanClass.getDeclaredMethod("getVMOption", Class.forName("java.lang.String")) - val bean = ManagementFactory.newPlatformMXBeanProxy(server, + val bean = ManagementFactory.newPlatformMXBeanProxy(server, hotSpotMBeanName, hotSpotMBeanClass) // TODO: We could use reflection on the VMOption returned ? return getVMMethod.invoke(bean, "UseCompressedOops").toString.contains("true") @@ -252,7 +253,7 @@ private[spark] object SizeEstimator extends Logging { if (info != null) { return info } - + val parent = getClassInfo(cls.getSuperclass) var shellSize = parent.shellSize var pointerFields = parent.pointerFields diff --git a/core/src/main/scala/org/apache/spark/util/Utils.scala b/core/src/main/scala/org/apache/spark/util/Utils.scala index fd5888e525..b6b89cc7bb 100644 --- a/core/src/main/scala/org/apache/spark/util/Utils.scala +++ b/core/src/main/scala/org/apache/spark/util/Utils.scala @@ -36,15 +36,13 @@ import org.apache.hadoop.fs.{Path, FileSystem, FileUtil} import org.apache.spark.serializer.{DeserializationStream, SerializationStream, SerializerInstance} import org.apache.spark.deploy.SparkHadoopUtil import java.nio.ByteBuffer -import org.apache.spark.{SparkContext, SparkException, Logging} +import org.apache.spark.{SparkConf, SparkContext, SparkException, Logging} /** * Various utility methods used by Spark. */ private[spark] object Utils extends Logging { - - private lazy val conf = SparkContext.globalConf /** Serialize an object using Java serialization */ def serialize[T](o: T): Array[Byte] = { val bos = new ByteArrayOutputStream() @@ -240,9 +238,9 @@ private[spark] object Utils extends Logging { * Throws SparkException if the target file already exists and has different contents than * the requested file. */ - def fetchFile(url: String, targetDir: File) { + def fetchFile(url: String, targetDir: File, conf: SparkConf) { val filename = url.split("/").last - val tempDir = getLocalDir + val tempDir = getLocalDir(conf) val tempFile = File.createTempFile("fetchFileTemp", null, new File(tempDir)) val targetFile = new File(targetDir, filename) val uri = new URI(url) @@ -312,7 +310,7 @@ private[spark] object Utils extends Logging { * return a single directory, even though the spark.local.dir property might be a list of * multiple paths. */ - def getLocalDir: String = { + def getLocalDir(conf: SparkConf): String = { conf.getOrElse("spark.local.dir", System.getProperty("java.io.tmpdir")).split(',')(0) } @@ -398,7 +396,7 @@ private[spark] object Utils extends Logging { InetAddress.getByName(address).getHostName } - def localHostPort(): String = { + def localHostPort(conf: SparkConf): String = { val retval = conf.getOrElse("spark.hostPort", null) if (retval == null) { logErrorWithStack("spark.hostPort not set but invoking localHostPort") @@ -838,7 +836,7 @@ private[spark] object Utils extends Logging { } } - /** + /** * Timing method based on iterations that permit JVM JIT optimization. * @param numIters number of iterations * @param f function to be executed diff --git a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala index ab81bfbe55..8d7546085f 100644 --- a/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala +++ b/core/src/test/scala/org/apache/spark/io/CompressionCodecSuite.scala @@ -20,9 +20,11 @@ package org.apache.spark.io import java.io.{ByteArrayInputStream, ByteArrayOutputStream} import org.scalatest.FunSuite +import org.apache.spark.SparkConf class CompressionCodecSuite extends FunSuite { + val conf = new SparkConf(false) def testCodec(codec: CompressionCodec) { // Write 1000 integers to the output stream, compressed. @@ -43,19 +45,19 @@ class CompressionCodecSuite extends FunSuite { } test("default compression codec") { - val codec = CompressionCodec.createCodec() + val codec = CompressionCodec.createCodec(conf) assert(codec.getClass === classOf[LZFCompressionCodec]) testCodec(codec) } test("lzf compression codec") { - val codec = CompressionCodec.createCodec(classOf[LZFCompressionCodec].getName) + val codec = CompressionCodec.createCodec(conf, classOf[LZFCompressionCodec].getName) assert(codec.getClass === classOf[LZFCompressionCodec]) testCodec(codec) } test("snappy compression codec") { - val codec = CompressionCodec.createCodec(classOf[SnappyCompressionCodec].getName) + val codec = CompressionCodec.createCodec(conf, classOf[SnappyCompressionCodec].getName) assert(codec.getClass === classOf[SnappyCompressionCodec]) testCodec(codec) } diff --git a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala index 2bb827c022..3711382f2e 100644 --- a/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/scheduler/cluster/ClusterTaskSetManagerSuite.scala @@ -82,7 +82,7 @@ class FakeClusterScheduler(sc: SparkContext, liveExecutors: (String, String)* /* class ClusterTaskSetManagerSuite extends FunSuite with LocalSparkContext with Logging { import TaskLocality.{ANY, PROCESS_LOCAL, NODE_LOCAL, RACK_LOCAL} private val conf = new SparkConf - val LOCALITY_WAIT = conf.getOrElse("spark.locality.wait", "3000").toLong + val LOCALITY_WAIT = conf.getOrElse("spark.locality.wait", "3000").toLong test("TaskSet with no preferences") { sc = new SparkContext("local", "test") diff --git a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala index c016c51171..33b0148896 100644 --- a/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala +++ b/core/src/test/scala/org/apache/spark/serializer/KryoSerializerSuite.scala @@ -22,12 +22,14 @@ import scala.collection.mutable import com.esotericsoftware.kryo.Kryo import org.scalatest.FunSuite -import org.apache.spark.SharedSparkContext +import org.apache.spark.{SparkConf, SharedSparkContext} import org.apache.spark.serializer.KryoTest._ class KryoSerializerSuite extends FunSuite with SharedSparkContext { + val conf = new SparkConf(false) + test("basic types") { - val ser = (new KryoSerializer).newInstance() + val ser = new KryoSerializer(conf).newInstance() def check[T](t: T) { assert(ser.deserialize[T](ser.serialize(t)) === t) } @@ -57,7 +59,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { } test("pairs") { - val ser = (new KryoSerializer).newInstance() + val ser = new KryoSerializer(conf).newInstance() def check[T](t: T) { assert(ser.deserialize[T](ser.serialize(t)) === t) } @@ -81,7 +83,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { } test("Scala data structures") { - val ser = (new KryoSerializer).newInstance() + val ser = new KryoSerializer(conf).newInstance() def check[T](t: T) { assert(ser.deserialize[T](ser.serialize(t)) === t) } @@ -104,7 +106,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { } test("ranges") { - val ser = (new KryoSerializer).newInstance() + val ser = new KryoSerializer(conf).newInstance() def check[T](t: T) { assert(ser.deserialize[T](ser.serialize(t)) === t) // Check that very long ranges don't get written one element at a time @@ -127,7 +129,7 @@ class KryoSerializerSuite extends FunSuite with SharedSparkContext { test("custom registrator") { System.setProperty("spark.kryo.registrator", classOf[MyRegistrator].getName) - val ser = (new KryoSerializer).newInstance() + val ser = new KryoSerializer(conf).newInstance() def check[T](t: T) { assert(ser.deserialize[T](ser.serialize(t)) === t) } diff --git a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala index 4ef5538951..a0fc3445be 100644 --- a/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala +++ b/core/src/test/scala/org/apache/spark/storage/BlockManagerSuite.scala @@ -34,7 +34,7 @@ import org.apache.spark.serializer.{JavaSerializer, KryoSerializer} import org.apache.spark.{SparkConf, SparkContext} class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodTester { - private val conf = new SparkConf + private val conf = new SparkConf(false) var store: BlockManager = null var store2: BlockManager = null var actorSystem: ActorSystem = null @@ -45,7 +45,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT // Reuse a serializer across tests to avoid creating a new thread-local buffer on each test conf.set("spark.kryoserializer.buffer.mb", "1") - val serializer = new KryoSerializer + val serializer = new KryoSerializer(conf) // Implicitly convert strings to BlockIds for test clarity. implicit def StringToBlockId(value: String): BlockId = new TestBlockId(value) @@ -167,7 +167,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("master + 2 managers interaction") { store = new BlockManager("exec1", actorSystem, master, serializer, 2000, conf) - store2 = new BlockManager("exec2", actorSystem, master, new KryoSerializer, 2000, conf) + store2 = new BlockManager("exec2", actorSystem, master, new KryoSerializer(conf), 2000, conf) val peers = master.getPeers(store.blockManagerId, 1) assert(peers.size === 1, "master did not return the other manager as a peer") @@ -654,7 +654,7 @@ class BlockManagerSuite extends FunSuite with BeforeAndAfter with PrivateMethodT test("block store put failure") { // Use Java serializer so we can create an unserializable error. - store = new BlockManager("", actorSystem, master, new JavaSerializer, 1200, conf) + store = new BlockManager("", actorSystem, master, new JavaSerializer(conf), 1200, conf) // The put should fail since a1 is not serializable. class UnserializableClass diff --git a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala index a5facd5bbd..11ebdc352b 100644 --- a/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala +++ b/core/src/test/scala/org/apache/spark/util/SizeEstimatorSuite.scala @@ -140,8 +140,6 @@ class SizeEstimatorSuite test("64-bit arch with no compressed oops") { val arch = System.setProperty("os.arch", "amd64") val oops = System.setProperty("spark.test.useCompressedOops", "false") - SparkContext.globalConf.set("os.arch", "amd64") - SparkContext.globalConf.set("spark.test.useCompressedOops", "false") val initialize = PrivateMethod[Unit]('initialize) SizeEstimator invokePrivate initialize() diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala index 12c430be27..4c0de46964 100644 --- a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala +++ b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRank.scala @@ -37,7 +37,7 @@ object WikipediaPageRank { System.exit(-1) } val sparkConf = new SparkConf() - sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + sparkConf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") sparkConf.set("spark.kryo.registrator", classOf[PRKryoRegistrator].getName) val inputFile = args(0) @@ -46,7 +46,7 @@ object WikipediaPageRank { val host = args(3) val usePartitioner = args(4).toBoolean - sparkConf.setMasterUrl(host).setAppName("WikipediaPageRank") + sparkConf.setMaster(host).setAppName("WikipediaPageRank") val sc = new SparkContext(sparkConf) // Parse the Wikipedia page data into a graph diff --git a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala index 5bf0b7a24a..2cf273a702 100644 --- a/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala +++ b/examples/src/main/scala/org/apache/spark/examples/bagel/WikipediaPageRankStandalone.scala @@ -35,7 +35,7 @@ object WikipediaPageRankStandalone { System.exit(-1) } val sparkConf = new SparkConf() - sparkConf.set("spark.serializer", "spark.bagel.examples.WPRSerializer") + sparkConf.set("spark.serializer", "spark.bagel.examples.WPRSerializer") val inputFile = args(0) @@ -44,7 +44,7 @@ object WikipediaPageRankStandalone { val host = args(3) val usePartitioner = args(4).toBoolean - sparkConf.setMasterUrl(host).setAppName("WikipediaPageRankStandalone") + sparkConf.setMaster(host).setAppName("WikipediaPageRankStandalone") val sc = new SparkContext(sparkConf) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala index 2f2d106f86..8b27ecf82c 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/recommendation/ALS.scala @@ -579,12 +579,12 @@ object ALS { val alpha = if (args.length >= 8) args(7).toDouble else 1 val blocks = if (args.length == 9) args(8).toInt else -1 val sc = new SparkContext(master, "ALS") - sc.conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") + sc.conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer") sc.conf.set("spark.kryo.registrator", classOf[ALSRegistrator].getName) - sc.conf.set("spark.kryo.referenceTracking", "false") - sc.conf.set("spark.kryoserializer.buffer.mb", "8") - sc.conf.set("spark.locality.wait", "10000") - + sc.conf.set("spark.kryo.referenceTracking", "false") + sc.conf.set("spark.kryoserializer.buffer.mb", "8") + sc.conf.set("spark.locality.wait", "10000") + val ratings = sc.textFile(ratingsFile).map { line => val fields = line.split(',') Rating(fields(0).toInt, fields(1).toInt, fields(2).toDouble) diff --git a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 433268a1dd..91e35e2d34 100644 --- a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -45,7 +45,7 @@ import org.apache.spark.util.Utils class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) extends Logging { def this(args: ApplicationMasterArguments) = this(args, new Configuration()) - + private var rpc: YarnRPC = YarnRPC.create(conf) private val yarnConf: YarnConfiguration = new YarnConfiguration(conf) private var appAttemptId: ApplicationAttemptId = _ @@ -81,12 +81,12 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e // Workaround until hadoop moves to something which has // https://issues.apache.org/jira/browse/HADOOP-8406 - fixed in (2.0.2-alpha but no 0.23 line) // org.apache.hadoop.io.compress.CompressionCodecFactory.getCodecClasses(conf) - + ApplicationMaster.register(this) // Start the user's JAR userThread = startUserClass() - + // This a bit hacky, but we need to wait until the spark.driver.port property has // been set by the Thread executing the user class. waitForSparkMaster() @@ -99,7 +99,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e // Allocate all containers allocateWorkers() - // Wait for the user class to Finish + // Wait for the user class to Finish userThread.join() System.exit(0) @@ -119,7 +119,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e } localDirs } - + private def getApplicationAttemptId(): ApplicationAttemptId = { val envs = System.getenv() val containerIdString = envs.get(ApplicationConstants.Environment.CONTAINER_ID.name()) @@ -128,17 +128,17 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e logInfo("ApplicationAttemptId: " + appAttemptId) appAttemptId } - + private def registerApplicationMaster(): RegisterApplicationMasterResponse = { logInfo("Registering the ApplicationMaster") amClient.registerApplicationMaster(Utils.localHostName(), 0, uiAddress) } - + private def waitForSparkMaster() { logInfo("Waiting for Spark driver to be reachable.") var driverUp = false var tries = 0 - val numTries = conf.getOrElse("spark.yarn.applicationMaster.waitTries", "10").toInt + val numTries = conf.getOrElse("spark.yarn.applicationMaster.waitTries", "10").toInt while (!driverUp && tries < numTries) { val driverHost = conf.get("spark.driver.host") val driverPort = conf.get("spark.driver.port") @@ -199,7 +199,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e ApplicationMaster.sparkContextRef.synchronized { var numTries = 0 val waitTime = 10000L - val maxNumTries = conf.getOrElse("spark.yarn.ApplicationMaster.waitTries", "10").toInt + val maxNumTries = conf.getOrElse("spark.yarn.ApplicationMaster.waitTries", "10").toInt while (ApplicationMaster.sparkContextRef.get() == null && numTries < maxNumTries) { logInfo("Waiting for Spark context initialization ... " + numTries) numTries = numTries + 1 @@ -214,7 +214,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e yarnConf, amClient, appAttemptId, - args, + args, sparkContext.preferredNodeLocationData) } else { logWarning("Unable to retrieve SparkContext inspite of waiting for %d, maxNumTries = %d". @@ -265,7 +265,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e // we want to be reasonably responsive without causing too many requests to RM. val schedulerInterval = - conf.getOrElse("spark.yarn.scheduler.heartbeat.interval-ms", "5000").toLong + conf.getOrElse("spark.yarn.scheduler.heartbeat.interval-ms", "5000").toLong // must be <= timeoutInterval / 2. val interval = math.min(timeoutInterval / 2, schedulerInterval) @@ -314,11 +314,11 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e for (container <- containers) { logInfo("Launching shell command on a new container." + ", containerId=" + container.getId() - + ", containerNode=" + container.getNodeId().getHost() + + ", containerNode=" + container.getNodeId().getHost() + ":" + container.getNodeId().getPort() + ", containerNodeURI=" + container.getNodeHttpAddress() + ", containerState" + container.getState() - + ", containerResourceMemory" + + ", containerResourceMemory" + container.getResource().getMemory()) } } @@ -338,12 +338,12 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e } /** - * Clean up the staging directory. + * Clean up the staging directory. */ - private def cleanupStagingDir() { + private def cleanupStagingDir() { var stagingDirPath: Path = null try { - val preserveFiles = conf.getOrElse("spark.yarn.preserve.staging.files", "false").toBoolean + val preserveFiles = conf.getOrElse("spark.yarn.preserve.staging.files", "false").toBoolean if (!preserveFiles) { stagingDirPath = new Path(System.getenv("SPARK_YARN_STAGING_DIR")) if (stagingDirPath == null) { @@ -359,7 +359,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e } } - // The shutdown hook that runs when a signal is received AND during normal close of the JVM. + // The shutdown hook that runs when a signal is received AND during normal close of the JVM. class AppMasterShutdownHook(appMaster: ApplicationMaster) extends Runnable { def run() { @@ -415,18 +415,18 @@ object ApplicationMaster { // Note that this will unfortunately not properly clean up the staging files because it gets // called too late, after the filesystem is already shutdown. if (modified) { - Runtime.getRuntime().addShutdownHook(new Thread with Logging { + Runtime.getRuntime().addShutdownHook(new Thread with Logging { // This is not only logs, but also ensures that log system is initialized for this instance // when we are actually 'run'-ing. logInfo("Adding shutdown hook for context " + sc) - override def run() { - logInfo("Invoking sc stop from shutdown hook") - sc.stop() + override def run() { + logInfo("Invoking sc stop from shutdown hook") + sc.stop() // Best case ... for (master <- applicationMasters) { master.finishApplicationMaster(FinalApplicationStatus.SUCCEEDED) } - } + } } ) } diff --git a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index a322f60864..963b5b88be 100644 --- a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -40,7 +40,7 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{Apps, Records} -import org.apache.spark.Logging +import org.apache.spark.Logging import org.apache.spark.util.Utils import org.apache.spark.deploy.SparkHadoopUtil @@ -150,7 +150,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl queueInfo.getChildQueues.size)) } - def verifyClusterResources(app: GetNewApplicationResponse) = { + def verifyClusterResources(app: GetNewApplicationResponse) = { val maxMem = app.getMaximumResourceCapability().getMemory() logInfo("Max mem capabililty of a single resource in this cluster " + maxMem) @@ -221,7 +221,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl FileUtil.copy(remoteFs, originalPath, fs, newPath, false, conf) fs.setReplication(newPath, replication) if (setPerms) fs.setPermission(newPath, new FsPermission(APP_FILE_PERMISSION)) - } + } // Resolve any symlinks in the URI path so using a "current" symlink to point to a specific // version shows the specific version in the distributed cache configuration val qualPath = fs.makeQualified(newPath) @@ -244,7 +244,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl } } val dst = new Path(fs.getHomeDirectory(), appStagingDir) - val replication = conf.getOrElse("spark.yarn.submit.file.replication", "3").toShort + val replication = conf.getOrElse("spark.yarn.submit.file.replication", "3").toShort if (UserGroupInformation.isSecurityEnabled()) { val dstFs = dst.getFileSystem(conf) @@ -269,7 +269,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl } val setPermissions = if (destName.equals(Client.APP_JAR)) true else false val destPath = copyRemoteFile(dst, new Path(localURI), replication, setPermissions) - distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, + distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, destName, statCache) } } @@ -283,7 +283,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl val destPath = copyRemoteFile(dst, localPath, replication) // Only add the resource to the Spark ApplicationMaster. val appMasterOnly = true - distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, + distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, linkname, statCache, appMasterOnly) } } @@ -295,7 +295,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl val localPath = new Path(localURI) val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName()) val destPath = copyRemoteFile(dst, localPath, replication) - distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, + distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, linkname, statCache) } } @@ -307,7 +307,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl val localPath = new Path(localURI) val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName()) val destPath = copyRemoteFile(dst, localPath, replication) - distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, + distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, linkname, statCache) } } @@ -317,7 +317,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl } def setupLaunchEnv( - localResources: HashMap[String, LocalResource], + localResources: HashMap[String, LocalResource], stagingDir: String): HashMap[String, String] = { logInfo("Setting up the launch environment") val log4jConfLocalRes = localResources.getOrElse(Client.LOG4J_PROP, null) @@ -406,11 +406,11 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl } val commands = List[String]( - javaCommand + + javaCommand + " -server " + JAVA_OPTS + " " + args.amClass + - " --class " + args.userClass + + " --class " + args.userClass + " --jar " + args.userJar + userArgsToString(args) + " --worker-memory " + args.workerMemory + @@ -436,7 +436,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl super.submitApplication(appContext) } - def monitorApplication(appId: ApplicationId): Boolean = { + def monitorApplication(appId: ApplicationId): Boolean = { while (true) { Thread.sleep(1000) val report = super.getApplicationReport(appId) @@ -458,7 +458,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl val state = report.getYarnApplicationState() val dsStatus = report.getFinalApplicationStatus() - if (state == YarnApplicationState.FINISHED || + if (state == YarnApplicationState.FINISHED || state == YarnApplicationState.FAILED || state == YarnApplicationState.KILLED) { return true @@ -495,25 +495,25 @@ object Client { Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$()) // If log4j present, ensure ours overrides all others if (addLog4j) { - Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() + + Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() + Path.SEPARATOR + LOG4J_PROP) } // Normally the users app.jar is last in case conflicts with spark jars - val userClasspathFirst = conf.getOrElse("spark.yarn.user.classpath.first", "false") + val userClasspathFirst = conf.getOrElse("spark.yarn.user.classpath.first", "false") .toBoolean if (userClasspathFirst) { - Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() + + Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() + Path.SEPARATOR + APP_JAR) } - Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() + + Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() + Path.SEPARATOR + SPARK_JAR) Client.populateHadoopClasspath(conf, env) if (!userClasspathFirst) { - Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() + + Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() + Path.SEPARATOR + APP_JAR) } - Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() + + Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() + Path.SEPARATOR + "*") } } diff --git a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 41ac292249..1a9bb97b3e 100644 --- a/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/new-yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -35,7 +35,7 @@ class ClientArguments(val args: Array[String]) { var workerMemory = 1024 // MB var workerCores = 1 var numWorkers = 2 - var amQueue = conf.getOrElse("QUEUE", "default") + var amQueue = conf.getOrElse("QUEUE", "default") var amMemory: Int = 512 // MB var amClass: String = "org.apache.spark.deploy.yarn.ApplicationMaster" var appName: String = "Spark" diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala index b2f499e637..f108c70f21 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala @@ -35,6 +35,7 @@ import java.lang.{Class => jClass} import scala.reflect.api.{Mirror, TypeCreator, Universe => ApiUniverse} import org.apache.spark.Logging +import org.apache.spark.SparkConf import org.apache.spark.SparkContext /** The Scala interactive shell. It provides a read-eval-print loop @@ -929,7 +930,7 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } def createSparkContext(): SparkContext = { - val uri = System.getenv("SPARK_EXECUTOR_URI") + val execUri = System.getenv("SPARK_EXECUTOR_URI") val master = this.master match { case Some(m) => m case None => { @@ -938,11 +939,16 @@ class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter, } } val jars = SparkILoop.getAddedJars.map(new java.io.File(_).getAbsolutePath) - sparkContext = new SparkContext(master, "Spark shell", System.getenv("SPARK_HOME"), jars) - if (uri != null) { - sparkContext.conf.set("spark.executor.uri", uri) + val conf = new SparkConf() + .setMaster(master) + .setAppName("Spark shell") + .setSparkHome(System.getenv("SPARK_HOME")) + .setJars(jars) + .set("spark.repl.class.uri", intp.classServer.uri) + if (execUri != null) { + conf.set("spark.executor.uri", execUri) } - sparkContext.conf.set("spark.repl.class.uri", intp.classServer.uri) + sparkContext = new SparkContext(conf) echo("Created spark context..") sparkContext } diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala index 0d412e4478..a993083164 100644 --- a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala +++ b/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala @@ -34,7 +34,7 @@ import scala.tools.reflect.StdRuntimeTags._ import scala.util.control.ControlThrowable import util.stackTraceString -import org.apache.spark.{SparkContext, HttpServer, SparkEnv, Logging} +import org.apache.spark.{HttpServer, SparkConf, Logging} import org.apache.spark.util.Utils // /** directory to save .class files to */ @@ -89,7 +89,7 @@ import org.apache.spark.util.Utils /** Local directory to save .class files too */ val outputDir = { val tmp = System.getProperty("java.io.tmpdir") - val rootDir = SparkContext.globalConf.getOrElse("spark.repl.classdir", tmp) + val rootDir = new SparkConf().getOrElse("spark.repl.classdir", tmp) Utils.createTempDir(rootDir) } if (SPARK_DEBUG_REPL) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala index b8e1427a21..f106bba678 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Checkpoint.scala @@ -24,7 +24,7 @@ import java.util.concurrent.RejectedExecutionException import org.apache.hadoop.fs.Path import org.apache.hadoop.conf.Configuration -import org.apache.spark.Logging +import org.apache.spark.{SparkConf, Logging} import org.apache.spark.io.CompressionCodec import org.apache.spark.util.MetadataCleaner @@ -36,12 +36,11 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) val framework = ssc.sc.appName val sparkHome = ssc.sc.getSparkHome.getOrElse(null) val jars = ssc.sc.jars - val environment = ssc.sc.environment val graph = ssc.graph val checkpointDir = ssc.checkpointDir val checkpointDuration = ssc.checkpointDuration val pendingTimes = ssc.scheduler.jobManager.getPendingTimes() - val delaySeconds = MetadataCleaner.getDelaySeconds + val delaySeconds = MetadataCleaner.getDelaySeconds(ssc.conf) val sparkConf = ssc.sc.conf def validate() { @@ -58,7 +57,7 @@ class Checkpoint(@transient ssc: StreamingContext, val checkpointTime: Time) * Convenience class to speed up the writing of graph checkpoint to file */ private[streaming] -class CheckpointWriter(checkpointDir: String) extends Logging { +class CheckpointWriter(conf: SparkConf, checkpointDir: String) extends Logging { val file = new Path(checkpointDir, "graph") // The file to which we actually write - and then "move" to file. private val writeFile = new Path(file.getParent, file.getName + ".next") @@ -66,14 +65,14 @@ class CheckpointWriter(checkpointDir: String) extends Logging { private var stopped = false - val conf = new Configuration() - var fs = file.getFileSystem(conf) + val hadoopConf = new Configuration() + var fs = file.getFileSystem(hadoopConf) val maxAttempts = 3 val executor = Executors.newFixedThreadPool(1) - private val compressionCodec = CompressionCodec.createCodec() + private val compressionCodec = CompressionCodec.createCodec(conf) - // Removed code which validates whether there is only one CheckpointWriter per path 'file' since + // Removed code which validates whether there is only one CheckpointWriter per path 'file' since // I did not notice any errors - reintroduce it ? class CheckpointWriteHandler(checkpointTime: Time, bytes: Array[Byte]) extends Runnable { @@ -142,11 +141,12 @@ class CheckpointWriter(checkpointDir: String) extends Logging { private[streaming] object CheckpointReader extends Logging { - def read(path: String): Checkpoint = { + def read(conf: SparkConf, path: String): Checkpoint = { val fs = new Path(path).getFileSystem(new Configuration()) - val attempts = Seq(new Path(path, "graph"), new Path(path, "graph.bk"), new Path(path), new Path(path + ".bk")) + val attempts = Seq( + new Path(path, "graph"), new Path(path, "graph.bk"), new Path(path), new Path(path + ".bk")) - val compressionCodec = CompressionCodec.createCodec() + val compressionCodec = CompressionCodec.createCodec(conf) attempts.foreach(file => { if (fs.exists(file)) { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala index 329d2b5835..8005202500 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/DStream.scala @@ -213,7 +213,7 @@ abstract class DStream[T: ClassTag] ( checkpointDuration + "). Please set it to higher than " + checkpointDuration + "." ) - val metadataCleanerDelay = MetadataCleaner.getDelaySeconds + val metadataCleanerDelay = MetadataCleaner.getDelaySeconds(ssc.conf) logInfo("metadataCleanupDelay = " + metadataCleanerDelay) assert( metadataCleanerDelay < 0 || rememberDuration.milliseconds < metadataCleanerDelay * 1000, diff --git a/streaming/src/main/scala/org/apache/spark/streaming/Scheduler.scala b/streaming/src/main/scala/org/apache/spark/streaming/Scheduler.scala index 1d23713c80..82ed6bed69 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/Scheduler.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/Scheduler.scala @@ -26,10 +26,10 @@ class Scheduler(ssc: StreamingContext) extends Logging { initLogging() - val concurrentJobs = ssc.sc.conf.getOrElse("spark.streaming.concurrentJobs", "1").toInt + val concurrentJobs = ssc.sc.conf.getOrElse("spark.streaming.concurrentJobs", "1").toInt val jobManager = new JobManager(ssc, concurrentJobs) val checkpointWriter = if (ssc.checkpointDuration != null && ssc.checkpointDir != null) { - new CheckpointWriter(ssc.checkpointDir) + new CheckpointWriter(ssc.conf, ssc.checkpointDir) } else { null } @@ -50,13 +50,13 @@ class Scheduler(ssc: StreamingContext) extends Logging { } logInfo("Scheduler started") } - + def stop() = synchronized { timer.stop() jobManager.stop() if (checkpointWriter != null) checkpointWriter.stop() ssc.graph.stop() - logInfo("Scheduler stopped") + logInfo("Scheduler stopped") } private def startFirstTime() { @@ -73,7 +73,7 @@ class Scheduler(ssc: StreamingContext) extends Logging { // or if the property is defined set it to that time if (clock.isInstanceOf[ManualClock]) { val lastTime = ssc.initialCheckpoint.checkpointTime.milliseconds - val jumpTime = ssc.sc.conf.getOrElse("spark.streaming.manualClock.jump", "0").toLong + val jumpTime = ssc.sc.conf.getOrElse("spark.streaming.manualClock.jump", "0").toLong clock.asInstanceOf[ManualClock].setTime(lastTime + jumpTime) } diff --git a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala index 76744223e1..079841ad9d 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/StreamingContext.scala @@ -87,13 +87,12 @@ class StreamingContext private ( null, batchDuration) } - /** * Re-create a StreamingContext from a checkpoint file. * @param path Path either to the directory that was specified as the checkpoint directory, or * to the checkpoint file 'graph' or 'graph.bk'. */ - def this(path: String) = this(null, CheckpointReader.read(path), null) + def this(path: String) = this(null, CheckpointReader.read(new SparkConf(), path), null) initLogging() @@ -102,11 +101,13 @@ class StreamingContext private ( "both SparkContext and checkpoint as null") } - if(cp_ != null && cp_.delaySeconds >= 0 && MetadataCleaner.getDelaySeconds < 0) { - MetadataCleaner.setDelaySeconds(cp_.delaySeconds) + private val conf_ = Option(sc_).map(_.conf).getOrElse(cp_.sparkConf) + + if(cp_ != null && cp_.delaySeconds >= 0 && MetadataCleaner.getDelaySeconds(conf_) < 0) { + MetadataCleaner.setDelaySeconds(conf_, cp_.delaySeconds) } - if (MetadataCleaner.getDelaySeconds < 0) { + if (MetadataCleaner.getDelaySeconds(conf_) < 0) { throw new SparkException("Spark Streaming cannot be used without setting spark.cleaner.ttl; " + "set this property before creating a SparkContext (use SPARK_JAVA_OPTS for the shell)") } @@ -115,12 +116,14 @@ class StreamingContext private ( protected[streaming] val sc: SparkContext = { if (isCheckpointPresent) { - new SparkContext(cp_.sparkConf, cp_.environment) + new SparkContext(cp_.sparkConf) } else { sc_ } } + protected[streaming] val conf = sc.conf + protected[streaming] val env = SparkEnv.get protected[streaming] val graph: DStreamGraph = { @@ -579,13 +582,15 @@ object StreamingContext { appName: String, sparkHome: String, jars: Seq[String], - environment: Map[String, String]): SparkContext = { + environment: Map[String, String]): SparkContext = + { + val sc = new SparkContext(master, appName, sparkHome, jars, environment) // Set the default cleaner delay to an hour if not already set. // This should be sufficient for even 1 second interval. - if (MetadataCleaner.getDelaySeconds < 0) { - MetadataCleaner.setDelaySeconds(3600) + if (MetadataCleaner.getDelaySeconds(sc.conf) < 0) { + MetadataCleaner.setDelaySeconds(sc.conf, 3600) } - new SparkContext(master, appName, sparkHome, jars, environment) + sc } protected[streaming] def rddToFileName[T](prefix: String, suffix: String, time: Time): String = { diff --git a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala index 8bf761b8cb..bd607f9d18 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/dstream/NetworkInputDStream.scala @@ -175,8 +175,8 @@ abstract class NetworkReceiver[T: ClassTag]() extends Serializable with Logging /** A helper actor that communicates with the NetworkInputTracker */ private class NetworkReceiverActor extends Actor { logInfo("Attempting to register with tracker") - val ip = env.conf.getOrElse("spark.driver.host", "localhost") - val port = env.conf.getOrElse("spark.driver.port", "7077").toInt + val ip = env.conf.getOrElse("spark.driver.host", "localhost") + val port = env.conf.getOrElse("spark.driver.port", "7077").toInt val url = "akka.tcp://spark@%s:%s/user/NetworkInputTracker".format(ip, port) val tracker = env.actorSystem.actorSelection(url) val timeout = 5.seconds @@ -213,7 +213,7 @@ abstract class NetworkReceiver[T: ClassTag]() extends Serializable with Logging case class Block(id: BlockId, buffer: ArrayBuffer[T], metadata: Any = null) val clock = new SystemClock() - val blockInterval = env.conf.getOrElse("spark.streaming.blockInterval", "200").toLong + val blockInterval = env.conf.getOrElse("spark.streaming.blockInterval", "200").toLong val blockIntervalTimer = new RecurringTimer(clock, blockInterval, updateCurrentBuffer) val blockStorageLevel = storageLevel val blocksForPushing = new ArrayBlockingQueue[Block](1000) diff --git a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala index fc8655a083..6585d494a6 100644 --- a/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala +++ b/streaming/src/main/scala/org/apache/spark/streaming/util/RawTextSender.scala @@ -20,7 +20,7 @@ package org.apache.spark.streaming.util import java.nio.ByteBuffer import org.apache.spark.util.{RateLimitedOutputStream, IntParam} import java.net.ServerSocket -import org.apache.spark.{Logging} +import org.apache.spark.{SparkConf, Logging} import it.unimi.dsi.fastutil.io.FastByteArrayOutputStream import scala.io.Source import java.io.IOException @@ -42,7 +42,7 @@ object RawTextSender extends Logging { // Repeat the input data multiple times to fill in a buffer val lines = Source.fromFile(file).getLines().toArray val bufferStream = new FastByteArrayOutputStream(blockSize + 1000) - val ser = new KryoSerializer().newInstance() + val ser = new KryoSerializer(new SparkConf()).newInstance() val serStream = ser.serializeStream(bufferStream) var i = 0 while (bufferStream.position < blockSize) { diff --git a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala index da8f135dd7..8c16daa21c 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/InputStreamsSuite.scala @@ -52,9 +52,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { override def checkpointDir = "checkpoint" - before { - conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") - } + conf.set("spark.streaming.clock", "org.apache.spark.streaming.util.ManualClock") after { // To avoid Akka rebinding to the same port, since it doesn't unbind immediately on shutdown @@ -70,7 +68,7 @@ class InputStreamsSuite extends TestSuiteBase with BeforeAndAfter { // Set up the streaming context and input streams val ssc = new StreamingContext(new SparkContext(conf), batchDuration) val networkStream = ssc.socketTextStream("localhost", testServer.port, StorageLevel.MEMORY_AND_DISK) - val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String ]] + val outputBuffer = new ArrayBuffer[Seq[String]] with SynchronizedBuffer[Seq[String]] val outputStream = new TestOutputStream(networkStream, outputBuffer) def output = outputBuffer.flatMap(x => x) ssc.registerOutputStream(outputStream) diff --git a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala index d1cab0c609..a265284bff 100644 --- a/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala +++ b/streaming/src/test/scala/org/apache/spark/streaming/TestSuiteBase.scala @@ -130,7 +130,11 @@ trait TestSuiteBase extends FunSuite with BeforeAndAfter with Logging { // Whether to actually wait in real time before changing manual clock def actuallyWait = false - def conf = new SparkConf().setMasterUrl(master).setAppName(framework).set("spark.cleaner.ttl", "3600") + val conf = new SparkConf() + .setMaster(master) + .setAppName(framework) + .set("spark.cleaner.ttl", "3600") + /** * Set up required DStreams to test the DStream operation using the two sequences * of input collections. diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala index 1dd38dd13e..dc9228180f 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ApplicationMaster.scala @@ -43,7 +43,7 @@ import org.apache.spark.util.Utils class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) extends Logging { def this(args: ApplicationMasterArguments) = this(args, new Configuration()) - + private var rpc: YarnRPC = YarnRPC.create(conf) private var resourceManager: AMRMProtocol = _ private var appAttemptId: ApplicationAttemptId = _ @@ -68,7 +68,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e // Use priority 30 as its higher then HDFS. Its same priority as MapReduce is using. ShutdownHookManager.get().addShutdownHook(new AppMasterShutdownHook(this), 30) - + appAttemptId = getApplicationAttemptId() isLastAMRetry = appAttemptId.getAttemptId() >= maxAppAttempts resourceManager = registerWithResourceManager() @@ -92,11 +92,11 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e // } //} // org.apache.hadoop.io.compress.CompressionCodecFactory.getCodecClasses(conf) - + ApplicationMaster.register(this) // Start the user's JAR userThread = startUserClass() - + // This a bit hacky, but we need to wait until the spark.driver.port property has // been set by the Thread executing the user class. waitForSparkMaster() @@ -105,11 +105,11 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e // Do this after spark master is up and SparkContext is created so that we can register UI Url val appMasterResponse: RegisterApplicationMasterResponse = registerApplicationMaster() - + // Allocate all containers allocateWorkers() - - // Wait for the user class to Finish + + // Wait for the user class to Finish userThread.join() System.exit(0) @@ -129,7 +129,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e } localDirs } - + private def getApplicationAttemptId(): ApplicationAttemptId = { val envs = System.getenv() val containerIdString = envs.get(ApplicationConstants.AM_CONTAINER_ID_ENV) @@ -138,7 +138,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e logInfo("ApplicationAttemptId: " + appAttemptId) appAttemptId } - + private def registerWithResourceManager(): AMRMProtocol = { val rmAddress = NetUtils.createSocketAddr(yarnConf.get( YarnConfiguration.RM_SCHEDULER_ADDRESS, @@ -146,26 +146,26 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e logInfo("Connecting to ResourceManager at " + rmAddress) rpc.getProxy(classOf[AMRMProtocol], rmAddress, conf).asInstanceOf[AMRMProtocol] } - + private def registerApplicationMaster(): RegisterApplicationMasterResponse = { logInfo("Registering the ApplicationMaster") val appMasterRequest = Records.newRecord(classOf[RegisterApplicationMasterRequest]) .asInstanceOf[RegisterApplicationMasterRequest] appMasterRequest.setApplicationAttemptId(appAttemptId) // Setting this to master host,port - so that the ApplicationReport at client has some - // sensible info. + // sensible info. // Users can then monitor stderr/stdout on that node if required. appMasterRequest.setHost(Utils.localHostName()) appMasterRequest.setRpcPort(0) appMasterRequest.setTrackingUrl(uiAddress) resourceManager.registerApplicationMaster(appMasterRequest) } - + private def waitForSparkMaster() { logInfo("Waiting for spark driver to be reachable.") var driverUp = false var tries = 0 - val numTries = conf.getOrElse("spark.yarn.applicationMaster.waitTries", "10").toInt + val numTries = conf.getOrElse("spark.yarn.applicationMaster.waitTries", "10").toInt while(!driverUp && tries < numTries) { val driverHost = conf.get("spark.driver.host") val driverPort = conf.get("spark.driver.port") @@ -226,7 +226,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e ApplicationMaster.sparkContextRef.synchronized { var count = 0 val waitTime = 10000L - val numTries = conf.getOrElse("spark.yarn.ApplicationMaster.waitTries", "10").toInt + val numTries = conf.getOrElse("spark.yarn.ApplicationMaster.waitTries", "10").toInt while (ApplicationMaster.sparkContextRef.get() == null && count < numTries) { logInfo("Waiting for spark context initialization ... " + count) count = count + 1 @@ -241,8 +241,8 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e yarnConf, resourceManager, appAttemptId, - args, - sparkContext.preferredNodeLocationData) + args, + sparkContext.preferredNodeLocationData) } else { logWarning("Unable to retrieve sparkContext inspite of waiting for %d, numTries = %d". format(count * waitTime, numTries)) @@ -294,7 +294,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e // we want to be reasonably responsive without causing too many requests to RM. val schedulerInterval = - conf.getOrElse("spark.yarn.scheduler.heartbeat.interval-ms", "5000").toLong + conf.getOrElse("spark.yarn.scheduler.heartbeat.interval-ms", "5000").toLong // must be <= timeoutInterval / 2. val interval = math.min(timeoutInterval / 2, schedulerInterval) @@ -342,11 +342,11 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e for (container <- containers) { logInfo("Launching shell command on a new container." + ", containerId=" + container.getId() - + ", containerNode=" + container.getNodeId().getHost() + + ", containerNode=" + container.getNodeId().getHost() + ":" + container.getNodeId().getPort() + ", containerNodeURI=" + container.getNodeHttpAddress() + ", containerState" + container.getState() - + ", containerResourceMemory" + + ", containerResourceMemory" + container.getResource().getMemory()) } } @@ -372,12 +372,12 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e } /** - * Clean up the staging directory. + * Clean up the staging directory. */ - private def cleanupStagingDir() { + private def cleanupStagingDir() { var stagingDirPath: Path = null try { - val preserveFiles = conf.getOrElse("spark.yarn.preserve.staging.files", "false").toBoolean + val preserveFiles = conf.getOrElse("spark.yarn.preserve.staging.files", "false").toBoolean if (!preserveFiles) { stagingDirPath = new Path(System.getenv("SPARK_YARN_STAGING_DIR")) if (stagingDirPath == null) { @@ -393,7 +393,7 @@ class ApplicationMaster(args: ApplicationMasterArguments, conf: Configuration) e } } - // The shutdown hook that runs when a signal is received AND during normal close of the JVM. + // The shutdown hook that runs when a signal is received AND during normal close of the JVM. class AppMasterShutdownHook(appMaster: ApplicationMaster) extends Runnable { def run() { @@ -446,18 +446,18 @@ object ApplicationMaster { // Note that this will unfortunately not properly clean up the staging files because it gets // called too late, after the filesystem is already shutdown. if (modified) { - Runtime.getRuntime().addShutdownHook(new Thread with Logging { + Runtime.getRuntime().addShutdownHook(new Thread with Logging { // This is not only logs, but also ensures that log system is initialized for this instance // when we are actually 'run'-ing. logInfo("Adding shutdown hook for context " + sc) - override def run() { - logInfo("Invoking sc stop from shutdown hook") - sc.stop() + override def run() { + logInfo("Invoking sc stop from shutdown hook") + sc.stop() // Best case ... for (master <- applicationMasters) { master.finishApplicationMaster(FinalApplicationStatus.SUCCEEDED) } - } + } } ) } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala index 29892e98e3..cc150888eb 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/Client.scala @@ -40,7 +40,7 @@ import org.apache.hadoop.yarn.conf.YarnConfiguration import org.apache.hadoop.yarn.ipc.YarnRPC import org.apache.hadoop.yarn.util.{Apps, Records} -import org.apache.spark.Logging +import org.apache.spark.Logging import org.apache.spark.util.Utils import org.apache.spark.deploy.SparkHadoopUtil @@ -59,7 +59,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl val STAGING_DIR_PERMISSION: FsPermission = FsPermission.createImmutable(0700:Short) // App files are world-wide readable and owner writable -> rw-r--r-- - val APP_FILE_PERMISSION: FsPermission = FsPermission.createImmutable(0644:Short) + val APP_FILE_PERMISSION: FsPermission = FsPermission.createImmutable(0644:Short) // for client user who want to monitor app status by itself. def runApp() = { @@ -103,7 +103,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl "greater than: " + YarnAllocationHandler.MEMORY_OVERHEAD), (args.workerMemory <= YarnAllocationHandler.MEMORY_OVERHEAD) -> ("Error: Worker memory size " + "must be greater than: " + YarnAllocationHandler.MEMORY_OVERHEAD) - ).foreach { case(cond, errStr) => + ).foreach { case(cond, errStr) => if (cond) { logError(errStr) args.printUsageAndExit(1) @@ -130,7 +130,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl queueInfo.getChildQueues.size)) } - def verifyClusterResources(app: GetNewApplicationResponse) = { + def verifyClusterResources(app: GetNewApplicationResponse) = { val maxMem = app.getMaximumResourceCapability().getMemory() logInfo("Max mem capabililty of a single resource in this cluster " + maxMem) @@ -146,7 +146,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl } // We could add checks to make sure the entire cluster has enough resources but that involves - // getting all the node reports and computing ourselves + // getting all the node reports and computing ourselves } def createApplicationSubmissionContext(appId: ApplicationId): ApplicationSubmissionContext = { @@ -207,7 +207,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl FileUtil.copy(remoteFs, originalPath, fs, newPath, false, conf) fs.setReplication(newPath, replication) if (setPerms) fs.setPermission(newPath, new FsPermission(APP_FILE_PERMISSION)) - } + } // Resolve any symlinks in the URI path so using a "current" symlink to point to a specific // version shows the specific version in the distributed cache configuration val qualPath = fs.makeQualified(newPath) @@ -230,7 +230,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl } } val dst = new Path(fs.getHomeDirectory(), appStagingDir) - val replication = conf.getOrElse("spark.yarn.submit.file.replication", "3").toShort + val replication = conf.getOrElse("spark.yarn.submit.file.replication", "3").toShort if (UserGroupInformation.isSecurityEnabled()) { val dstFs = dst.getFileSystem(conf) @@ -241,7 +241,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl val statCache: Map[URI, FileStatus] = HashMap[URI, FileStatus]() - Map(Client.SPARK_JAR -> System.getenv("SPARK_JAR"), Client.APP_JAR -> args.userJar, + Map(Client.SPARK_JAR -> System.getenv("SPARK_JAR"), Client.APP_JAR -> args.userJar, Client.LOG4J_PROP -> System.getenv("SPARK_LOG4J_CONF")) .foreach { case(destName, _localPath) => val localPath: String = if (_localPath != null) _localPath.trim() else "" @@ -253,7 +253,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl } val setPermissions = if (destName.equals(Client.APP_JAR)) true else false val destPath = copyRemoteFile(dst, new Path(localURI), replication, setPermissions) - distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, + distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, destName, statCache) } } @@ -265,7 +265,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl val localPath = new Path(localURI) val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName()) val destPath = copyRemoteFile(dst, localPath, replication) - distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, + distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, linkname, statCache, true) } } @@ -277,7 +277,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl val localPath = new Path(localURI) val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName()) val destPath = copyRemoteFile(dst, localPath, replication) - distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, + distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.FILE, linkname, statCache) } } @@ -289,7 +289,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl val localPath = new Path(localURI) val linkname = Option(localURI.getFragment()).getOrElse(localPath.getName()) val destPath = copyRemoteFile(dst, localPath, replication) - distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, + distCacheMgr.addResource(fs, conf, destPath, localResources, LocalResourceType.ARCHIVE, linkname, statCache) } } @@ -299,7 +299,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl } def setupLaunchEnv( - localResources: HashMap[String, LocalResource], + localResources: HashMap[String, LocalResource], stagingDir: String): HashMap[String, String] = { logInfo("Setting up the launch environment") val log4jConfLocalRes = localResources.getOrElse(Client.LOG4J_PROP, null) @@ -354,7 +354,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl // Add Xmx for am memory JAVA_OPTS += "-Xmx" + amMemory + "m " - JAVA_OPTS += " -Djava.io.tmpdir=" + + JAVA_OPTS += " -Djava.io.tmpdir=" + new Path(Environment.PWD.$(), YarnConfiguration.DEFAULT_CONTAINER_TEMP_DIR) + " " // Commenting it out for now - so that people can refer to the properties if required. Remove @@ -387,11 +387,11 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl javaCommand = Environment.JAVA_HOME.$() + "/bin/java" } - val commands = List[String](javaCommand + + val commands = List[String](javaCommand + " -server " + JAVA_OPTS + " " + args.amClass + - " --class " + args.userClass + + " --class " + args.userClass + " --jar " + args.userJar + userArgsToString(args) + " --worker-memory " + args.workerMemory + @@ -421,7 +421,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl super.submitApplication(appContext) } - def monitorApplication(appId: ApplicationId): Boolean = { + def monitorApplication(appId: ApplicationId): Boolean = { while (true) { Thread.sleep(1000) val report = super.getApplicationReport(appId) @@ -443,7 +443,7 @@ class Client(conf: Configuration, args: ClientArguments) extends YarnClientImpl val state = report.getYarnApplicationState() val dsStatus = report.getFinalApplicationStatus() - if (state == YarnApplicationState.FINISHED || + if (state == YarnApplicationState.FINISHED || state == YarnApplicationState.FAILED || state == YarnApplicationState.KILLED) { return true @@ -461,7 +461,7 @@ object Client { def main(argStrings: Array[String]) { // Set an env variable indicating we are running in YARN mode. // Note that anything with SPARK prefix gets propagated to all (remote) processes - conf.set("SPARK_YARN_MODE", "true") + System.setProperty("SPARK_YARN_MODE", "true") val args = new ClientArguments(argStrings) @@ -479,25 +479,25 @@ object Client { Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$()) // If log4j present, ensure ours overrides all others if (addLog4j) { - Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() + + Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() + Path.SEPARATOR + LOG4J_PROP) } // Normally the users app.jar is last in case conflicts with spark jars - val userClasspathFirst = conf.getOrElse("spark.yarn.user.classpath.first", "false") + val userClasspathFirst = conf.getOrElse("spark.yarn.user.classpath.first", "false") .toBoolean if (userClasspathFirst) { - Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() + + Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() + Path.SEPARATOR + APP_JAR) } - Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() + + Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() + Path.SEPARATOR + SPARK_JAR) Client.populateHadoopClasspath(conf, env) if (!userClasspathFirst) { - Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() + + Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() + Path.SEPARATOR + APP_JAR) } - Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() + + Apps.addToEnvironment(env, Environment.CLASSPATH.name, Environment.PWD.$() + Path.SEPARATOR + "*") } } diff --git a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala index 617289f568..e9e46a193b 100644 --- a/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala +++ b/yarn/src/main/scala/org/apache/spark/deploy/yarn/ClientArguments.scala @@ -33,7 +33,7 @@ class ClientArguments(val args: Array[String]) { var workerMemory = 1024 var workerCores = 1 var numWorkers = 2 - var amQueue = conf.getOrElse("QUEUE", "default") + var amQueue = conf.getOrElse("QUEUE", "default") var amMemory: Int = 512 var amClass: String = "org.apache.spark.deploy.yarn.ApplicationMaster" var appName: String = "Spark" -- cgit v1.2.3 From f150b6e76c56ed6f604e6dbda7bce6b6278929fb Mon Sep 17 00:00:00 2001 From: "Lian, Cheng" Date: Sun, 29 Dec 2013 17:13:01 +0800 Subject: Response to Reynold's comments --- .../spark/mllib/classification/NaiveBayes.scala | 26 +++++++++++++--------- 1 file changed, 16 insertions(+), 10 deletions(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 2bc4c5afc0..d0f3a368e8 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -17,20 +17,22 @@ package org.apache.spark.mllib.classification +import org.jblas.DoubleMatrix + import org.apache.spark.Logging import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext._ -import org.jblas.DoubleMatrix /** * Model for Naive Bayes Classifiers. * - * @param weightPerLabel Weights computed for every label, which's dimension is C. - * @param weightMatrix Weights computed for every label and feature, which's dimension is CXD + * @param weightPerLabel Weights computed for every label, whose dimension is C. + * @param weightMatrix Weights computed for every label and feature, whose dimension is CXD */ -class NaiveBayesModel(val weightPerLabel: Array[Double], - val weightMatrix: Array[Array[Double]]) +class NaiveBayesModel( + @transient val weightPerLabel: Array[Double], + @transient val weightMatrix: Array[Array[Double]]) extends ClassificationModel with Serializable { // Create a column vector that can be used for predictions @@ -50,7 +52,12 @@ class NaiveBayes private (val lambda: Double = 1.0) // smoothing parameter extends Serializable with Logging { private def vectorAdd(v1: Array[Double], v2: Array[Double]) = { - v1.zip(v2).map(pair => pair._1 + pair._2) + var i = 0 + while (i < v1.length) { + v1(i) += v2(i) + i += 1 + } + v1 } /** @@ -79,8 +86,8 @@ class NaiveBayes private (val lambda: Double = 1.0) // smoothing parameter // considerably expensive. val N = collected.values.map(_._1).sum val logDenom = math.log(N + C * lambda) - val weightPerLabel = Array.fill[Double](C)(0) - val weightMatrix = Array.fill[Array[Double]](C)(null) + val weightPerLabel = new Array[Double](C) + val weightMatrix = new Array[Array[Double]](C) for ((label, (_, labelWeight, weights)) <- collected) { weightPerLabel(label) = labelWeight - logDenom @@ -100,8 +107,7 @@ object NaiveBayes { * @param input RDD of (label, array of features) pairs. * @param lambda smooth parameter */ - def train(C: Int, D: Int, input: RDD[LabeledPoint], - lambda: Double = 1.0): NaiveBayesModel = { + def train(C: Int, D: Int, input: RDD[LabeledPoint], lambda: Double = 1.0): NaiveBayesModel = { new NaiveBayes(lambda).run(C, D, input) } } -- cgit v1.2.3 From 6d0e2e86dfbca88abc847d3babac2d1f82d61aaf Mon Sep 17 00:00:00 2001 From: "Lian, Cheng" Date: Mon, 30 Dec 2013 22:46:32 +0800 Subject: Response to comments from Reynold, Ameet and Evan * Arguments renamed according to Ameet's suggestion * Using DoubleMatrix instead of Array[Double] in computation * Removed arguments C (kinds of label) and D (dimension of feature vector) from NaiveBayes.train() * Replaced reduceByKey with foldByKey to avoid modifying original input data --- .../spark/mllib/classification/NaiveBayes.scala | 120 +++++++++++++-------- .../mllib/classification/NaiveBayesSuite.scala | 32 +++--- 2 files changed, 90 insertions(+), 62 deletions(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index d0f3a368e8..9fd1adddb0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -27,87 +27,115 @@ import org.apache.spark.SparkContext._ /** * Model for Naive Bayes Classifiers. * - * @param weightPerLabel Weights computed for every label, whose dimension is C. - * @param weightMatrix Weights computed for every label and feature, whose dimension is CXD + * @param pi Log of class priors, whose dimension is C. + * @param theta Log of class conditional probabilities, whose dimension is CXD. */ -class NaiveBayesModel( - @transient val weightPerLabel: Array[Double], - @transient val weightMatrix: Array[Array[Double]]) +class NaiveBayesModel(pi: Array[Double], theta: Array[Array[Double]]) extends ClassificationModel with Serializable { // Create a column vector that can be used for predictions - private val _weightPerLabel = new DoubleMatrix(weightPerLabel.length, 1, weightPerLabel:_*) - private val _weightMatrix = new DoubleMatrix(weightMatrix) + private val _pi = new DoubleMatrix(pi.length, 1, pi: _*) + private val _theta = new DoubleMatrix(theta) def predict(testData: RDD[Array[Double]]): RDD[Double] = testData.map(predict) def predict(testData: Array[Double]): Double = { val dataMatrix = new DoubleMatrix(testData.length, 1, testData: _*) - val result = _weightPerLabel.add(_weightMatrix.mmul(dataMatrix)) + val result = _pi.add(_theta.mmul(dataMatrix)) result.argmax() } } -class NaiveBayes private (val lambda: Double = 1.0) // smoothing parameter +/** + * Trains a Naive Bayes model given an RDD of `(label, features)` pairs. + * + * @param lambda The smooth parameter + */ +class NaiveBayes private (val lambda: Double = 1.0) extends Serializable with Logging { - private def vectorAdd(v1: Array[Double], v2: Array[Double]) = { - var i = 0 - while (i < v1.length) { - v1(i) += v2(i) - i += 1 - } - v1 - } - /** - * Run the algorithm with the configured parameters on an input - * RDD of LabeledPoint entries. + * Run the algorithm with the configured parameters on an input RDD of LabeledPoint entries. * - * @param C kind of labels, labels are continuous integers and the maximal label is C-1 - * @param D dimension of feature vectors * @param data RDD of (label, array of features) pairs. */ - def run(C: Int, D: Int, data: RDD[LabeledPoint]) = { - val countsAndSummedFeatures = data.map { case LabeledPoint(label, features) => - label.toInt -> (1, features) - }.reduceByKey { (lhs, rhs) => - (lhs._1 + rhs._1, vectorAdd(lhs._2, rhs._2)) + def run(data: RDD[LabeledPoint]) = { + // Prepares input data, the shape of resulted RDD is: + // + // label: Int -> (count: Int, features: DoubleMatrix) + // + // The added count field is initialized to 1 to enable the following `foldByKey` transformation. + val mappedData = data.map { case LabeledPoint(label, features) => + label.toInt -> (1, new DoubleMatrix(features.length, 1, features: _*)) + } + + // Gets a map from labels to their corresponding sample point counts and summed feature vectors. + // Shape of resulted RDD is: + // + // label: Int -> (count: Int, summedFeatureVector: DoubleMatrix) + // + // Two tricky parts worth explaining: + // + // 1. Feature vectors are summed with the inplace jblas matrix addition operation, thus we + // chose `foldByKey` instead of `reduceByKey` to avoid modifying original input data. + // + // 2. The zero value passed to `foldByKey` contains a `null` rather than a zero vector because + // the dimension of the feature vector is unknown. Calling `data.first.length` to get the + // dimension is not preferable since it requires an expensive RDD action. + val countsAndSummedFeatures = mappedData.foldByKey((0, null)) { (lhs, rhs) => + if (lhs._1 == 0) { + (rhs._1, new DoubleMatrix().copy(rhs._2)) + } else { + (lhs._1 + rhs._1, lhs._2.addi(rhs._2)) + } } val collected = countsAndSummedFeatures.mapValues { case (count, summedFeatureVector) => - val labelWeight = math.log(count + lambda) - val logDenom = math.log(summedFeatureVector.sum + D * lambda) - val weights = summedFeatureVector.map(w => math.log(w + lambda) - logDenom) - (count, labelWeight, weights) + val p = math.log(count + lambda) + val logDenom = math.log(summedFeatureVector.sum + summedFeatureVector.length * lambda) + val t = summedFeatureVector + var i = 0 + while (i < t.length) { + t.put(i, math.log(t.get(i) + lambda) - logDenom) + i += 1 + } + (count, p, t) }.collectAsMap() - // We can simply call `data.count` to get `N`, but that triggers another RDD action, which is - // considerably expensive. + // Total sample count. Calling `data.count` to get `N` is not preferable since it triggers + // an expensive RDD action val N = collected.values.map(_._1).sum + + // Kinds of label. + val C = collected.size + val logDenom = math.log(N + C * lambda) - val weightPerLabel = new Array[Double](C) - val weightMatrix = new Array[Array[Double]](C) + val pi = new Array[Double](C) + val theta = new Array[Array[Double]](C) - for ((label, (_, labelWeight, weights)) <- collected) { - weightPerLabel(label) = labelWeight - logDenom - weightMatrix(label) = weights + for ((label, (_, p, t)) <- collected) { + pi(label) = p - logDenom + theta(label) = t.toArray } - new NaiveBayesModel(weightPerLabel, weightMatrix) + new NaiveBayesModel(pi, theta) } } object NaiveBayes { /** - * Train a naive bayes model given an RDD of (label, features) pairs. + * Trains a Naive Bayes model given an RDD of `(label, features)` pairs. + * + * This is the Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all kinds of + * discrete data. For example, by converting documents into TF-IDF vectors, it can be used for + * document classification. By making every vector a 0-1 vector. it can also be used as + * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]). * - * @param C kind of labels, the maximal label is C-1 - * @param D dimension of feature vectors - * @param input RDD of (label, array of features) pairs. - * @param lambda smooth parameter + * @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency + * vector or a count vector. + * @param lambda The smooth parameter */ - def train(C: Int, D: Int, input: RDD[LabeledPoint], lambda: Double = 1.0): NaiveBayesModel = { - new NaiveBayes(lambda).run(C, D, input) + def train(input: RDD[LabeledPoint], lambda: Double = 1.0): NaiveBayesModel = { + new NaiveBayes(lambda).run(input) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index a2821347a7..18575f410c 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -38,20 +38,20 @@ object NaiveBayesSuite { // Generate input of the form Y = (weightMatrix*x).argmax() def generateNaiveBayesInput( - weightPerLabel: Array[Double], // 1XC - weightsMatrix: Array[Array[Double]], // CXD + pi: Array[Double], // 1XC + theta: Array[Array[Double]], // CXD nPoints: Int, seed: Int): Seq[LabeledPoint] = { - val D = weightsMatrix(0).length + val D = theta(0).length val rnd = new Random(seed) - val _weightPerLabel = weightPerLabel.map(math.pow(math.E, _)) - val _weightMatrix = weightsMatrix.map(row => row.map(math.pow(math.E, _))) + val _pi = pi.map(math.pow(math.E, _)) + val _theta = theta.map(row => row.map(math.pow(math.E, _))) for (i <- 0 until nPoints) yield { - val y = calcLabel(rnd.nextDouble(), _weightPerLabel) + val y = calcLabel(rnd.nextDouble(), _pi) val xi = Array.tabulate[Double](D) { j => - if (rnd.nextDouble() < _weightMatrix(y)(j)) 1 else 0 + if (rnd.nextDouble() < _theta(y)(j)) 1 else 0 } LabeledPoint(y, xi) @@ -83,20 +83,20 @@ class NaiveBayesSuite extends FunSuite with BeforeAndAfterAll { test("Naive Bayes") { val nPoints = 10000 - val weightPerLabel = Array(math.log(0.5), math.log(0.3), math.log(0.2)) - val weightsMatrix = Array( - Array(math.log(0.91), math.log(0.03), math.log(0.03), math.log(0.03)), // label 0 - Array(math.log(0.03), math.log(0.91), math.log(0.03), math.log(0.03)), // label 1 - Array(math.log(0.03), math.log(0.03), math.log(0.91), math.log(0.03)) // label 2 - ) + val pi = Array(0.5, 0.3, 0.2).map(math.log) + val theta = Array( + Array(0.91, 0.03, 0.03, 0.03), // label 0 + Array(0.03, 0.91, 0.03, 0.03), // label 1 + Array(0.03, 0.03, 0.91, 0.03) // label 2 + ).map(_.map(math.log)) - val testData = NaiveBayesSuite.generateNaiveBayesInput(weightPerLabel, weightsMatrix, nPoints, 42) + val testData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 42) val testRDD = sc.parallelize(testData, 2) testRDD.cache() - val model = NaiveBayes.train(3, 4, testRDD) + val model = NaiveBayes.train(testRDD) - val validationData = NaiveBayesSuite.generateNaiveBayesInput(weightPerLabel, weightsMatrix, nPoints, 17) + val validationData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 17) val validationRDD = sc.parallelize(validationData, 2) // Test prediction on RDD. -- cgit v1.2.3 From dd6033e6853e32e9de2c910797c7fbc0072e7491 Mon Sep 17 00:00:00 2001 From: "Lian, Cheng" Date: Thu, 2 Jan 2014 01:38:24 +0800 Subject: Aggregated all sample points to driver without any shuffle --- .../spark/mllib/classification/NaiveBayes.scala | 76 ++++++++-------------- .../mllib/classification/NaiveBayesSuite.scala | 8 +-- 2 files changed, 31 insertions(+), 53 deletions(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index 9fd1adddb0..524300d6ae 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -17,12 +17,13 @@ package org.apache.spark.mllib.classification +import scala.collection.mutable + import org.jblas.DoubleMatrix import org.apache.spark.Logging import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.rdd.RDD -import org.apache.spark.SparkContext._ /** * Model for Naive Bayes Classifiers. @@ -60,62 +61,39 @@ class NaiveBayes private (val lambda: Double = 1.0) * @param data RDD of (label, array of features) pairs. */ def run(data: RDD[LabeledPoint]) = { - // Prepares input data, the shape of resulted RDD is: - // - // label: Int -> (count: Int, features: DoubleMatrix) - // - // The added count field is initialized to 1 to enable the following `foldByKey` transformation. - val mappedData = data.map { case LabeledPoint(label, features) => - label.toInt -> (1, new DoubleMatrix(features.length, 1, features: _*)) - } - - // Gets a map from labels to their corresponding sample point counts and summed feature vectors. - // Shape of resulted RDD is: - // - // label: Int -> (count: Int, summedFeatureVector: DoubleMatrix) + // Aggregates all sample points to driver side to get sample count and summed feature vector + // for each label. The shape of `zeroCombiner` & `aggregated` is: // - // Two tricky parts worth explaining: - // - // 1. Feature vectors are summed with the inplace jblas matrix addition operation, thus we - // chose `foldByKey` instead of `reduceByKey` to avoid modifying original input data. - // - // 2. The zero value passed to `foldByKey` contains a `null` rather than a zero vector because - // the dimension of the feature vector is unknown. Calling `data.first.length` to get the - // dimension is not preferable since it requires an expensive RDD action. - val countsAndSummedFeatures = mappedData.foldByKey((0, null)) { (lhs, rhs) => - if (lhs._1 == 0) { - (rhs._1, new DoubleMatrix().copy(rhs._2)) - } else { - (lhs._1 + rhs._1, lhs._2.addi(rhs._2)) + // label: Int -> (count: Int, featuresSum: DoubleMatrix) + val zeroCombiner = mutable.Map.empty[Int, (Int, DoubleMatrix)] + val aggregated = data.aggregate(zeroCombiner)({ (combiner, point) => + point match { + case LabeledPoint(label, features) => + val (count, featuresSum) = combiner.getOrElse(label.toInt, (0, DoubleMatrix.zeros(1))) + val fs = new DoubleMatrix(features.length, 1, features: _*) + combiner += label.toInt -> (count + 1, featuresSum.addi(fs)) } - } - - val collected = countsAndSummedFeatures.mapValues { case (count, summedFeatureVector) => - val p = math.log(count + lambda) - val logDenom = math.log(summedFeatureVector.sum + summedFeatureVector.length * lambda) - val t = summedFeatureVector - var i = 0 - while (i < t.length) { - t.put(i, math.log(t.get(i) + lambda) - logDenom) - i += 1 + }, { (lhs, rhs) => + for ((label, (c, fs)) <- rhs) { + val (count, featuresSum) = lhs.getOrElse(label, (0, DoubleMatrix.zeros(1))) + lhs(label) = (count + c, featuresSum.addi(fs)) } - (count, p, t) - }.collectAsMap() - - // Total sample count. Calling `data.count` to get `N` is not preferable since it triggers - // an expensive RDD action - val N = collected.values.map(_._1).sum + lhs + }) - // Kinds of label. - val C = collected.size + // Kinds of label + val C = aggregated.size + // Total sample count + val N = aggregated.values.map(_._1).sum - val logDenom = math.log(N + C * lambda) val pi = new Array[Double](C) val theta = new Array[Array[Double]](C) + val piLogDenom = math.log(N + C * lambda) - for ((label, (_, p, t)) <- collected) { - pi(label) = p - logDenom - theta(label) = t.toArray + for ((label, (count, fs)) <- aggregated) { + val thetaLogDenom = math.log(fs.sum() + fs.length * lambda) + pi(label) = math.log(count + lambda) - piLogDenom + theta(label) = fs.toArray.map(f => math.log(f + lambda) - thetaLogDenom) } new NaiveBayesModel(pi, theta) diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index 18575f410c..b615f76e66 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -27,16 +27,16 @@ import org.apache.spark.SparkContext object NaiveBayesSuite { - private def calcLabel(p: Double, weightPerLabel: Array[Double]): Int = { + private def calcLabel(p: Double, pi: Array[Double]): Int = { var sum = 0.0 - for (j <- 0 until weightPerLabel.length) { - sum += weightPerLabel(j) + for (j <- 0 until pi.length) { + sum += pi(j) if (p < sum) return j } -1 } - // Generate input of the form Y = (weightMatrix*x).argmax() + // Generate input of the form Y = (theta * x).argmax() def generateNaiveBayesInput( pi: Array[Double], // 1XC theta: Array[Array[Double]], // CXD -- cgit v1.2.3