aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala18
1 files changed, 13 insertions, 5 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
index 082a6bcd21..80a46fc70c 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala
@@ -21,17 +21,17 @@ import breeze.linalg.{Vector => BV}
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.ml.util.DefaultReadWriteTest
+import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.mllib.classification.NaiveBayes.{Bernoulli, Multinomial}
import org.apache.spark.mllib.classification.NaiveBayesSuite._
import org.apache.spark.mllib.linalg._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
- @transient var dataset: DataFrame = _
+ @transient var dataset: Dataset[_] = _
override def beforeAll(): Unit = {
super.beforeAll()
@@ -86,7 +86,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
model: NaiveBayesModel,
modelType: String): Unit = {
featureAndProbabilities.collect().foreach {
- case Row(features: Vector, probability: Vector) => {
+ case Row(features: Vector, probability: Vector) =>
assert(probability.toArray.sum ~== 1.0 relTol 1.0e-10)
val expected = modelType match {
case Multinomial =>
@@ -97,7 +97,6 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
throw new UnknownError(s"Invalid modelType: $modelType.")
}
assert(probability ~== expected relTol 1.0e-10)
- }
}
}
@@ -185,6 +184,15 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa
val nb = new NaiveBayes()
testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings, checkModelData)
}
+
+ test("should support all NumericType labels and not support other types") {
+ val nb = new NaiveBayes()
+ MLTestingUtils.checkNumericTypes[NaiveBayesModel, NaiveBayes](
+ nb, isClassification = true, sqlContext) { (expected, actual) =>
+ assert(expected.pi === actual.pi)
+ assert(expected.theta === actual.theta)
+ }
+ }
}
object NaiveBayesSuite {