diff options
author | Frank Dai <soulmachine@gmail.com> | 2013-12-25 16:50:42 +0800 |
---|---|---|
committer | Frank Dai <soulmachine@gmail.com> | 2013-12-25 16:50:42 +0800 |
commit | 3dc655aa19f678219e5d999fe97ab769567ffb1c (patch) | |
tree | 4daa85039ddcd82d0e262a4786d4c63c2ca4b747 /mllib | |
parent | 85a344b4f0cd149c6e6f06f8b942c34146b302be (diff) | |
download | spark-3dc655aa19f678219e5d999fe97ab769567ffb1c.tar.gz spark-3dc655aa19f678219e5d999fe97ab769567ffb1c.tar.bz2 spark-3dc655aa19f678219e5d999fe97ab769567ffb1c.zip |
standard Naive Bayes classifier
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala | 103 | ||||
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala | 92 |
2 files changed, 195 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala new file mode 100644 index 0000000000..f1b0e6ee6a --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -0,0 +1,103 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.classification + +import scala.collection.mutable + +import org.apache.spark.Logging +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.rdd.RDD +import org.apache.spark.SparkContext._ +import org.jblas.DoubleMatrix + +/** + * Model for Naive Bayes Classifiers. + * + * @param weightPerLabel Weights computed for every label, which's dimension is C. + * @param weightMatrix Weights computed for every label and feature, which's dimension is CXD + */ +class NaiveBayesModel(val weightPerLabel: Array[Double], + val weightMatrix: Array[Array[Double]]) + extends ClassificationModel with Serializable { + + // Create a column vector that can be used for predictions + private val _weightPerLabel = new DoubleMatrix(weightPerLabel.length, 1, weightPerLabel:_*) + private val _weightMatrix = new DoubleMatrix(weightMatrix) + + def predict(testData: RDD[Array[Double]]): RDD[Double] = testData.map(predict) + + def predict(testData: Array[Double]): Double = { + val dataMatrix = new DoubleMatrix(testData.length, 1, testData: _*) + val result = _weightPerLabel.add(_weightMatrix.mmul(dataMatrix)) + result.argmax() + } +} + + + +class NaiveBayes private (val lambda: Double = 1.0) // smoothing parameter + extends Serializable with Logging { + + /** + * Run the algorithm with the configured parameters on an input + * RDD of LabeledPoint entries. + * + * @param C kind of labels, labels are continuous integers and the maximal label is C-1 + * @param D dimension of feature vectors + * @param data RDD of (label, array of features) pairs. + */ + def run(C: Int, D: Int, data: RDD[LabeledPoint]): NaiveBayesModel = { + val groupedData = data.map(p => p.label.toInt -> p.features).groupByKey() + + val countPerLabel = groupedData.mapValues(_.size) + val logDenominator = math.log(data.count() + C * lambda) + val weightPerLabel = countPerLabel.mapValues { + count => math.log(count + lambda) - logDenominator + } + + val summedObservations = groupedData.mapValues(_.reduce { + (lhs, rhs) => lhs.zip(rhs).map(pair => pair._1 + pair._2) + }) + + val weightsMatrix = summedObservations.mapValues { weights => + val sum = weights.sum + val logDenom = math.log(sum + D * lambda) + weights.map(w => math.log(w + lambda) - logDenom) + } + + val labelWeights = weightPerLabel.collect().sorted.map(_._2) + val weightsMat = weightsMatrix.collect().sortBy(_._1).map(_._2) + + new NaiveBayesModel(labelWeights, weightsMat) + } +} + +object NaiveBayes { + /** + * Train a naive bayes model given an RDD of (label, features) pairs. + * + * @param C kind of labels, the maximal label is C-1 + * @param D dimension of feature vectors + * @param input RDD of (label, array of features) pairs. + * @param lambda smooth parameter + */ + def train(C: Int, D: Int, input: RDD[LabeledPoint], + lambda: Double = 1.0): NaiveBayesModel = { + new NaiveBayes(lambda).run(C, D, input) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala new file mode 100644 index 0000000000..d871ed3672 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -0,0 +1,92 @@ +package org.apache.spark.mllib.classification + +import scala.collection.JavaConversions._ +import scala.util.Random + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.FunSuite + +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.SparkContext + +object NaiveBayesSuite { + + private def calcLabel(p: Double, weightPerLabel: Array[Double]): Int = { + var sum = 0.0 + for (j <- 0 until weightPerLabel.length) { + sum += weightPerLabel(j) + if (p < sum) return j + } + -1 + } + + // Generate input of the form Y = (weightMatrix*x).argmax() + def generateNaiveBayesInput( + weightPerLabel: Array[Double], // 1XC + weightsMatrix: Array[Array[Double]], // CXD + nPoints: Int, + seed: Int): Seq[LabeledPoint] = { + val D = weightsMatrix(0).length + val rnd = new Random(seed) + + val _weightPerLabel = weightPerLabel.map(math.pow(math.E, _)) + val _weightMatrix = weightsMatrix.map(row => row.map(math.pow(math.E, _))) + + for (i <- 0 until nPoints) yield { + val y = calcLabel(rnd.nextDouble(), _weightPerLabel) + val xi = Array.tabulate[Double](D) { j => + if (rnd.nextDouble() < _weightMatrix(y)(j)) 1 else 0 + } + + LabeledPoint(y, xi) + } + } +} + +class NaiveBayesSuite extends FunSuite with BeforeAndAfterAll { + @transient private var sc: SparkContext = _ + + override def beforeAll() { + sc = new SparkContext("local", "test") + } + + override def afterAll() { + sc.stop() + System.clearProperty("spark.driver.port") + } + + def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { + val numOffPredictions = predictions.zip(input).count { + case (prediction, expected) => + prediction != expected.label + } + // At least 80% of the predictions should be on. + assert(numOffPredictions < input.length / 5) + } + + test("Naive Bayes") { + val nPoints = 10000 + + val weightPerLabel = Array(math.log(0.5), math.log(0.3), math.log(0.2)) + val weightsMatrix = Array( + Array(math.log(0.91), math.log(0.03), math.log(0.03), math.log(0.03)), // label 0 + Array(math.log(0.03), math.log(0.91), math.log(0.03), math.log(0.03)), // label 1 + Array(math.log(0.03), math.log(0.03), math.log(0.91), math.log(0.03)) // label 2 + ) + + val testData = NaiveBayesSuite.generateNaiveBayesInput(weightPerLabel, weightsMatrix, nPoints, 42) + val testRDD = sc.parallelize(testData, 2) + testRDD.cache() + + val model = NaiveBayes.train(3, 4, testRDD) + + val validationData = NaiveBayesSuite.generateNaiveBayesInput(weightPerLabel, weightsMatrix, nPoints, 17) + val validationRDD = sc.parallelize(validationData, 2) + + // Test prediction on RDD. + validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + + // Test prediction on Array. + validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + } +} |