aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLian, Cheng <rhythm.mail@gmail.com>2013-12-25 22:45:57 +0800
committerLian, Cheng <rhythm.mail@gmail.com>2013-12-25 22:45:57 +0800
commitc0337c5bbfd5126c64964a9fdefd2bef11727d87 (patch)
tree1095e73a4558f3988235608fa984c5292635184f
parent3bb714eaa3bdb7b7c33f6e5263c683f4c4beeddc (diff)
downloadspark-c0337c5bbfd5126c64964a9fdefd2bef11727d87.tar.gz
spark-c0337c5bbfd5126c64964a9fdefd2bef11727d87.tar.bz2
spark-c0337c5bbfd5126c64964a9fdefd2bef11727d87.zip
Let reduceByKey to take care of local combine
Also refactored some heavy FP code to improve readability and reduce memory footprint.
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala43
1 files changed, 16 insertions, 27 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index edea5ed3e6..4c96b241eb 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -17,8 +17,6 @@
package org.apache.spark.mllib.classification
-import scala.collection.mutable
-
import org.apache.spark.Logging
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
@@ -63,39 +61,30 @@ class NaiveBayes private (val lambda: Double = 1.0) // smoothing parameter
* @param data RDD of (label, array of features) pairs.
*/
def run(C: Int, D: Int, data: RDD[LabeledPoint]) = {
- val locallyReduced = data.mapPartitions { iterator =>
- val localLabelCounts = mutable.Map.empty[Int, Int].withDefaultValue(0)
- val localSummedObservations =
- mutable.Map.empty[Int, Array[Double]].withDefaultValue(Array.fill(D)(0.0))
-
- for (LabeledPoint(label, features) <- iterator; i = label.toInt) {
- localLabelCounts(i) += 1
- localSummedObservations(i) = vectorAdd(localSummedObservations(i), features)
- }
-
- for ((label, count) <- localLabelCounts.toIterator) yield {
- label -> (count, localSummedObservations(label))
- }
- }
-
- val reduced = locallyReduced.reduceByKey { (lhs, rhs) =>
+ val countsAndSummedFeatures = data.map { case LabeledPoint(label, features) =>
+ label.toInt ->(1, features)
+ }.reduceByKey { (lhs, rhs) =>
(lhs._1 + rhs._1, vectorAdd(lhs._2, rhs._2))
}
- val collected = reduced.mapValues { case (count, summed) =>
+ val collected = countsAndSummedFeatures.mapValues { case (count, summedFeatureVector) =>
val labelWeight = math.log(count + lambda)
- val logDenom = math.log(summed.sum + D * lambda)
- val weights = summed.map(w => math.log(w + lambda) - logDenom)
+ val logDenom = math.log(summedFeatureVector.sum + D * lambda)
+ val weights = summedFeatureVector.map(w => math.log(w + lambda) - logDenom)
(count, labelWeight, weights)
}.collectAsMap()
- val weightPerLabel = {
- val N = collected.values.map(_._1).sum
- val logDenom = math.log(N + C * lambda)
- collected.mapValues(_._2 - logDenom).toArray.sortBy(_._1).map(_._2)
- }
+ // We can simply call `data.count` to get `N`, but that triggers another RDD action, which is
+ // considerably expensive.
+ val N = collected.values.map(_._1).sum
+ val logDenom = math.log(N + C * lambda)
+ val weightPerLabel = Array.fill[Double](C)(0)
+ val weightMatrix = Array.fill[Array[Double]](C)(null)
- val weightMatrix = collected.mapValues(_._3).toArray.sortBy(_._1).map(_._2)
+ for ((label, (_, labelWeight, weights)) <- collected) {
+ weightPerLabel(label) = labelWeight - logDenom
+ weightMatrix(label) = weights
+ }
new NaiveBayesModel(weightPerLabel, weightMatrix)
}