aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala11
1 files changed, 6 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
index 02b27fb650..bb9e150c49 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GeneralizedLinearRegression.scala
@@ -992,7 +992,7 @@ class GeneralizedLinearRegressionSummary private[regression] (
} else {
link.unlink(0.0)
}
- predictions.select(col(model.getLabelCol), w).rdd.map {
+ predictions.select(col(model.getLabelCol).cast(DoubleType), w).rdd.map {
case Row(y: Double, weight: Double) =>
family.deviance(y, wtdmu, weight)
}.sum()
@@ -1004,7 +1004,7 @@ class GeneralizedLinearRegressionSummary private[regression] (
@Since("2.0.0")
lazy val deviance: Double = {
val w = weightCol
- predictions.select(col(model.getLabelCol), col(predictionCol), w).rdd.map {
+ predictions.select(col(model.getLabelCol).cast(DoubleType), col(predictionCol), w).rdd.map {
case Row(label: Double, pred: Double, weight: Double) =>
family.deviance(label, pred, weight)
}.sum()
@@ -1030,9 +1030,10 @@ class GeneralizedLinearRegressionSummary private[regression] (
lazy val aic: Double = {
val w = weightCol
val weightSum = predictions.select(w).agg(sum(w)).first().getDouble(0)
- val t = predictions.select(col(model.getLabelCol), col(predictionCol), w).rdd.map {
- case Row(label: Double, pred: Double, weight: Double) =>
- (label, pred, weight)
+ val t = predictions.select(
+ col(model.getLabelCol).cast(DoubleType), col(predictionCol), w).rdd.map {
+ case Row(label: Double, pred: Double, weight: Double) =>
+ (label, pred, weight)
}
family.aic(t, deviance, numInstances, weightSum) + 2 * rank
}