diff options
author | Frank Dai <soulmachine@gmail.com> | 2013-12-25 16:50:42 +0800 |
---|---|---|
committer | Frank Dai <soulmachine@gmail.com> | 2013-12-25 16:50:42 +0800 |
commit | 3dc655aa19f678219e5d999fe97ab769567ffb1c (patch) | |
tree | 4daa85039ddcd82d0e262a4786d4c63c2ca4b747 /mllib/src/test | |
parent | 85a344b4f0cd149c6e6f06f8b942c34146b302be (diff) | |
download | spark-3dc655aa19f678219e5d999fe97ab769567ffb1c.tar.gz spark-3dc655aa19f678219e5d999fe97ab769567ffb1c.tar.bz2 spark-3dc655aa19f678219e5d999fe97ab769567ffb1c.zip |
standard Naive Bayes classifier
Diffstat (limited to 'mllib/src/test')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala | 92 |
1 files changed, 92 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala new file mode 100644 index 0000000000..d871ed3672 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala @@ -0,0 +1,92 @@ +package org.apache.spark.mllib.classification + +import scala.collection.JavaConversions._ +import scala.util.Random + +import org.scalatest.BeforeAndAfterAll +import org.scalatest.FunSuite + +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.SparkContext + +object NaiveBayesSuite { + + private def calcLabel(p: Double, weightPerLabel: Array[Double]): Int = { + var sum = 0.0 + for (j <- 0 until weightPerLabel.length) { + sum += weightPerLabel(j) + if (p < sum) return j + } + -1 + } + + // Generate input of the form Y = (weightMatrix*x).argmax() + def generateNaiveBayesInput( + weightPerLabel: Array[Double], // 1XC + weightsMatrix: Array[Array[Double]], // CXD + nPoints: Int, + seed: Int): Seq[LabeledPoint] = { + val D = weightsMatrix(0).length + val rnd = new Random(seed) + + val _weightPerLabel = weightPerLabel.map(math.pow(math.E, _)) + val _weightMatrix = weightsMatrix.map(row => row.map(math.pow(math.E, _))) + + for (i <- 0 until nPoints) yield { + val y = calcLabel(rnd.nextDouble(), _weightPerLabel) + val xi = Array.tabulate[Double](D) { j => + if (rnd.nextDouble() < _weightMatrix(y)(j)) 1 else 0 + } + + LabeledPoint(y, xi) + } + } +} + +class NaiveBayesSuite extends FunSuite with BeforeAndAfterAll { + @transient private var sc: SparkContext = _ + + override def beforeAll() { + sc = new SparkContext("local", "test") + } + + override def afterAll() { + sc.stop() + System.clearProperty("spark.driver.port") + } + + def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) { + val numOffPredictions = predictions.zip(input).count { + case (prediction, expected) => + prediction != expected.label + } + // At least 80% of the predictions should be on. + assert(numOffPredictions < input.length / 5) + } + + test("Naive Bayes") { + val nPoints = 10000 + + val weightPerLabel = Array(math.log(0.5), math.log(0.3), math.log(0.2)) + val weightsMatrix = Array( + Array(math.log(0.91), math.log(0.03), math.log(0.03), math.log(0.03)), // label 0 + Array(math.log(0.03), math.log(0.91), math.log(0.03), math.log(0.03)), // label 1 + Array(math.log(0.03), math.log(0.03), math.log(0.91), math.log(0.03)) // label 2 + ) + + val testData = NaiveBayesSuite.generateNaiveBayesInput(weightPerLabel, weightsMatrix, nPoints, 42) + val testRDD = sc.parallelize(testData, 2) + testRDD.cache() + + val model = NaiveBayes.train(3, 4, testRDD) + + val validationData = NaiveBayesSuite.generateNaiveBayesInput(weightPerLabel, weightsMatrix, nPoints, 17) + val validationRDD = sc.parallelize(validationData, 2) + + // Test prediction on RDD. + validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData) + + // Test prediction on Array. + validatePrediction(validationData.map(row => model.predict(row.features)), validationData) + } +} |