aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala15
1 files changed, 12 insertions, 3 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index afeeaf7fb5..48db428130 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -29,13 +29,13 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
-import org.apache.spark.sql.{DataFrame, Row}
+import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions.lit
class LogisticRegressionSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
- @transient var dataset: DataFrame = _
+ @transient var dataset: Dataset[_] = _
@transient var binaryDataset: DataFrame = _
private val eps: Double = 1e-5
@@ -103,7 +103,7 @@ class LogisticRegressionSuite
assert(model.hasSummary)
// Validate that we re-insert a probability column for evaluation
val fieldNames = model.summary.predictions.schema.fieldNames
- assert((dataset.schema.fieldNames.toSet).subsetOf(
+ assert(dataset.schema.fieldNames.toSet.subsetOf(
fieldNames.toSet))
assert(fieldNames.exists(s => s.startsWith("probability_")))
}
@@ -934,6 +934,15 @@ class LogisticRegressionSuite
testEstimatorAndModelReadWrite(lr, dataset, LogisticRegressionSuite.allParamSettings,
checkModelData)
}
+
+ test("should support all NumericType labels and not support other types") {
+ val lr = new LogisticRegression().setMaxIter(1)
+ MLTestingUtils.checkNumericTypes[LogisticRegressionModel, LogisticRegression](
+ lr, isClassification = true, sqlContext) { (expected, actual) =>
+ assert(expected.intercept === actual.intercept)
+ assert(expected.coefficients.toArray === actual.coefficients.toArray)
+ }
+ }
}
object LogisticRegressionSuite {