aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala28
1 files changed, 25 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index c9b3ff0172..b381dc2cb0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -87,12 +87,17 @@ class NaiveBayesModel private[mllib] (
}
override def predict(testData: Vector): Double = {
+ val brzData = testData.toBreeze
modelType match {
case "Multinomial" =>
- labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) )
+ labels (brzArgmax (brzPi + brzTheta * brzData) )
case "Bernoulli" =>
+ if (!brzData.forall(v => v == 0.0 || v == 1.0)) {
+ throw new SparkException(
+ s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData.")
+ }
labels (brzArgmax (brzPi +
- (brzTheta - brzNegTheta.get) * testData.toBreeze + brzNegThetaSum.get))
+ (brzTheta - brzNegTheta.get) * brzData + brzNegThetaSum.get))
case _ =>
// This should never happen.
throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType")
@@ -293,12 +298,29 @@ class NaiveBayes private (
}
}
+ val requireZeroOneBernoulliValues: Vector => Unit = (v: Vector) => {
+ val values = v match {
+ case SparseVector(size, indices, values) =>
+ values
+ case DenseVector(values) =>
+ values
+ }
+ if (!values.forall(v => v == 0.0 || v == 1.0)) {
+ throw new SparkException(
+ s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $v.")
+ }
+ }
+
// Aggregates term frequencies per label.
// TODO: Calling combineByKey and collect creates two stages, we can implement something
// TODO: similar to reduceByKeyLocally to save one stage.
val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, BDV[Double])](
createCombiner = (v: Vector) => {
- requireNonnegativeValues(v)
+ if (modelType == "Bernoulli") {
+ requireZeroOneBernoulliValues(v)
+ } else {
+ requireNonnegativeValues(v)
+ }
(1L, v.toBreeze.toDenseVector)
},
mergeValue = (c: (Long, BDV[Double]), v: Vector) => {