aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLian, Cheng <rhythm.mail@gmail.com>2013-12-30 22:46:32 +0800
committerLian, Cheng <rhythm.mail@gmail.com>2013-12-30 22:46:32 +0800
commit6d0e2e86dfbca88abc847d3babac2d1f82d61aaf (patch)
tree982302a5b1b2485ad08b992d9468e2b7c9eb4cc9
parentf150b6e76c56ed6f604e6dbda7bce6b6278929fb (diff)
downloadspark-6d0e2e86dfbca88abc847d3babac2d1f82d61aaf.tar.gz
spark-6d0e2e86dfbca88abc847d3babac2d1f82d61aaf.tar.bz2
spark-6d0e2e86dfbca88abc847d3babac2d1f82d61aaf.zip
Response to comments from Reynold, Ameet and Evan
* Arguments renamed according to Ameet's suggestion * Using DoubleMatrix instead of Array[Double] in computation * Removed arguments C (kinds of label) and D (dimension of feature vector) from NaiveBayes.train() * Replaced reduceByKey with foldByKey to avoid modifying original input data
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala120
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala32
2 files changed, 90 insertions, 62 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index d0f3a368e8..9fd1adddb0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -27,87 +27,115 @@ import org.apache.spark.SparkContext._
/**
* Model for Naive Bayes Classifiers.
*
- * @param weightPerLabel Weights computed for every label, whose dimension is C.
- * @param weightMatrix Weights computed for every label and feature, whose dimension is CXD
+ * @param pi Log of class priors, whose dimension is C.
+ * @param theta Log of class conditional probabilities, whose dimension is CXD.
*/
-class NaiveBayesModel(
- @transient val weightPerLabel: Array[Double],
- @transient val weightMatrix: Array[Array[Double]])
+class NaiveBayesModel(pi: Array[Double], theta: Array[Array[Double]])
extends ClassificationModel with Serializable {
// Create a column vector that can be used for predictions
- private val _weightPerLabel = new DoubleMatrix(weightPerLabel.length, 1, weightPerLabel:_*)
- private val _weightMatrix = new DoubleMatrix(weightMatrix)
+ private val _pi = new DoubleMatrix(pi.length, 1, pi: _*)
+ private val _theta = new DoubleMatrix(theta)
def predict(testData: RDD[Array[Double]]): RDD[Double] = testData.map(predict)
def predict(testData: Array[Double]): Double = {
val dataMatrix = new DoubleMatrix(testData.length, 1, testData: _*)
- val result = _weightPerLabel.add(_weightMatrix.mmul(dataMatrix))
+ val result = _pi.add(_theta.mmul(dataMatrix))
result.argmax()
}
}
-class NaiveBayes private (val lambda: Double = 1.0) // smoothing parameter
+/**
+ * Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
+ *
+ * @param lambda The smooth parameter
+ */
+class NaiveBayes private (val lambda: Double = 1.0)
extends Serializable with Logging {
- private def vectorAdd(v1: Array[Double], v2: Array[Double]) = {
- var i = 0
- while (i < v1.length) {
- v1(i) += v2(i)
- i += 1
- }
- v1
- }
-
/**
- * Run the algorithm with the configured parameters on an input
- * RDD of LabeledPoint entries.
+ * Run the algorithm with the configured parameters on an input RDD of LabeledPoint entries.
*
- * @param C kind of labels, labels are continuous integers and the maximal label is C-1
- * @param D dimension of feature vectors
* @param data RDD of (label, array of features) pairs.
*/
- def run(C: Int, D: Int, data: RDD[LabeledPoint]) = {
- val countsAndSummedFeatures = data.map { case LabeledPoint(label, features) =>
- label.toInt -> (1, features)
- }.reduceByKey { (lhs, rhs) =>
- (lhs._1 + rhs._1, vectorAdd(lhs._2, rhs._2))
+ def run(data: RDD[LabeledPoint]) = {
+ // Prepares input data, the shape of resulted RDD is:
+ //
+ // label: Int -> (count: Int, features: DoubleMatrix)
+ //
+ // The added count field is initialized to 1 to enable the following `foldByKey` transformation.
+ val mappedData = data.map { case LabeledPoint(label, features) =>
+ label.toInt -> (1, new DoubleMatrix(features.length, 1, features: _*))
+ }
+
+ // Gets a map from labels to their corresponding sample point counts and summed feature vectors.
+ // Shape of resulted RDD is:
+ //
+ // label: Int -> (count: Int, summedFeatureVector: DoubleMatrix)
+ //
+ // Two tricky parts worth explaining:
+ //
+ // 1. Feature vectors are summed with the inplace jblas matrix addition operation, thus we
+ // chose `foldByKey` instead of `reduceByKey` to avoid modifying original input data.
+ //
+ // 2. The zero value passed to `foldByKey` contains a `null` rather than a zero vector because
+ // the dimension of the feature vector is unknown. Calling `data.first.length` to get the
+ // dimension is not preferable since it requires an expensive RDD action.
+ val countsAndSummedFeatures = mappedData.foldByKey((0, null)) { (lhs, rhs) =>
+ if (lhs._1 == 0) {
+ (rhs._1, new DoubleMatrix().copy(rhs._2))
+ } else {
+ (lhs._1 + rhs._1, lhs._2.addi(rhs._2))
+ }
}
val collected = countsAndSummedFeatures.mapValues { case (count, summedFeatureVector) =>
- val labelWeight = math.log(count + lambda)
- val logDenom = math.log(summedFeatureVector.sum + D * lambda)
- val weights = summedFeatureVector.map(w => math.log(w + lambda) - logDenom)
- (count, labelWeight, weights)
+ val p = math.log(count + lambda)
+ val logDenom = math.log(summedFeatureVector.sum + summedFeatureVector.length * lambda)
+ val t = summedFeatureVector
+ var i = 0
+ while (i < t.length) {
+ t.put(i, math.log(t.get(i) + lambda) - logDenom)
+ i += 1
+ }
+ (count, p, t)
}.collectAsMap()
- // We can simply call `data.count` to get `N`, but that triggers another RDD action, which is
- // considerably expensive.
+ // Total sample count. Calling `data.count` to get `N` is not preferable since it triggers
+ // an expensive RDD action
val N = collected.values.map(_._1).sum
+
+ // Kinds of label.
+ val C = collected.size
+
val logDenom = math.log(N + C * lambda)
- val weightPerLabel = new Array[Double](C)
- val weightMatrix = new Array[Array[Double]](C)
+ val pi = new Array[Double](C)
+ val theta = new Array[Array[Double]](C)
- for ((label, (_, labelWeight, weights)) <- collected) {
- weightPerLabel(label) = labelWeight - logDenom
- weightMatrix(label) = weights
+ for ((label, (_, p, t)) <- collected) {
+ pi(label) = p - logDenom
+ theta(label) = t.toArray
}
- new NaiveBayesModel(weightPerLabel, weightMatrix)
+ new NaiveBayesModel(pi, theta)
}
}
object NaiveBayes {
/**
- * Train a naive bayes model given an RDD of (label, features) pairs.
+ * Trains a Naive Bayes model given an RDD of `(label, features)` pairs.
+ *
+ * This is the Multinomial NB ([[http://tinyurl.com/lsdw6p]]) which can handle all kinds of
+ * discrete data. For example, by converting documents into TF-IDF vectors, it can be used for
+ * document classification. By making every vector a 0-1 vector. it can also be used as
+ * Bernoulli NB ([[http://tinyurl.com/p7c96j6]]).
*
- * @param C kind of labels, the maximal label is C-1
- * @param D dimension of feature vectors
- * @param input RDD of (label, array of features) pairs.
- * @param lambda smooth parameter
+ * @param input RDD of `(label, array of features)` pairs. Every vector should be a frequency
+ * vector or a count vector.
+ * @param lambda The smooth parameter
*/
- def train(C: Int, D: Int, input: RDD[LabeledPoint], lambda: Double = 1.0): NaiveBayesModel = {
- new NaiveBayes(lambda).run(C, D, input)
+ def train(input: RDD[LabeledPoint], lambda: Double = 1.0): NaiveBayesModel = {
+ new NaiveBayes(lambda).run(input)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
index a2821347a7..18575f410c 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
@@ -38,20 +38,20 @@ object NaiveBayesSuite {
// Generate input of the form Y = (weightMatrix*x).argmax()
def generateNaiveBayesInput(
- weightPerLabel: Array[Double], // 1XC
- weightsMatrix: Array[Array[Double]], // CXD
+ pi: Array[Double], // 1XC
+ theta: Array[Array[Double]], // CXD
nPoints: Int,
seed: Int): Seq[LabeledPoint] = {
- val D = weightsMatrix(0).length
+ val D = theta(0).length
val rnd = new Random(seed)
- val _weightPerLabel = weightPerLabel.map(math.pow(math.E, _))
- val _weightMatrix = weightsMatrix.map(row => row.map(math.pow(math.E, _)))
+ val _pi = pi.map(math.pow(math.E, _))
+ val _theta = theta.map(row => row.map(math.pow(math.E, _)))
for (i <- 0 until nPoints) yield {
- val y = calcLabel(rnd.nextDouble(), _weightPerLabel)
+ val y = calcLabel(rnd.nextDouble(), _pi)
val xi = Array.tabulate[Double](D) { j =>
- if (rnd.nextDouble() < _weightMatrix(y)(j)) 1 else 0
+ if (rnd.nextDouble() < _theta(y)(j)) 1 else 0
}
LabeledPoint(y, xi)
@@ -83,20 +83,20 @@ class NaiveBayesSuite extends FunSuite with BeforeAndAfterAll {
test("Naive Bayes") {
val nPoints = 10000
- val weightPerLabel = Array(math.log(0.5), math.log(0.3), math.log(0.2))
- val weightsMatrix = Array(
- Array(math.log(0.91), math.log(0.03), math.log(0.03), math.log(0.03)), // label 0
- Array(math.log(0.03), math.log(0.91), math.log(0.03), math.log(0.03)), // label 1
- Array(math.log(0.03), math.log(0.03), math.log(0.91), math.log(0.03)) // label 2
- )
+ val pi = Array(0.5, 0.3, 0.2).map(math.log)
+ val theta = Array(
+ Array(0.91, 0.03, 0.03, 0.03), // label 0
+ Array(0.03, 0.91, 0.03, 0.03), // label 1
+ Array(0.03, 0.03, 0.91, 0.03) // label 2
+ ).map(_.map(math.log))
- val testData = NaiveBayesSuite.generateNaiveBayesInput(weightPerLabel, weightsMatrix, nPoints, 42)
+ val testData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 42)
val testRDD = sc.parallelize(testData, 2)
testRDD.cache()
- val model = NaiveBayes.train(3, 4, testRDD)
+ val model = NaiveBayes.train(testRDD)
- val validationData = NaiveBayesSuite.generateNaiveBayesInput(weightPerLabel, weightsMatrix, nPoints, 17)
+ val validationData = NaiveBayesSuite.generateNaiveBayesInput(pi, theta, nPoints, 17)
val validationRDD = sc.parallelize(validationData, 2)
// Test prediction on RDD.