diff options
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala | 15 |
1 files changed, 12 insertions, 3 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala index afeeaf7fb5..48db428130 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala @@ -29,13 +29,13 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ -import org.apache.spark.sql.{DataFrame, Row} +import org.apache.spark.sql.{DataFrame, Dataset, Row} import org.apache.spark.sql.functions.lit class LogisticRegressionSuite extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest { - @transient var dataset: DataFrame = _ + @transient var dataset: Dataset[_] = _ @transient var binaryDataset: DataFrame = _ private val eps: Double = 1e-5 @@ -103,7 +103,7 @@ class LogisticRegressionSuite assert(model.hasSummary) // Validate that we re-insert a probability column for evaluation val fieldNames = model.summary.predictions.schema.fieldNames - assert((dataset.schema.fieldNames.toSet).subsetOf( + assert(dataset.schema.fieldNames.toSet.subsetOf( fieldNames.toSet)) assert(fieldNames.exists(s => s.startsWith("probability_"))) } @@ -934,6 +934,15 @@ class LogisticRegressionSuite testEstimatorAndModelReadWrite(lr, dataset, LogisticRegressionSuite.allParamSettings, checkModelData) } + + test("should support all NumericType labels and not support other types") { + val lr = new LogisticRegression().setMaxIter(1) + MLTestingUtils.checkNumericTypes[LogisticRegressionModel, LogisticRegression]( + lr, isClassification = true, sqlContext) { (expected, actual) => + assert(expected.intercept === actual.intercept) + assert(expected.coefficients.toArray === actual.coefficients.toArray) + } + } } object LogisticRegressionSuite { |