aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-07-31 13:11:42 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-07-31 13:11:42 -0700
commitfbef566a107b47e5fddde0ea65b8587d5039062d (patch)
tree8df3e894582f62558df7fb8efb29a8ae4f63a9f2 /mllib/src/main
parent060c79aab58efd4ce7353a1b00534de0d9e1de0b (diff)
downloadspark-fbef566a107b47e5fddde0ea65b8587d5039062d.tar.gz
spark-fbef566a107b47e5fddde0ea65b8587d5039062d.tar.bz2
spark-fbef566a107b47e5fddde0ea65b8587d5039062d.zip
[SPARK-9308] [ML] ml.NaiveBayesModel support predicting class probabilities
Make NaiveBayesModel support predicting class probabilities, inherit from ProbabilisticClassificationModel. Author: Yanbo Liang <ybliang8@gmail.com> Closes #7672 from yanboliang/spark-9308 and squashes the following commits: 25e224c [Yanbo Liang] raw2probabilityInPlace should operate in-place 3ee56d6 [Yanbo Liang] change predictRaw and raw2probabilityInPlace c07e7a2 [Yanbo Liang] ml.NaiveBayesModel support predicting class probabilities
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala65
1 files changed, 49 insertions, 16 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
index 5be35fe209..b46b676204 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/NaiveBayes.scala
@@ -69,7 +69,7 @@ private[ml] trait NaiveBayesParams extends PredictorParams {
* The input feature values must be nonnegative.
*/
class NaiveBayes(override val uid: String)
- extends Predictor[Vector, NaiveBayes, NaiveBayesModel]
+ extends ProbabilisticClassifier[Vector, NaiveBayes, NaiveBayesModel]
with NaiveBayesParams {
def this() = this(Identifiable.randomUID("nb"))
@@ -106,7 +106,7 @@ class NaiveBayesModel private[ml] (
override val uid: String,
val pi: Vector,
val theta: Matrix)
- extends PredictionModel[Vector, NaiveBayesModel] with NaiveBayesParams {
+ extends ProbabilisticClassificationModel[Vector, NaiveBayesModel] with NaiveBayesParams {
import OldNaiveBayes.{Bernoulli, Multinomial}
@@ -129,29 +129,62 @@ class NaiveBayesModel private[ml] (
throw new UnknownError(s"Invalid modelType: ${$(modelType)}.")
}
- override protected def predict(features: Vector): Double = {
+ override val numClasses: Int = pi.size
+
+ private def multinomialCalculation(features: Vector) = {
+ val prob = theta.multiply(features)
+ BLAS.axpy(1.0, pi, prob)
+ prob
+ }
+
+ private def bernoulliCalculation(features: Vector) = {
+ features.foreachActive((_, value) =>
+ if (value != 0.0 && value != 1.0) {
+ throw new SparkException(
+ s"Bernoulli naive Bayes requires 0 or 1 feature values but found $features.")
+ }
+ )
+ val prob = thetaMinusNegTheta.get.multiply(features)
+ BLAS.axpy(1.0, pi, prob)
+ BLAS.axpy(1.0, negThetaSum.get, prob)
+ prob
+ }
+
+ override protected def predictRaw(features: Vector): Vector = {
$(modelType) match {
case Multinomial =>
- val prob = theta.multiply(features)
- BLAS.axpy(1.0, pi, prob)
- prob.argmax
+ multinomialCalculation(features)
case Bernoulli =>
- features.foreachActive{ (index, value) =>
- if (value != 0.0 && value != 1.0) {
- throw new SparkException(
- s"Bernoulli naive Bayes requires 0 or 1 feature values but found $features")
- }
- }
- val prob = thetaMinusNegTheta.get.multiply(features)
- BLAS.axpy(1.0, pi, prob)
- BLAS.axpy(1.0, negThetaSum.get, prob)
- prob.argmax
+ bernoulliCalculation(features)
case _ =>
// This should never happen.
throw new UnknownError(s"Invalid modelType: ${$(modelType)}.")
}
}
+ override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
+ rawPrediction match {
+ case dv: DenseVector =>
+ var i = 0
+ val size = dv.size
+ val maxLog = dv.values.max
+ while (i < size) {
+ dv.values(i) = math.exp(dv.values(i) - maxLog)
+ i += 1
+ }
+ val probSum = dv.values.sum
+ i = 0
+ while (i < size) {
+ dv.values(i) = dv.values(i) / probSum
+ i += 1
+ }
+ dv
+ case sv: SparseVector =>
+ throw new RuntimeException("Unexpected error in NaiveBayesModel:" +
+ " raw2probabilityInPlace encountered SparseVector")
+ }
+ }
+
override def copy(extra: ParamMap): NaiveBayesModel = {
copyValues(new NaiveBayesModel(uid, pi, theta).setParent(this.parent), extra)
}