aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-07-31 13:11:42 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-07-31 13:11:42 -0700
commitfbef566a107b47e5fddde0ea65b8587d5039062d (patch)
tree8df3e894582f62558df7fb8efb29a8ae4f63a9f2 /mllib/src/test
parent060c79aab58efd4ce7353a1b00534de0d9e1de0b (diff)
downloadspark-fbef566a107b47e5fddde0ea65b8587d5039062d.tar.gz
spark-fbef566a107b47e5fddde0ea65b8587d5039062d.tar.bz2
spark-fbef566a107b47e5fddde0ea65b8587d5039062d.zip
[SPARK-9308] [ML] ml.NaiveBayesModel support predicting class probabilities
Make NaiveBayesModel support predicting class probabilities, inherit from ProbabilisticClassificationModel. Author: Yanbo Liang <ybliang8@gmail.com> Closes #7672 from yanboliang/spark-9308 and squashes the following commits: 25e224c [Yanbo Liang] raw2probabilityInPlace should operate in-place 3ee56d6 [Yanbo Liang] change predictRaw and raw2probabilityInPlace c07e7a2 [Yanbo Liang] ml.NaiveBayesModel support predicting class probabilities
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala54
1 files changed, 52 insertions, 2 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
index 264bde3703..aea3d9b694 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -17,8 +17,11 @@
package org.apache.spark.ml.classification
+import breeze.linalg.{Vector => BV}
+
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.mllib.classification.NaiveBayes
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
@@ -28,6 +31,8 @@ import org.apache.spark.sql.Row
class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
+ import NaiveBayes.{Multinomial, Bernoulli}
+
def validatePrediction(predictionAndLabels: DataFrame): Unit = {
val numOfErrorPredictions = predictionAndLabels.collect().count {
case Row(prediction: Double, label: Double) =>
@@ -46,6 +51,43 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(model.theta.map(math.exp) ~== thetaData.map(math.exp) absTol 0.05, "theta mismatch")
}
+ def expectedMultinomialProbabilities(model: NaiveBayesModel, feature: Vector): Vector = {
+ val logClassProbs: BV[Double] = model.pi.toBreeze + model.theta.multiply(feature).toBreeze
+ val classProbs = logClassProbs.toArray.map(math.exp)
+ val classProbsSum = classProbs.sum
+ Vectors.dense(classProbs.map(_ / classProbsSum))
+ }
+
+ def expectedBernoulliProbabilities(model: NaiveBayesModel, feature: Vector): Vector = {
+ val negThetaMatrix = model.theta.map(v => math.log(1.0 - math.exp(v)))
+ val negFeature = Vectors.dense(feature.toArray.map(v => 1.0 - v))
+ val piTheta: BV[Double] = model.pi.toBreeze + model.theta.multiply(feature).toBreeze
+ val logClassProbs: BV[Double] = piTheta + negThetaMatrix.multiply(negFeature).toBreeze
+ val classProbs = logClassProbs.toArray.map(math.exp)
+ val classProbsSum = classProbs.sum
+ Vectors.dense(classProbs.map(_ / classProbsSum))
+ }
+
+ def validateProbabilities(
+ featureAndProbabilities: DataFrame,
+ model: NaiveBayesModel,
+ modelType: String): Unit = {
+ featureAndProbabilities.collect().foreach {
+ case Row(features: Vector, probability: Vector) => {
+ assert(probability.toArray.sum ~== 1.0 relTol 1.0e-10)
+ val expected = modelType match {
+ case Multinomial =>
+ expectedMultinomialProbabilities(model, features)
+ case Bernoulli =>
+ expectedBernoulliProbabilities(model, features)
+ case _ =>
+ throw new UnknownError(s"Invalid modelType: $modelType.")
+ }
+ assert(probability ~== expected relTol 1.0e-10)
+ }
+ }
+ }
+
test("params") {
ParamsSuite.checkParams(new NaiveBayes)
val model = new NaiveBayesModel("nb", pi = Vectors.dense(Array(0.2, 0.8)),
@@ -83,9 +125,13 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
piArray, thetaArray, nPoints, 17, "multinomial"))
- val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
+ val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
validatePrediction(predictionAndLabels)
+
+ val featureAndProbabilities = model.transform(validationDataset)
+ .select("features", "probability")
+ validateProbabilities(featureAndProbabilities, model, "multinomial")
}
test("Naive Bayes Bernoulli") {
@@ -109,8 +155,12 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
piArray, thetaArray, nPoints, 20, "bernoulli"))
- val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
+ val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
validatePrediction(predictionAndLabels)
+
+ val featureAndProbabilities = model.transform(validationDataset)
+ .select("features", "probability")
+ validateProbabilities(featureAndProbabilities, model, "bernoulli")
}
}