aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala65
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala54
2 files changed, 101 insertions, 18 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index 5be35fe209..b46b676204 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -69,7 +69,7 @@ private[ml] trait NaiveBayesParams extends PredictorParams {
* The input feature values must be nonnegative.
*/
class NaiveBayes(override val uid: String)
- extends Predictor[Vector, NaiveBayes, NaiveBayesModel]
+ extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel]
with NaiveBayesParams {
def this() = this(Identifiable.randomUID("nb"))
@@ -106,7 +106,7 @@ class NaiveBayesModel private[ml] (
override val uid: String,
val pi: Vector,
val theta: Matrix)
- extends PredictionModel[Vector, NaiveBayesModel] with NaiveBayesParams {
+ extends ProbabilisticClassificationModel[Vector, NaiveBayesModel] with NaiveBayesParams {
import OldNaiveBayes.{Bernoulli, Multinomial}
@@ -129,29 +129,62 @@ class NaiveBayesModel private[ml] (
throw new UnknownError(s"Invalid modelType: ${$(modelType)}.")
}
- override protected def predict(features: Vector): Double = {
+ override val numClasses: Int = pi.size
+
+ private def multinomialCalculation(features: Vector) = {
+ val prob = theta.multiply(features)
+ BLAS.axpy(1.0, pi, prob)
+ prob
+ }
+
+ private def bernoulliCalculation(features: Vector) = {
+ features.foreachActive((_, value) =>
+ if (value != 0.0 && value != 1.0) {
+ throw new SparkException(
+ s"Bernoulli naive Bayes requires 0 or 1 feature values but found $features.")
+ }
+ )
+ val prob = thetaMinusNegTheta.get.multiply(features)
+ BLAS.axpy(1.0, pi, prob)
+ BLAS.axpy(1.0, negThetaSum.get, prob)
+ prob
+ }
+
+ override protected def predictRaw(features: Vector): Vector = {
$(modelType) match {
case Multinomial =>
- val prob = theta.multiply(features)
- BLAS.axpy(1.0, pi, prob)
- prob.argmax
+ multinomialCalculation(features)
case Bernoulli =>
- features.foreachActive{ (index, value) =>
- if (value != 0.0 && value != 1.0) {
- throw new SparkException(
- s"Bernoulli naive Bayes requires 0 or 1 feature values but found $features")
- }
- }
- val prob = thetaMinusNegTheta.get.multiply(features)
- BLAS.axpy(1.0, pi, prob)
- BLAS.axpy(1.0, negThetaSum.get, prob)
- prob.argmax
+ bernoulliCalculation(features)
case _ =>
// This should never happen.
throw new UnknownError(s"Invalid modelType: ${$(modelType)}.")
}
}
+ override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
+ rawPrediction match {
+ case dv: DenseVector =>
+ var i = 0
+ val size = dv.size
+ val maxLog = dv.values.max
+ while (i < size) {
+ dv.values(i) = math.exp(dv.values(i) - maxLog)
+ i += 1
+ }
+ val probSum = dv.values.sum
+ i = 0
+ while (i < size) {
+ dv.values(i) = dv.values(i) / probSum
+ i += 1
+ }
+ dv
+ case sv: SparseVector =>
+ throw new RuntimeException("Unexpected error in NaiveBayesModel:" +
+ " raw2probabilityInPlace encountered SparseVector")
+ }
+ }
+
override def copy(extra: ParamMap): NaiveBayesModel = {
copyValues(new NaiveBayesModel(uid, pi, theta).setParent(this.parent), extra)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
index 264bde3703..aea3d9b694 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -17,8 +17,11 @@
package org.apache.spark.ml.classification
+import breeze.linalg.{Vector => BV}
+
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.mllib.classification.NaiveBayes
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
@@ -28,6 +31,8 @@ import org.apache.spark.sql.Row
class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
+ import NaiveBayes.{Multinomial, Bernoulli}
+
def validatePrediction(predictionAndLabels: DataFrame): Unit = {
val numOfErrorPredictions = predictionAndLabels.collect().count {
case Row(prediction: Double, label: Double) =>
@@ -46,6 +51,43 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(model.theta.map(math.exp) ~== thetaData.map(math.exp) absTol 0.05, "theta mismatch")
}
+ def expectedMultinomialProbabilities(model: NaiveBayesModel, feature: Vector): Vector = {
+ val logClassProbs: BV[Double] = model.pi.toBreeze + model.theta.multiply(feature).toBreeze
+ val classProbs = logClassProbs.toArray.map(math.exp)
+ val classProbsSum = classProbs.sum
+ Vectors.dense(classProbs.map(_ / classProbsSum))
+ }
+
+ def expectedBernoulliProbabilities(model: NaiveBayesModel, feature: Vector): Vector = {
+ val negThetaMatrix = model.theta.map(v => math.log(1.0 - math.exp(v)))
+ val negFeature = Vectors.dense(feature.toArray.map(v => 1.0 - v))
+ val piTheta: BV[Double] = model.pi.toBreeze + model.theta.multiply(feature).toBreeze
+ val logClassProbs: BV[Double] = piTheta + negThetaMatrix.multiply(negFeature).toBreeze
+ val classProbs = logClassProbs.toArray.map(math.exp)
+ val classProbsSum = classProbs.sum
+ Vectors.dense(classProbs.map(_ / classProbsSum))
+ }
+
+ def validateProbabilities(
+ featureAndProbabilities: DataFrame,
+ model: NaiveBayesModel,
+ modelType: String): Unit = {
+ featureAndProbabilities.collect().foreach {
+ case Row(features: Vector, probability: Vector) => {
+ assert(probability.toArray.sum ~== 1.0 relTol 1.0e-10)
+ val expected = modelType match {
+ case Multinomial =>
+ expectedMultinomialProbabilities(model, features)
+ case Bernoulli =>
+ expectedBernoulliProbabilities(model, features)
+ case _ =>
+ throw new UnknownError(s"Invalid modelType: $modelType.")
+ }
+ assert(probability ~== expected relTol 1.0e-10)
+ }
+ }
+ }
+
test("params") {
ParamsSuite.checkParams(new NaiveBayes)
val model = new NaiveBayesModel("nb", pi = Vectors.dense(Array(0.2, 0.8)),
@@ -83,9 +125,13 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
piArray, thetaArray, nPoints, 17, "multinomial"))
- val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
+ val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
validatePrediction(predictionAndLabels)
+
+ val featureAndProbabilities = model.transform(validationDataset)
+ .select("features", "probability")
+ validateProbabilities(featureAndProbabilities, model, "multinomial")
}
test("Naive Bayes Bernoulli") {
@@ -109,8 +155,12 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext {
val validationDataset = sqlContext.createDataFrame(generateNaiveBayesInput(
piArray, thetaArray, nPoints, 20, "bernoulli"))
- val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
+ val predictionAndLabels = model.transform(validationDataset).select("prediction", "label")
validatePrediction(predictionAndLabels)
+
+ val featureAndProbabilities = model.transform(validationDataset)
+ .select("features", "probability")
+ validateProbabilities(featureAndProbabilities, model, "bernoulli")
}
}