aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala9
1 files changed, 7 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
index daaa174a08..b6b25ecd01 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
@@ -73,10 +73,15 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
@Since("1.4.0")
override def evaluate(dataset: DataFrame): Double = {
val schema = dataset.schema
+ val predictionColName = $(predictionCol)
val predictionType = schema($(predictionCol)).dataType
- require(predictionType == FloatType || predictionType == DoubleType)
+ require(predictionType == FloatType || predictionType == DoubleType,
+ s"Prediction column $predictionColName must be of type float or double, " +
+ s" but not $predictionType")
+ val labelColName = $(labelCol)
val labelType = schema($(labelCol)).dataType
- require(labelType == FloatType || labelType == DoubleType)
+ require(labelType == FloatType || labelType == DoubleType,
+ s"Label column $labelColName must be of type float or double, but not $labelType")
val predictionAndLabels = dataset
.select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType))