aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorDominik Dahlem <dominik.dahlem@gmail.com>2015-12-08 18:54:10 -0800
committerJoseph K. Bradley <joseph@databricks.com>2015-12-08 18:54:10 -0800
commita0046e379bee0852c39ece4ea719cde70d350b0e (patch)
tree098dcfda07247f00245451292ef0f6961a95555e /mllib
parent765c67f5f2e0b1367e37883f662d313661e3a0d9 (diff)
downloadspark-a0046e379bee0852c39ece4ea719cde70d350b0e.tar.gz
spark-a0046e379bee0852c39ece4ea719cde70d350b0e.tar.bz2
spark-a0046e379bee0852c39ece4ea719cde70d350b0e.zip
[SPARK-11343][ML] Documentation of float and double prediction/label columns in RegressionEvaluator
felixcheung , mengxr Just added a message to require() Author: Dominik Dahlem <dominik.dahlem@gmail.combination> Closes #9598 from dahlem/ddahlem_regression_evaluator_double_predictions_message_04112015.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala9
1 files changed, 7 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
index daaa174a08..b6b25ecd01 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
@@ -73,10 +73,15 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
@Since("1.4.0")
override def evaluate(dataset: DataFrame): Double = {
val schema = dataset.schema
+ val predictionColName = $(predictionCol)
val predictionType = schema($(predictionCol)).dataType
- require(predictionType == FloatType || predictionType == DoubleType)
+ require(predictionType == FloatType || predictionType == DoubleType,
+ s"Prediction column $predictionColName must be of type float or double, " +
+ s" but not $predictionType")
+ val labelColName = $(labelCol)
val labelType = schema($(labelCol)).dataType
- require(labelType == FloatType || labelType == DoubleType)
+ require(labelType == FloatType || labelType == DoubleType,
+ s"Label column $labelColName must be of type float or double, but not $labelType")
val predictionAndLabels = dataset
.select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType))