aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorFrank Dai <soulmachine@gmail.com>2013-12-25 16:50:42 +0800
committerFrank Dai <soulmachine@gmail.com>2013-12-25 16:50:42 +0800
commit3dc655aa19f678219e5d999fe97ab769567ffb1c (patch)
tree4daa85039ddcd82d0e262a4786d4c63c2ca4b747
parent85a344b4f0cd149c6e6f06f8b942c34146b302be (diff)
downloadspark-3dc655aa19f678219e5d999fe97ab769567ffb1c.tar.gz
spark-3dc655aa19f678219e5d999fe97ab769567ffb1c.tar.bz2
spark-3dc655aa19f678219e5d999fe97ab769567ffb1c.zip
standard Naive Bayes classifier
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala103
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala92
2 files changed, 195 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
new file mode 100644
index 0000000000..f1b0e6ee6a
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.classification
+
+import scala.collection.mutable
+
+import org.apache.spark.Logging
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.rdd.RDD
+import org.apache.spark.SparkContext._
+import org.jblas.DoubleMatrix
+
+/**
+ * Model for Naive Bayes Classifiers.
+ *
+ * @param weightPerLabel Weights computed for every label, which's dimension is C.
+ * @param weightMatrix Weights computed for every label and feature, which's dimension is CXD
+ */
+class NaiveBayesModel(val weightPerLabel: Array[Double],
+ val weightMatrix: Array[Array[Double]])
+ extends ClassificationModel with Serializable {
+
+ // Create a column vector that can be used for predictions
+ private val _weightPerLabel = new DoubleMatrix(weightPerLabel.length, 1, weightPerLabel:_*)
+ private val _weightMatrix = new DoubleMatrix(weightMatrix)
+
+ def predict(testData: RDD[Array[Double]]): RDD[Double] = testData.map(predict)
+
+ def predict(testData: Array[Double]): Double = {
+ val dataMatrix = new DoubleMatrix(testData.length, 1, testData: _*)
+ val result = _weightPerLabel.add(_weightMatrix.mmul(dataMatrix))
+ result.argmax()
+ }
+}
+
+
+
+class NaiveBayes private (val lambda: Double = 1.0) // smoothing parameter
+ extends Serializable with Logging {
+
+ /**
+ * Run the algorithm with the configured parameters on an input
+ * RDD of LabeledPoint entries.
+ *
+ * @param C kind of labels, labels are continuous integers and the maximal label is C-1
+ * @param D dimension of feature vectors
+ * @param data RDD of (label, array of features) pairs.
+ */
+ def run(C: Int, D: Int, data: RDD[LabeledPoint]): NaiveBayesModel = {
+ val groupedData = data.map(p => p.label.toInt -> p.features).groupByKey()
+
+ val countPerLabel = groupedData.mapValues(_.size)
+ val logDenominator = math.log(data.count() + C * lambda)
+ val weightPerLabel = countPerLabel.mapValues {
+ count => math.log(count + lambda) - logDenominator
+ }
+
+ val summedObservations = groupedData.mapValues(_.reduce {
+ (lhs, rhs) => lhs.zip(rhs).map(pair => pair._1 + pair._2)
+ })
+
+ val weightsMatrix = summedObservations.mapValues { weights =>
+ val sum = weights.sum
+ val logDenom = math.log(sum + D * lambda)
+ weights.map(w => math.log(w + lambda) - logDenom)
+ }
+
+ val labelWeights = weightPerLabel.collect().sorted.map(_._2)
+ val weightsMat = weightsMatrix.collect().sortBy(_._1).map(_._2)
+
+ new NaiveBayesModel(labelWeights, weightsMat)
+ }
+}
+
+object NaiveBayes {
+ /**
+ * Train a naive bayes model given an RDD of (label, features) pairs.
+ *
+ * @param C kind of labels, the maximal label is C-1
+ * @param D dimension of feature vectors
+ * @param input RDD of (label, array of features) pairs.
+ * @param lambda smooth parameter
+ */
+ def train(C: Int, D: Int, input: RDD[LabeledPoint],
+ lambda: Double = 1.0): NaiveBayesModel = {
+ new NaiveBayes(lambda).run(C, D, input)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
new file mode 100644
index 0000000000..d871ed3672
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
@@ -0,0 +1,92 @@
+package org.apache.spark.mllib.classification
+
+import scala.collection.JavaConversions._
+import scala.util.Random
+
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.SparkContext
+
+object NaiveBayesSuite {
+
+ private def calcLabel(p: Double, weightPerLabel: Array[Double]): Int = {
+ var sum = 0.0
+ for (j <- 0 until weightPerLabel.length) {
+ sum += weightPerLabel(j)
+ if (p < sum) return j
+ }
+ -1
+ }
+
+ // Generate input of the form Y = (weightMatrix*x).argmax()
+ def generateNaiveBayesInput(
+ weightPerLabel: Array[Double], // 1XC
+ weightsMatrix: Array[Array[Double]], // CXD
+ nPoints: Int,
+ seed: Int): Seq[LabeledPoint] = {
+ val D = weightsMatrix(0).length
+ val rnd = new Random(seed)
+
+ val _weightPerLabel = weightPerLabel.map(math.pow(math.E, _))
+ val _weightMatrix = weightsMatrix.map(row => row.map(math.pow(math.E, _)))
+
+ for (i <- 0 until nPoints) yield {
+ val y = calcLabel(rnd.nextDouble(), _weightPerLabel)
+ val xi = Array.tabulate[Double](D) { j =>
+ if (rnd.nextDouble() < _weightMatrix(y)(j)) 1 else 0
+ }
+
+ LabeledPoint(y, xi)
+ }
+ }
+}
+
+class NaiveBayesSuite extends FunSuite with BeforeAndAfterAll {
+ @transient private var sc: SparkContext = _
+
+ override def beforeAll() {
+ sc = new SparkContext("local", "test")
+ }
+
+ override def afterAll() {
+ sc.stop()
+ System.clearProperty("spark.driver.port")
+ }
+
+ def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
+ val numOffPredictions = predictions.zip(input).count {
+ case (prediction, expected) =>
+ prediction != expected.label
+ }
+ // At least 80% of the predictions should be on.
+ assert(numOffPredictions < input.length / 5)
+ }
+
+ test("Naive Bayes") {
+ val nPoints = 10000
+
+ val weightPerLabel = Array(math.log(0.5), math.log(0.3), math.log(0.2))
+ val weightsMatrix = Array(
+ Array(math.log(0.91), math.log(0.03), math.log(0.03), math.log(0.03)), // label 0
+ Array(math.log(0.03), math.log(0.91), math.log(0.03), math.log(0.03)), // label 1
+ Array(math.log(0.03), math.log(0.03), math.log(0.91), math.log(0.03)) // label 2
+ )
+
+ val testData = NaiveBayesSuite.generateNaiveBayesInput(weightPerLabel, weightsMatrix, nPoints, 42)
+ val testRDD = sc.parallelize(testData, 2)
+ testRDD.cache()
+
+ val model = NaiveBayes.train(3, 4, testRDD)
+
+ val validationData = NaiveBayesSuite.generateNaiveBayesInput(weightPerLabel, weightsMatrix, nPoints, 17)
+ val validationRDD = sc.parallelize(validationData, 2)
+
+ // Test prediction on RDD.
+ validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData)
+
+ // Test prediction on Array.
+ validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
+ }
+}