aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala69
1 files changed, 67 insertions, 2 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
index 597428d036..e934e5ea42 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -22,10 +22,10 @@ import scala.util.Random
import breeze.linalg.{DenseVector => BDV, Vector => BV}
import breeze.stats.distributions.{Multinomial => BrzMultinomial}
-import org.apache.spark.SparkFunSuite
+import org.apache.spark.{SparkException, SparkFunSuite}
import org.apache.spark.ml.classification.NaiveBayes.{Bernoulli, Multinomial}
import org.apache.spark.ml.classification.NaiveBayesSuite._
-import org.apache.spark.ml.feature.{Instance, LabeledPoint}
+import org.apache.spark.ml.feature.LabeledPoint
import org.apache.spark.ml.linalg._
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
@@ -106,6 +106,11 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
}
}
+ test("model types") {
+ assert(Multinomial === "multinomial")
+ assert(Bernoulli === "bernoulli")
+ }
+
test("params") {
ParamsSuite.checkParams(new NaiveBayes)
val model = new NaiveBayesModel("nb", pi = Vectors.dense(Array(0.2, 0.8)),
@@ -228,6 +233,66 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
validateProbabilities(featureAndProbabilities, model, "bernoulli")
}
+ test("detect negative values") {
+ val dense = spark.createDataFrame(Seq(
+ LabeledPoint(1.0, Vectors.dense(1.0)),
+ LabeledPoint(0.0, Vectors.dense(-1.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0))))
+ intercept[SparkException] {
+ new NaiveBayes().fit(dense)
+ }
+ val sparse = spark.createDataFrame(Seq(
+ LabeledPoint(1.0, Vectors.sparse(1, Array(0), Array(1.0))),
+ LabeledPoint(0.0, Vectors.sparse(1, Array(0), Array(-1.0))),
+ LabeledPoint(1.0, Vectors.sparse(1, Array(0), Array(1.0))),
+ LabeledPoint(1.0, Vectors.sparse(1, Array.empty, Array.empty))))
+ intercept[SparkException] {
+ new NaiveBayes().fit(sparse)
+ }
+ val nan = spark.createDataFrame(Seq(
+ LabeledPoint(1.0, Vectors.sparse(1, Array(0), Array(1.0))),
+ LabeledPoint(0.0, Vectors.sparse(1, Array(0), Array(Double.NaN))),
+ LabeledPoint(1.0, Vectors.sparse(1, Array(0), Array(1.0))),
+ LabeledPoint(1.0, Vectors.sparse(1, Array.empty, Array.empty))))
+ intercept[SparkException] {
+ new NaiveBayes().fit(nan)
+ }
+ }
+
+ test("detect non zero or one values in Bernoulli") {
+ val badTrain = spark.createDataFrame(Seq(
+ LabeledPoint(1.0, Vectors.dense(1.0)),
+ LabeledPoint(0.0, Vectors.dense(2.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0))))
+
+ intercept[SparkException] {
+ new NaiveBayes().setModelType(Bernoulli).setSmoothing(1.0).fit(badTrain)
+ }
+
+ val okTrain = spark.createDataFrame(Seq(
+ LabeledPoint(1.0, Vectors.dense(1.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0)),
+ LabeledPoint(0.0, Vectors.dense(0.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0))))
+
+ val model = new NaiveBayes().setModelType(Bernoulli).setSmoothing(1.0).fit(okTrain)
+
+ val badPredict = spark.createDataFrame(Seq(
+ LabeledPoint(1.0, Vectors.dense(1.0)),
+ LabeledPoint(1.0, Vectors.dense(2.0)),
+ LabeledPoint(1.0, Vectors.dense(1.0)),
+ LabeledPoint(1.0, Vectors.dense(0.0))))
+
+ intercept[SparkException] {
+ model.transform(badPredict).collect()
+ }
+ }
+
test("read/write") {
def checkModelData(model: NaiveBayesModel, model2: NaiveBayesModel): Unit = {
assert(model.pi === model2.pi)