aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBryan Cutler <cutlerb@gmail.com>2016-09-29 16:31:30 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-09-29 16:31:30 -0700
commit2f739567080d804a942cfcca0e22f91ab7cbea36 (patch)
treead2605595a12125ee103ccf5f3481b031594a159
parent39eb3bb1ec29aa993de13a6eba3ab27db6fc5371 (diff)
downloadspark-2f739567080d804a942cfcca0e22f91ab7cbea36.tar.gz
spark-2f739567080d804a942cfcca0e22f91ab7cbea36.tar.bz2
spark-2f739567080d804a942cfcca0e22f91ab7cbea36.zip
[SPARK-17697][ML] Fixed bug in summary calculations that pattern match against label without casting
## What changes were proposed in this pull request? In calling LogisticRegression.evaluate and GeneralizedLinearRegression.evaluate using a Dataset where the Label is not of a double type, calculations pattern match against a double and throw a MatchError. This fix casts the Label column to a DoubleType to ensure there is no MatchError. ## How was this patch tested? Added unit tests to call evaluate with a dataset that has Label as other numeric types. Author: Bryan Cutler <cutlerb@gmail.com> Closes #15288 from BryanCutler/binaryLOR-numericCheck-SPARK-17697.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala11
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala18
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala25
4 files changed, 49 insertions, 7 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
index 5ab63d1de9..329961a25d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/LogisticRegression.scala
@@ -1169,7 +1169,7 @@ class BinaryLogisticRegressionSummary private[classification] (
// TODO: Allow the user to vary the number of bins using a setBins method in
// BinaryClassificationMetrics. For now the default is set to 100.
@transient private val binaryMetrics = new BinaryClassificationMetrics(
- predictions.select(probabilityCol, labelCol).rdd.map {
+ predictions.select(col(probabilityCol), col(labelCol).cast(DoubleType)).rdd.map {
case Row(score: Vector, label: Double) => (score(1), label)
}, 100
)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 02b27fb650..bb9e150c49 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -992,7 +992,7 @@ class GeneralizedLinearRegressionSummary private[regression] (
} else {
link.unlink(0.0)
}
- predictions.select(col(model.getLabelCol), w).rdd.map {
+ predictions.select(col(model.getLabelCol).cast(DoubleType), w).rdd.map {
case Row(y: Double, weight: Double) =>
family.deviance(y, wtdmu, weight)
}.sum()
@@ -1004,7 +1004,7 @@ class GeneralizedLinearRegressionSummary private[regression] (
@Since("2.0.0")
lazy val deviance: Double = {
val w = weightCol
- predictions.select(col(model.getLabelCol), col(predictionCol), w).rdd.map {
+ predictions.select(col(model.getLabelCol).cast(DoubleType), col(predictionCol), w).rdd.map {
case Row(label: Double, pred: Double, weight: Double) =>
family.deviance(label, pred, weight)
}.sum()
@@ -1030,9 +1030,10 @@ class GeneralizedLinearRegressionSummary private[regression] (
lazy val aic: Double = {
val w = weightCol
val weightSum = predictions.select(w).agg(sum(w)).first().getDouble(0)
- val t = predictions.select(col(model.getLabelCol), col(predictionCol), w).rdd.map {
- case Row(label: Double, pred: Double, weight: Double) =>
- (label, pred, weight)
+ val t = predictions.select(
+ col(model.getLabelCol).cast(DoubleType), col(predictionCol), w).rdd.map {
+ case Row(label: Double, pred: Double, weight: Double) =>
+ (label, pred, weight)
}
family.aic(t, deviance, numInstances, weightSum) + 2 * rank
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
index 8451e60144..42b56754e0 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/LogisticRegressionSuite.scala
@@ -32,7 +32,8 @@ import org.apache.spark.ml.util.{DefaultReadWriteTest, MLTestingUtils}
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{Dataset, Row}
-import org.apache.spark.sql.functions.lit
+import org.apache.spark.sql.functions.{col, lit}
+import org.apache.spark.sql.types.LongType
class LogisticRegressionSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -1776,6 +1777,21 @@ class LogisticRegressionSuite
summary.precisionByThreshold.collect() === sameSummary.precisionByThreshold.collect())
}
+ test("evaluate with labels that are not doubles") {
+ // Evaluate a test set with Label that is a numeric type other than Double
+ val lr = new LogisticRegression()
+ .setMaxIter(1)
+ .setRegParam(1.0)
+ val model = lr.fit(smallBinaryDataset)
+ val summary = model.evaluate(smallBinaryDataset).asInstanceOf[BinaryLogisticRegressionSummary]
+
+ val longLabelData = smallBinaryDataset.select(col(model.getLabelCol).cast(LongType),
+ col(model.getFeaturesCol))
+ val longSummary = model.evaluate(longLabelData).asInstanceOf[BinaryLogisticRegressionSummary]
+
+ assert(summary.areaUnderROC === longSummary.areaUnderROC)
+ }
+
test("statistics on training data") {
// Test that loss is monotonically decreasing.
val lr = new LogisticRegression()
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
index 937aa7d3c2..ac1ef5feb9 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/GeneralizedLinearRegressionSuite.scala
@@ -31,6 +31,7 @@ import org.apache.spark.mllib.random._
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.sql.{DataFrame, Row}
import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.FloatType
class GeneralizedLinearRegressionSuite
extends SparkFunSuite with MLlibTestSparkContext with DefaultReadWriteTest {
@@ -1067,6 +1068,30 @@ class GeneralizedLinearRegressionSuite
idx += 1
}
}
+
+ test("evaluate with labels that are not doubles") {
+ // Evaulate with a dataset that contains Labels not as doubles to verify correct casting
+ val dataset = Seq(
+ Instance(17.0, 1.0, Vectors.dense(0.0, 5.0).toSparse),
+ Instance(19.0, 1.0, Vectors.dense(1.0, 7.0)),
+ Instance(23.0, 1.0, Vectors.dense(2.0, 11.0)),
+ Instance(29.0, 1.0, Vectors.dense(3.0, 13.0))
+ ).toDF()
+
+ val trainer = new GeneralizedLinearRegression()
+ .setMaxIter(1)
+ val model = trainer.fit(dataset)
+ assert(model.hasSummary)
+ val summary = model.summary
+
+ val longLabelDataset = dataset.select(col(model.getLabelCol).cast(FloatType),
+ col(model.getFeaturesCol))
+ val evalSummary = model.evaluate(longLabelDataset)
+ // The calculations below involve pattern matching with Label as a double
+ assert(evalSummary.nullDeviance === summary.nullDeviance)
+ assert(evalSummary.deviance === summary.deviance)
+ assert(evalSummary.aic === summary.aic)
+ }
}
object GeneralizedLinearRegressionSuite {