aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorleahmcguire <lmcguire@salesforce.com>2015-05-13 14:13:19 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-05-13 14:13:19 -0700
commit61e05fc58e1245de871c409b60951745b5db3420 (patch)
tree5265a0e3646ca4d5d5194a3c3144c16da911b6e8 /mllib
parent5db18ba6e1bd8c6307c41549176c53590cf344a0 (diff)
downloadspark-61e05fc58e1245de871c409b60951745b5db3420.tar.gz
spark-61e05fc58e1245de871c409b60951745b5db3420.tar.bz2
spark-61e05fc58e1245de871c409b60951745b5db3420.zip
[SPARK-7545] [MLLIB] Added check in Bernoulli Naive Bayes to make sure that both training and predict features have values of 0 or 1
Author: leahmcguire <lmcguire@salesforce.com> Closes #6073 from leahmcguire/binaryCheckNB and squashes the following commits: b8442c2 [leahmcguire] changed to if else for value checks 911bf83 [leahmcguire] undid reformat 4eedf1e [leahmcguire] moved bernoulli check 9ee9e84 [leahmcguire] fixed style error 3f3b32c [leahmcguire] fixed zero one check so only called in combiner 831fd27 [leahmcguire] got test working f44bb3c [leahmcguire] removed changes from CV branch 67253f0 [leahmcguire] added check to bernoulli to ensure feature values are zero or one f191c71 [leahmcguire] fixed name 58d060b [leahmcguire] changed param name and test according to comments 04f0d3c [leahmcguire] Added stats from cross validation as a val in the cross validation model to save them for user access
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala28
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala33
2 files changed, 58 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
index c9b3ff0172..b381dc2cb0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/classification/NaiveBayes.scala
@@ -87,12 +87,17 @@ class NaiveBayesModel private[mllib] (
}
override def predict(testData: Vector): Double = {
+ val brzData = testData.toBreeze
modelType match {
case "Multinomial" =>
- labels (brzArgmax (brzPi + brzTheta * testData.toBreeze) )
+ labels (brzArgmax (brzPi + brzTheta * brzData) )
case "Bernoulli" =>
+ if (!brzData.forall(v => v == 0.0 || v == 1.0)) {
+ throw new SparkException(
+ s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $testData.")
+ }
labels (brzArgmax (brzPi +
- (brzTheta - brzNegTheta.get) * testData.toBreeze + brzNegThetaSum.get))
+ (brzTheta - brzNegTheta.get) * brzData + brzNegThetaSum.get))
case _ =>
// This should never happen.
throw new UnknownError(s"NaiveBayesModel was created with an unknown ModelType: $modelType")
@@ -293,12 +298,29 @@ class NaiveBayes private (
}
}
+ val requireZeroOneBernoulliValues: Vector => Unit = (v: Vector) => {
+ val values = v match {
+ case SparseVector(size, indices, values) =>
+ values
+ case DenseVector(values) =>
+ values
+ }
+ if (!values.forall(v => v == 0.0 || v == 1.0)) {
+ throw new SparkException(
+ s"Bernoulli Naive Bayes requires 0 or 1 feature values but found $v.")
+ }
+ }
+
// Aggregates term frequencies per label.
// TODO: Calling combineByKey and collect creates two stages, we can implement something
// TODO: similar to reduceByKeyLocally to save one stage.
val aggregated = data.map(p => (p.label, p.features)).combineByKey[(Long, BDV[Double])](
createCombiner = (v: Vector) => {
- requireNonnegativeValues(v)
+ if (modelType == "Bernoulli") {
+ requireZeroOneBernoulliValues(v)
+ } else {
+ requireNonnegativeValues(v)
+ }
(1L, v.toBreeze.toDenseVector)
},
mergeValue = (c: (Long, BDV[Double]), v: Vector) => {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
index ea89b17b7c..40a79a1f19 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/classification/NaiveBayesSuite.scala
@@ -208,6 +208,39 @@ class NaiveBayesSuite extends FunSuite with MLlibTestSparkContext {
}
}
+ test("detect non zero or one values in Bernoulli") {
+ val badTrain = Seq(
+ LabeledPoint(1.0, Vectors.dense(1.0)),
+ LabeledPoint(0.0, Vectors.dense(2.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0)))
+
+ intercept[SparkException] {
+ NaiveBayes.train(sc.makeRDD(badTrain, 2), 1.0, "Bernoulli")
+ }
+
+ val okTrain = Seq(
+ LabeledPoint(1.0, Vectors.dense(1.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0))
+ )
+
+ val badPredict = Seq(
+ Vectors.dense(1.0),
+ Vectors.dense(2.0),
+ Vectors.dense(1.0),
+ Vectors.dense(0.0))
+
+ val model = NaiveBayes.train(sc.makeRDD(okTrain, 2), 1.0, "Bernoulli")
+ intercept[SparkException] {
+ model.predict(sc.makeRDD(badPredict, 2)).collect()
+ }
+ }
+
test("model save/load: 2.0 to 2.0") {
val tempDir = Utils.createTempDir()
val path = tempDir.toURI.toString