aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala25
1 files changed, 25 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index 937aa7d3c2..ac1ef5feb9 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -31,6 +31,7 @@ import org.apache.spark.mllib.random._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.FloatType
class GeneralizedLinearRegressionSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -1067,6 +1068,30 @@ class GeneralizedLinearRegressionSuite
idx += 1
}
}
+
+ test("evaluate with labels that are not doubles") {
+ // Evaulate with a dataset that contains Labels not as doubles to verify correct casting
+ val dataset = Seq(
+ Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
+ Instance(19.0, 1.0, Vectors.dense(1.0, 7.0)),
+ Instance(23.0, 1.0, Vectors.dense(2.0, 11.0)),
+ Instance(29.0, 1.0, Vectors.dense(3.0, 13.0))
+ ).toDF()
+
+ val trainer = new GeneralizedLinearRegression()
+ .setMaxIter(1)
+ val model = trainer.fit(dataset)
+ assert(model.hasSummary)
+ val summary = model.summary
+
+ val longLabelDataset = dataset.select(col(model.getLabelCol).cast(FloatType),
+ col(model.getFeaturesCol))
+ val evalSummary = model.evaluate(longLabelDataset)
+ // The calculations below involve pattern matching with Label as a double
+ assert(evalSummary.nullDeviance === summary.nullDeviance)
+ assert(evalSummary.deviance === summary.deviance)
+ assert(evalSummary.aic === summary.aic)
+ }
}
object GeneralizedLinearRegressionSuite {