aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala18
1 files changed, 17 insertions, 1 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 8451e60144..42b56754e0 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -32,7 +32,8 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{Dataset, Row}
-import org.apache.spark.sql.functions.lit
+import org.apache.spark.sql.functions.{col, lit}
+import org.apache.spark.sql.types.LongType
class LogisticRegressionSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -1776,6 +1777,21 @@ class LogisticRegressionSuite
summary.precisionByThreshold.collect() === sameSummary.precisionByThreshold.collect())
}
+ test("evaluate with labels that are not doubles") {
+ // Evaluate a test set with Label that is a numeric type other than Double
+ val lr = new LogisticRegression()
+ .setMaxIter(1)
+ .setRegParam(1.0)
+ val model = lr.fit(smallBinaryDataset)
+ val summary = model.evaluate(smallBinaryDataset).asInstanceOf[BinaryLogisticRegressionSummary]
+
+ val longLabelData = smallBinaryDataset.select(col(model.getLabelCol).cast(LongType),
+ col(model.getFeaturesCol))
+ val longSummary = model.evaluate(longLabelData).asInstanceOf[BinaryLogisticRegressionSummary]
+
+ assert(summary.areaUnderROC === longSummary.areaUnderROC)
+ }
+
test("statistics on training data") {
// Test that loss is monotonically decreasing.
val lr = new LogisticRegression()