aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLian, Cheng <rhythm.mail@gmail.com>2013-12-25 17:15:38 +0800
committerLian, Cheng <rhythm.mail@gmail.com>2013-12-25 17:15:38 +0800
commit3bb714eaa3bdb7b7c33f6e5263c683f4c4beeddc (patch)
tree046e7aaf349acdb7689cefcb307e0a246e33b69f
parent3dc655aa19f678219e5d999fe97ab769567ffb1c (diff)
downloadspark-3bb714eaa3bdb7b7c33f6e5263c683f4c4beeddc.tar.gz
spark-3bb714eaa3bdb7b7c33f6e5263c683f4c4beeddc.tar.bz2
spark-3bb714eaa3bdb7b7c33f6e5263c683f4c4beeddc.zip
Refactored NaiveBayes
* Minimized shuffle output with mapPartitions. * Reduced RDD actions from 3 to 1.
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala60
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala9
2 files changed, 41 insertions, 28 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index f1b0e6ee6a..edea5ed3e6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -48,11 +48,12 @@ class NaiveBayesModel(val weightPerLabel: Array[Double],
}
}
-
-
class NaiveBayes private (val lambda: Double = 1.0) // smoothing parameter
extends Serializable with Logging {
+ private[this] def vectorAdd(v1: Array[Double], v2: Array[Double]) =
+ v1.zip(v2).map(pair => pair._1 + pair._2)
+
/**
* Run the algorithm with the configured parameters on an input
* RDD of LabeledPoint entries.
@@ -61,29 +62,42 @@ class NaiveBayes private (val lambda: Double = 1.0) // smoothing parameter
* @param D dimension of feature vectors
* @param data RDD of (label, array of features) pairs.
*/
- def run(C: Int, D: Int, data: RDD[LabeledPoint]): NaiveBayesModel = {
- val groupedData = data.map(p => p.label.toInt -> p.features).groupByKey()
-
- val countPerLabel = groupedData.mapValues(_.size)
- val logDenominator = math.log(data.count() + C * lambda)
- val weightPerLabel = countPerLabel.mapValues {
- count => math.log(count + lambda) - logDenominator
+ def run(C: Int, D: Int, data: RDD[LabeledPoint]) = {
+ val locallyReduced = data.mapPartitions { iterator =>
+ val localLabelCounts = mutable.Map.empty[Int, Int].withDefaultValue(0)
+ val localSummedObservations =
+ mutable.Map.empty[Int, Array[Double]].withDefaultValue(Array.fill(D)(0.0))
+
+ for (LabeledPoint(label, features) <- iterator; i = label.toInt) {
+ localLabelCounts(i) += 1
+ localSummedObservations(i) = vectorAdd(localSummedObservations(i), features)
+ }
+
+ for ((label, count) <- localLabelCounts.toIterator) yield {
+ label -> (count, localSummedObservations(label))
+ }
+ }
+
+ val reduced = locallyReduced.reduceByKey { (lhs, rhs) =>
+ (lhs._1 + rhs._1, vectorAdd(lhs._2, rhs._2))
}
-
- val summedObservations = groupedData.mapValues(_.reduce {
- (lhs, rhs) => lhs.zip(rhs).map(pair => pair._1 + pair._2)
- })
-
- val weightsMatrix = summedObservations.mapValues { weights =>
- val sum = weights.sum
- val logDenom = math.log(sum + D * lambda)
- weights.map(w => math.log(w + lambda) - logDenom)
+
+ val collected = reduced.mapValues { case (count, summed) =>
+ val labelWeight = math.log(count + lambda)
+ val logDenom = math.log(summed.sum + D * lambda)
+ val weights = summed.map(w => math.log(w + lambda) - logDenom)
+ (count, labelWeight, weights)
+ }.collectAsMap()
+
+ val weightPerLabel = {
+ val N = collected.values.map(_._1).sum
+ val logDenom = math.log(N + C * lambda)
+ collected.mapValues(_._2 - logDenom).toArray.sortBy(_._1).map(_._2)
}
-
- val labelWeights = weightPerLabel.collect().sorted.map(_._2)
- val weightsMat = weightsMatrix.collect().sortBy(_._1).map(_._2)
-
- new NaiveBayesModel(labelWeights, weightsMat)
+
+ val weightMatrix = collected.mapValues(_._3).toArray.sortBy(_._1).map(_._2)
+
+ new NaiveBayesModel(weightPerLabel, weightMatrix)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
index d871ed3672..cc8d48a42b 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
@@ -1,6 +1,5 @@
package org.apache.spark.mllib.classification
-import scala.collection.JavaConversions._
import scala.util.Random
import org.scalatest.BeforeAndAfterAll
@@ -56,12 +55,12 @@ class NaiveBayesSuite extends FunSuite with BeforeAndAfterAll {
}
def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
- val numOffPredictions = predictions.zip(input).count {
+ val numOfPredictions = predictions.zip(input).count {
case (prediction, expected) =>
prediction != expected.label
}
// At least 80% of the predictions should be on.
- assert(numOffPredictions < input.length / 5)
+ assert(numOfPredictions < input.length / 5)
}
test("Naive Bayes") {
@@ -71,8 +70,8 @@ class NaiveBayesSuite extends FunSuite with BeforeAndAfterAll {
val weightsMatrix = Array(
Array(math.log(0.91), math.log(0.03), math.log(0.03), math.log(0.03)), // label 0
Array(math.log(0.03), math.log(0.91), math.log(0.03), math.log(0.03)), // label 1
- Array(math.log(0.03), math.log(0.03), math.log(0.91), math.log(0.03)) // label 2
- )
+ Array(math.log(0.03), math.log(0.03), math.log(0.91), math.log(0.03)) // label 2
+ )
val testData = NaiveBayesSuite.generateNaiveBayesInput(weightPerLabel, weightsMatrix, nPoints, 42)
val testRDD = sc.parallelize(testData, 2)