From 3bb714eaa3bdb7b7c33f6e5263c683f4c4beeddc Mon Sep 17 00:00:00 2001 From: "Lian, Cheng" Date: Wed, 25 Dec 2013 17:15:38 +0800 Subject: Refactored NaiveBayes * Minimized shuffle output with mapPartitions. * Reduced RDD actions from 3 to 1. --- .../spark/mllib/classification/NaiveBayes.scala | 60 +++++++++++++--------- .../mllib/classification/NaiveBayesSuite.scala | 9 ++-- 2 files changed, 41 insertions(+), 28 deletions(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index f1b0e6ee6a..edea5ed3e6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -48,11 +48,12 @@ class NaiveBayesModel(val weightPerLabel: Array[Double], } } - - class NaiveBayes private (val lambda: Double = 1.0) // smoothing parameter extends Serializable with Logging { + private[this] def vectorAdd(v1: Array[Double], v2: Array[Double]) = + v1.zip(v2).map(pair => pair._1 + pair._2) + /** * Run the algorithm with the configured parameters on an input * RDD of LabeledPoint entries. @@ -61,29 +62,42 @@ class NaiveBayes private (val lambda: Double = 1.0) // smoothing parameter * @param D dimension of feature vectors * @param data RDD of (label, array of features) pairs. */ - def run(C: Int, D: Int, data: RDD[LabeledPoint]): NaiveBayesModel = { - val groupedData = data.map(p => p.label.toInt -> p.features).groupByKey() - - val countPerLabel = groupedData.mapValues(_.size) - val logDenominator = math.log(data.count() + C * lambda) - val weightPerLabel = countPerLabel.mapValues { - count => math.log(count + lambda) - logDenominator + def run(C: Int, D: Int, data: RDD[LabeledPoint]) = { + val locallyReduced = data.mapPartitions { iterator => + val localLabelCounts = mutable.Map.empty[Int, Int].withDefaultValue(0) + val localSummedObservations = + mutable.Map.empty[Int, Array[Double]].withDefaultValue(Array.fill(D)(0.0)) + + for (LabeledPoint(label, features) <- iterator; i = label.toInt) { + localLabelCounts(i) += 1 + localSummedObservations(i) = vectorAdd(localSummedObservations(i), features) + } + + for ((label, count) <- localLabelCounts.toIterator) yield { + label -> (count, localSummedObservations(label)) + } + } + + val reduced = locallyReduced.reduceByKey { (lhs, rhs) => + (lhs._1 + rhs._1, vectorAdd(lhs._2, rhs._2)) } - - val summedObservations = groupedData.mapValues(_.reduce { - (lhs, rhs) => lhs.zip(rhs).map(pair => pair._1 + pair._2) - }) - - val weightsMatrix = summedObservations.mapValues { weights => - val sum = weights.sum - val logDenom = math.log(sum + D * lambda) - weights.map(w => math.log(w + lambda) - logDenom) + + val collected = reduced.mapValues { case (count, summed) => + val labelWeight = math.log(count + lambda) + val logDenom = math.log(summed.sum + D * lambda) + val weights = summed.map(w => math.log(w + lambda) - logDenom) + (count, labelWeight, weights) + }.collectAsMap() + + val weightPerLabel = { + val N = collected.values.map(_._1).sum + val logDenom = math.log(N + C * lambda) + collected.mapValues(_._2 - logDenom).toArray.sortBy(_._1).map(_._2) } - - val labelWeights = weightPerLabel.collect().sorted.map(_._2) - val weightsMat = weightsMatrix.collect().sortBy(_._1).map(_._2) - - new NaiveBayesModel(labelWeights, weightsMat) + + val weightMatrix = collected.mapValues(_._3).toArray.sortBy(_._1).map(_._2) + + new NaiveBayesModel(weightPerLabel, weightMatrix) } } diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala index d871ed3672..cc8d48a42b 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -1,6 +1,5 @@ package org.apache.spark.mllib.classification -import scala.collection.JavaConversions._ import scala.util.Random import org.scalatest.BeforeAndAfterAll @@ -56,12 +55,12 @@ class NaiveBayesSuite extends FunSuite with BeforeAndAfterAll { } def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { - val numOffPredictions = predictions.zip(input).count { + val numOfPredictions = predictions.zip(input).count { case (prediction, expected) => prediction != expected.label } // At least 80% of the predictions should be on. - assert(numOffPredictions < input.length / 5) + assert(numOfPredictions < input.length / 5) } test("Naive Bayes") { @@ -71,8 +70,8 @@ class NaiveBayesSuite extends FunSuite with BeforeAndAfterAll { val weightsMatrix = Array( Array(math.log(0.91), math.log(0.03), math.log(0.03), math.log(0.03)), // label 0 Array(math.log(0.03), math.log(0.91), math.log(0.03), math.log(0.03)), // label 1 - Array(math.log(0.03), math.log(0.03), math.log(0.91), math.log(0.03)) // label 2 - ) + Array(math.log(0.03), math.log(0.03), math.log(0.91), math.log(0.03)) // label 2 + ) val testData = NaiveBayesSuite.generateNaiveBayesInput(weightPerLabel, weightsMatrix, nPoints, 42) val testRDD = sc.parallelize(testData, 2) -- cgit v1.2.3