aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala76
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala55
2 files changed, 113 insertions, 18 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index f51ee36d0d..9e379d7d74 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -93,26 +93,70 @@ class NaiveBayesModel private[mllib] (
override def predict(testData: Vector): Double = {
modelType match {
case Multinomial =>
- val prob = thetaMatrix.multiply(testData)
- BLAS.axpy(1.0, piVector, prob)
- labels(prob.argmax)
+ labels(multinomialCalculation(testData).argmax)
case Bernoulli =>
- testData.foreachActive { (index, value) =>
- if (value != 0.0 && value != 1.0) {
- throw new SparkException(
- s"Bernoulli naive Bayes requires 0 or 1 feature values but found $testData.")
- }
- }
- val prob = thetaMinusNegTheta.get.multiply(testData)
- BLAS.axpy(1.0, piVector, prob)
- BLAS.axpy(1.0, negThetaSum.get, prob)
- labels(prob.argmax)
- case _ =>
- // This should never happen.
- throw new UnknownError(s"Invalid modelType: $modelType.")
+ labels(bernoulliCalculation(testData).argmax)
+ }
+ }
+
+ /**
+ * Predict values for the given data set using the model trained.
+ *
+ * @param testData RDD representing data points to be predicted
+ * @return an RDD[Vector] where each entry contains the predicted posterior class probabilities,
+ * in the same order as class labels
+ */
+ def predictProbabilities(testData: RDD[Vector]): RDD[Vector] = {
+ val bcModel = testData.context.broadcast(this)
+ testData.mapPartitions { iter =>
+ val model = bcModel.value
+ iter.map(model.predictProbabilities)
}
}
+ /**
+ * Predict posterior class probabilities for a single data point using the model trained.
+ *
+ * @param testData array representing a single data point
+ * @return predicted posterior class probabilities from the trained model,
+ * in the same order as class labels
+ */
+ def predictProbabilities(testData: Vector): Vector = {
+ modelType match {
+ case Multinomial =>
+ posteriorProbabilities(multinomialCalculation(testData))
+ case Bernoulli =>
+ posteriorProbabilities(bernoulliCalculation(testData))
+ }
+ }
+
+ private def multinomialCalculation(testData: Vector) = {
+ val prob = thetaMatrix.multiply(testData)
+ BLAS.axpy(1.0, piVector, prob)
+ prob
+ }
+
+ private def bernoulliCalculation(testData: Vector) = {
+ testData.foreachActive((_, value) =>
+ if (value != 0.0 && value != 1.0) {
+ throw new SparkException(
+ s"Bernoulli naive Bayes requires 0 or 1 feature values but found $testData.")
+ }
+ )
+ val prob = thetaMinusNegTheta.get.multiply(testData)
+ BLAS.axpy(1.0, piVector, prob)
+ BLAS.axpy(1.0, negThetaSum.get, prob)
+ prob
+ }
+
+ private def posteriorProbabilities(logProb: DenseVector) = {
+ val logProbArray = logProb.toArray
+ val maxLog = logProbArray.max
+ val scaledProbs = logProbArray.map(lp => math.exp(lp - maxLog))
+ val probSum = scaledProbs.sum
+ new DenseVector(scaledProbs.map(_ / probSum))
+ }
+
override def save(sc: SparkContext, path: String): Unit = {
val data = NaiveBayesModel.SaveLoadV2_0.Data(labels, pi, theta, modelType)
NaiveBayesModel.SaveLoadV2_0.save(sc, path, data)
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
index f7fc873060..cffa1ab700 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
@@ -19,13 +19,14 @@ package org.apache.spark.mllib.classification
import scala.util.Random
-import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, argmax => brzArgmax, sum => brzSum}
+import breeze.linalg.{DenseMatrix => BDM, DenseVector => BDV, Vector => BV}
import breeze.stats.distributions.{Multinomial => BrzMultinomial}
import org.apache.spark.{SparkException, SparkFunSuite}
-import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.{LocalClusterSparkContext, MLlibTestSparkContext}
+import org.apache.spark.mllib.util.TestingUtils._
import org.apache.spark.util.Utils
object NaiveBayesSuite {
@@ -154,6 +155,29 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
// Test prediction on Array.
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
+
+ // Test posteriors
+ validationData.map(_.features).foreach { features =>
+ val predicted = model.predictProbabilities(features).toArray
+ assert(predicted.sum ~== 1.0 relTol 1.0e-10)
+ val expected = expectedMultinomialProbabilities(model, features)
+ expected.zip(predicted).foreach { case (e, p) => assert(e ~== p relTol 1.0e-10) }
+ }
+ }
+
+ /**
+ * @param model Multinomial Naive Bayes model
+ * @param testData input to compute posterior probabilities for
+ * @return posterior class probabilities (in order of labels) for input
+ */
+ private def expectedMultinomialProbabilities(model: NaiveBayesModel, testData: Vector) = {
+ val piVector = new BDV(model.pi)
+ // model.theta is row-major; treat it as col-major representation of transpose, and transpose:
+ val thetaMatrix = new BDM(model.theta(0).length, model.theta.length, model.theta.flatten).t
+ val logClassProbs: BV[Double] = piVector + (thetaMatrix * testData.toBreeze)
+ val classProbs = logClassProbs.toArray.map(math.exp)
+ val classProbsSum = classProbs.sum
+ classProbs.map(_ / classProbsSum)
}
test("Naive Bayes Bernoulli") {
@@ -182,6 +206,33 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
// Test prediction on Array.
validatePrediction(validationData.map(row => model.predict(row.features)), validationData)
+
+ // Test posteriors
+ validationData.map(_.features).foreach { features =>
+ val predicted = model.predictProbabilities(features).toArray
+ assert(predicted.sum ~== 1.0 relTol 1.0e-10)
+ val expected = expectedBernoulliProbabilities(model, features)
+ expected.zip(predicted).foreach { case (e, p) => assert(e ~== p relTol 1.0e-10) }
+ }
+ }
+
+ /**
+ * @param model Bernoulli Naive Bayes model
+ * @param testData input to compute posterior probabilities for
+ * @return posterior class probabilities (in order of labels) for input
+ */
+ private def expectedBernoulliProbabilities(model: NaiveBayesModel, testData: Vector) = {
+ val piVector = new BDV(model.pi)
+ val thetaMatrix = new BDM(model.theta(0).length, model.theta.length, model.theta.flatten).t
+ val negThetaMatrix = new BDM(model.theta(0).length, model.theta.length,
+ model.theta.flatten.map(v => math.log(1.0 - math.exp(v)))).t
+ val testBreeze = testData.toBreeze
+ val negTestBreeze = new BDV(Array.fill(testBreeze.size)(1.0)) - testBreeze
+ val piTheta: BV[Double] = piVector + (thetaMatrix * testBreeze)
+ val logClassProbs: BV[Double] = piTheta + (negThetaMatrix * negTestBreeze)
+ val classProbs = logClassProbs.toArray.map(math.exp)
+ val classProbsSum = classProbs.sum
+ classProbs.map(_ / classProbsSum)
}
test("detect negative values") {