aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorLian, Cheng <rhythm.mail@gmail.com>2014-01-02 01:38:24 +0800
committerLian, Cheng <rhythm.mail@gmail.com>2014-01-02 01:38:24 +0800
commitdd6033e6853e32e9de2c910797c7fbc0072e7491 (patch)
treec67660c36a5d2b1ac848b3da364522dca9149a37 /mllib
parent6d0e2e86dfbca88abc847d3babac2d1f82d61aaf (diff)
downloadspark-dd6033e6853e32e9de2c910797c7fbc0072e7491.tar.gz
spark-dd6033e6853e32e9de2c910797c7fbc0072e7491.tar.bz2
spark-dd6033e6853e32e9de2c910797c7fbc0072e7491.zip
Aggregated all sample points to driver without any shuffle
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala76
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala8
2 files changed, 31 insertions, 53 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index 9fd1adddb0..524300d6ae 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -17,12 +17,13 @@
package org.apache.spark.mllib.classification
+import scala.collection.mutable
+
import org.jblas.DoubleMatrix
import org.apache.spark.Logging
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.rdd.RDD
-import org.apache.spark.SparkContext._
/**
* Model for Naive Bayes Classifiers.
@@ -60,62 +61,39 @@ class NaiveBayes private (val lambda: Double = 1.0)
* @param data RDD of (label, array of features) pairs.
*/
def run(data: RDD[LabeledPoint]) = {
- // Prepares input data, the shape of resulted RDD is:
- //
- // label: Int -> (count: Int, features: DoubleMatrix)
- //
- // The added count field is initialized to 1 to enable the following `foldByKey` transformation.
- val mappedData = data.map { case LabeledPoint(label, features) =>
- label.toInt -> (1, new DoubleMatrix(features.length, 1, features: _*))
- }
-
- // Gets a map from labels to their corresponding sample point counts and summed feature vectors.
- // Shape of resulted RDD is:
- //
- // label: Int -> (count: Int, summedFeatureVector: DoubleMatrix)
+ // Aggregates all sample points to driver side to get sample count and summed feature vector
+ // for each label. The shape of `zeroCombiner` & `aggregated` is:
//
- // Two tricky parts worth explaining:
- //
- // 1. Feature vectors are summed with the inplace jblas matrix addition operation, thus we
- // chose `foldByKey` instead of `reduceByKey` to avoid modifying original input data.
- //
- // 2. The zero value passed to `foldByKey` contains a `null` rather than a zero vector because
- // the dimension of the feature vector is unknown. Calling `data.first.length` to get the
- // dimension is not preferable since it requires an expensive RDD action.
- val countsAndSummedFeatures = mappedData.foldByKey((0, null)) { (lhs, rhs) =>
- if (lhs._1 == 0) {
- (rhs._1, new DoubleMatrix().copy(rhs._2))
- } else {
- (lhs._1 + rhs._1, lhs._2.addi(rhs._2))
+ // label: Int -> (count: Int, featuresSum: DoubleMatrix)
+ val zeroCombiner = mutable.Map.empty[Int, (Int, DoubleMatrix)]
+ val aggregated = data.aggregate(zeroCombiner)({ (combiner, point) =>
+ point match {
+ case LabeledPoint(label, features) =>
+ val (count, featuresSum) = combiner.getOrElse(label.toInt, (0, DoubleMatrix.zeros(1)))
+ val fs = new DoubleMatrix(features.length, 1, features: _*)
+ combiner += label.toInt -> (count + 1, featuresSum.addi(fs))
}
- }
-
- val collected = countsAndSummedFeatures.mapValues { case (count, summedFeatureVector) =>
- val p = math.log(count + lambda)
- val logDenom = math.log(summedFeatureVector.sum + summedFeatureVector.length * lambda)
- val t = summedFeatureVector
- var i = 0
- while (i < t.length) {
- t.put(i, math.log(t.get(i) + lambda) - logDenom)
- i += 1
+ }, { (lhs, rhs) =>
+ for ((label, (c, fs)) <- rhs) {
+ val (count, featuresSum) = lhs.getOrElse(label, (0, DoubleMatrix.zeros(1)))
+ lhs(label) = (count + c, featuresSum.addi(fs))
}
- (count, p, t)
- }.collectAsMap()
-
- // Total sample count. Calling `data.count` to get `N` is not preferable since it triggers
- // an expensive RDD action
- val N = collected.values.map(_._1).sum
+ lhs
+ })
- // Kinds of label.
- val C = collected.size
+ // Kinds of label
+ val C = aggregated.size
+ // Total sample count
+ val N = aggregated.values.map(_._1).sum
- val logDenom = math.log(N + C * lambda)
val pi = new Array[Double](C)
val theta = new Array[Array[Double]](C)
+ val piLogDenom = math.log(N + C * lambda)
- for ((label, (_, p, t)) <- collected) {
- pi(label) = p - logDenom
- theta(label) = t.toArray
+ for ((label, (count, fs)) <- aggregated) {
+ val thetaLogDenom = math.log(fs.sum() + fs.length * lambda)
+ pi(label) = math.log(count + lambda) - piLogDenom
+ theta(label) = fs.toArray.map(f => math.log(f + lambda) - thetaLogDenom)
}
new NaiveBayesModel(pi, theta)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
index 18575f410c..b615f76e66 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
@@ -27,16 +27,16 @@ import org.apache.spark.SparkContext
object NaiveBayesSuite {
- private def calcLabel(p: Double, weightPerLabel: Array[Double]): Int = {
+ private def calcLabel(p: Double, pi: Array[Double]): Int = {
var sum = 0.0
- for (j <- 0 until weightPerLabel.length) {
- sum += weightPerLabel(j)
+ for (j <- 0 until pi.length) {
+ sum += pi(j)
if (p < sum) return j
}
-1
}
- // Generate input of the form Y = (weightMatrix*x).argmax()
+ // Generate input of the form Y = (theta * x).argmax()
def generateNaiveBayesInput(
pi: Array[Double], // 1XC
theta: Array[Array[Double]], // CXD