diff options
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala | 18 |
1 files changed, 13 insertions, 5 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala index 082a6bcd21..80a46fc70c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/NaiveBayesSuite.scala @@ -21,17 +21,17 @@ import breeze.linalg.{Vector => BV} import org.apache.spark.SparkFunSuite import org.apache.spark.ml.param.ParamsSuite -import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils} import org.apache.spark.mllib.classification.NaiveBayes.{Bernoulli, Multinomial} import org.apache.spark.mllib.classification.NaiveBayesSuite._ import org.apache.spark.mllib.linalg._ import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { - @transient var dataset: DataFrame = _ + @transient var dataset: Dataset[_] = _ override def beforeAll(): Unit = { super.beforeAll() @@ -86,7 +86,7 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa model: NaiveBayesModel, modelType: String): Unit = { featureAndProbabilities.collect().foreach { - case Row(features: Vector, probability: Vector) => { + case Row(features: Vector, probability: Vector) => assert(probability.toArray.sum ~== 1.0 relTol 1.0e-10) val expected = modelType match { case Multinomial => @@ -97,7 +97,6 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa throw new UnknownError(s"Invalid modelType: $modelType.") } assert(probability ~== expected relTol 1.0e-10) - } } } @@ -185,6 +184,15 @@ class NaiveBayesSuite extends SparkFunSuite with MLlibTestSparkContext with Defa val nb = new NaiveBayes() testEstimatorAndModelReadWrite(nb, dataset, NaiveBayesSuite.allParamSettings, checkModelData) } + + test("should support all NumericType labels and not support other types") { + val nb = new NaiveBayes() + MLTestingUtils.checkNumericTypes[NaiveBayesModel, NaiveBayes]( + nb, isClassification = true, sqlContext) { (expected, actual) => + assert(expected.pi === actual.pi) + assert(expected.theta === actual.theta) + } + } } object NaiveBayesSuite { |