aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorFrank Dai <soulmachine@gmail.com>2013-12-25 16:50:42 +0800
committerFrank Dai <soulmachine@gmail.com>2013-12-25 16:50:42 +0800
commit3dc655aa19f678219e5d999fe97ab769567ffb1c (patch)
tree4daa85039ddcd82d0e262a4786d4c63c2ca4b747 /mllib/src/test
parent85a344b4f0cd149c6e6f06f8b942c34146b302be (diff)
downloadspark-3dc655aa19f678219e5d999fe97ab769567ffb1c.tar.gz
spark-3dc655aa19f678219e5d999fe97ab769567ffb1c.tar.bz2
spark-3dc655aa19f678219e5d999fe97ab769567ffb1c.zip
standard Naive Bayes classifier
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala92
1 files changed, 92 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
new file mode 100644
index 0000000000..d871ed3672
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
@@ -0,0 +1,92 @@
+package org.apache.spark.mllib.classification
+
+import scala.collection.JavaConversions._
+import scala.util.Random
+
+import org.scalatest.BeforeAndAfterAll
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.SparkContext
+
+object NaiveBayesSuite {
+
+ private def calcLabel(p: Double, weightPerLabel: Array[Double]): Int = {
+ var sum = 0.0
+ for (j <- 0 until weightPerLabel.length) {
+ sum += weightPerLabel(j)
+ if (p < sum) return j
+ }
+ -1
+ }
+
+ // Generate input of the form Y = (weightMatrix*x).argmax()
+ def generateNaiveBayesInput(
+ weightPerLabel: Array[Double], // 1XC
+ weightsMatrix: Array[Array[Double]], // CXD
+ nPoints: Int,
+ seed: Int): Seq[LabeledPoint] = {
+ val D = weightsMatrix(0).length
+ val rnd = new Random(seed)
+
+ val _weightPerLabel = weightPerLabel.map(math.pow(math.E, _))
+ val _weightMatrix = weightsMatrix.map(row => row.map(math.pow(math.E, _)))
+
+ for (i <- 0 until nPoints) yield {
+ val y = calcLabel(rnd.nextDouble(), _weightPerLabel)
+ val xi = Array.tabulate[Double](D) { j =>
+ if (rnd.nextDouble() < _weightMatrix(y)(j)) 1 else 0
+ }
+
+ LabeledPoint(y, xi)
+ }
+ }
+}
+
+class NaiveBayesSuite extends FunSuite with BeforeAndAfterAll {
+ @transient private var sc: SparkContext = _
+
+ override def beforeAll() {
+ sc = new SparkContext("local", "test")
+ }
+
+ override def afterAll() {
+ sc.stop()
+ System.clearProperty("spark.driver.port")
+ }
+
+ def validatePrediction(predictions: Seq[Double], input: Seq[LabeledPoint]) {
+ val numOffPredictions = predictions.zip(input).count {
+ case (prediction, expected) =>
+ prediction != expected.label
+ }
+ // At least 80% of the predictions should be on.
+ assert(numOffPredictions < input.length / 5)
+ }
+
+ test("Naive Bayes") {
+ val nPoints = 10000
+
+ val weightPerLabel = Array(math.log(0.5), math.log(0.3), math.log(0.2))
+ val weightsMatrix = Array(
+ Array(math.log(0.91), math.log(0.03), math.log(0.03), math.log(0.03)), // label 0
+ Array(math.log(0.03), math.log(0.91), math.log(0.03), math.log(0.03)), // label 1
+ Array(math.log(0.03), math.log(0.03), math.log(0.91), math.log(0.03)) // label 2
+ )
+
+ val testData = NaiveBayesSuite.generateNaiveBayesInput(weightPerLabel, weightsMatrix, nPoints, 42)
+ val testRDD = sc.parallelize(testData, 2)
+ testRDD.cache()
+
+ val model = NaiveBayes.train(3, 4, testRDD)
+
+ val validationData = NaiveBayesSuite.generateNaiveBayesInput(weightPerLabel, weightsMatrix, nPoints, 17)
+ val validationRDD = sc.parallelize(validationData, 2)
+
+ // Test prediction on RDD.
+ validatePrediction(model.predict(validationRDD.map(_.features)).collect(), validationData)
+
+ // Test prediction on Array.
+ validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
+ }
+}