diff options
Diffstat (limited to 'mllib/src/main')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala | 28 |
1 files changed, 25 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala index c9b3ff0172..b381dc2cb0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala @@ -87,12 +87,17 @@ class NaiveBayesModel private[mllib] ( } override def predict(testData: Vector): Double = { + val brzData = testData.toBreeze modelType match { case "Multinomial" => - labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) ) + labels (brzArgmax (brzPi + brzTheta * brzData) ) case "Bernoulli" => + if (!brzData.forall(v => v == 0.0 || v == 1.0)) { + throw new SparkException( + s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData.") + } labels (brzArgmax (brzPi + - (brzTheta - brzNegTheta.get) * testData.toBreeze + brzNegThetaSum.get)) + (brzTheta - brzNegTheta.get) * brzData + brzNegThetaSum.get)) case _ => // This should never happen. throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType") @@ -293,12 +298,29 @@ class NaiveBayes private ( } } + val requireZeroOneBernoulliValues: Vector => Unit = (v: Vector) => { + val values = v match { + case SparseVector(size, indices, values) => + values + case DenseVector(values) => + values + } + if (!values.forall(v => v == 0.0 || v == 1.0)) { + throw new SparkException( + s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $v.") + } + } + // Aggregates term frequencies per label. // TODO: Calling combineByKey and collect creates two stages, we can implement something // TODO: similar to reduceByKeyLocally to save one stage. val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, BDV[Double])]( createCombiner = (v: Vector) => { - requireNonnegativeValues(v) + if (modelType == "Bernoulli") { + requireZeroOneBernoulliValues(v) + } else { + requireNonnegativeValues(v) + } (1L, v.toBreeze.toDenseVector) }, mergeValue = (c: (Long, BDV[Double]), v: Vector) => { |