aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala54
1 files changed, 52 insertions, 2 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
index 264bde3703..aea3d9b694 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -17,8 +17,11 @@
package org.apache.spark.ml.classification
+import breeze.linalg.{Vector => BV}
+
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.mllib.classification.NaiveBayes
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
@@ -28,6 +31,8 @@ import org.apache.spark.sql.Row
class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
+ import NaiveBayes.{Multinomial, Bernoulli}
+
def validatePrediction(predictionAndLabels: DataFrame): Unit = {
val numOfErrorPredictions = predictionAndLabels.collect().count {
case Row(prediction: Double, label: Double) =>
@@ -46,6 +51,43 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(model.theta.map(math.exp) ~== thetaData.map(math.exp) absTol 0.05, "theta mismatch")
}
+ def expectedMultinomialProbabilities(model: NaiveBayesModel, feature: Vector): Vector = {
+ val logClassProbs: BV[Double] = model.pi.toBreeze + model.theta.multiply(feature).toBreeze
+ val classProbs = logClassProbs.toArray.map(math.exp)
+ val classProbsSum = classProbs.sum
+ Vectors.dense(classProbs.map(_ / classProbsSum))
+ }
+
+ def expectedBernoulliProbabilities(model: NaiveBayesModel, feature: Vector): Vector = {
+ val negThetaMatrix = model.theta.map(v => math.log(1.0 - math.exp(v)))
+ val negFeature = Vectors.dense(feature.toArray.map(v => 1.0 - v))
+ val piTheta: BV[Double] = model.pi.toBreeze + model.theta.multiply(feature).toBreeze
+ val logClassProbs: BV[Double] = piTheta + negThetaMatrix.multiply(negFeature).toBreeze
+ val classProbs = logClassProbs.toArray.map(math.exp)
+ val classProbsSum = classProbs.sum
+ Vectors.dense(classProbs.map(_ / classProbsSum))
+ }
+
+ def validateProbabilities(
+ featureAndProbabilities: DataFrame,
+ model: NaiveBayesModel,
+ modelType: String): Unit = {
+ featureAndProbabilities.collect().foreach {
+ case Row(features: Vector, probability: Vector) => {
+ assert(probability.toArray.sum ~== 1.0 relTol 1.0e-10)
+ val expected = modelType match {
+ case Multinomial =>
+ expectedMultinomialProbabilities(model, features)
+ case Bernoulli =>
+ expectedBernoulliProbabilities(model, features)
+ case _ =>
+ throw new UnknownError(s"Invalid modelType: $modelType.")
+ }
+ assert(probability ~== expected relTol 1.0e-10)
+ }
+ }
+ }
+
test("params") {
ParamsSuite.checkParams(new NaiveBayes)
val model = new NaiveBayesModel("nb", pi = Vectors.dense(Array(0.2, 0.8)),
@@ -83,9 +125,13 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
piArray, thetaArray, nPoints, 17, "multinomial"))
- val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
+ val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
validatePrediction(predictionAndLabels)
+
+ val featureAndProbabilities = model.transform(validationDataset)
+ .select("features", "probability")
+ validateProbabilities(featureAndProbabilities, model, "multinomial")
}
test("Naive Bayes Bernoulli") {
@@ -109,8 +155,12 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
piArray, thetaArray, nPoints, 20, "bernoulli"))
- val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
+ val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
validatePrediction(predictionAndLabels)
+
+ val featureAndProbabilities = model.transform(validationDataset)
+ .select("features", "probability")
+ validateProbabilities(featureAndProbabilities, model, "bernoulli")
}
}