aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org
diff options
context:
space:
mode:
authorDominik Dahlem <dominik.dahlem@gmail.combination>2015-11-02 16:11:42 -0800
committerXiangrui Meng <meng@databricks.com>2015-11-02 16:11:42 -0800
commitec03866a7ef2d0826520755d47c8c9480148a76c (patch)
tree3179499967f7916522d7869b8ec9885163c2620e /mllib/src/main/scala/org
parentecfb3e73fd0a99f0be96034710974e78b6f9d624 (diff)
downloadspark-ec03866a7ef2d0826520755d47c8c9480148a76c.tar.gz
spark-ec03866a7ef2d0826520755d47c8c9480148a76c.tar.bz2
spark-ec03866a7ef2d0826520755d47c8c9480148a76c.zip
[SPARK-11343][ML] Allow float and double prediction/label columns in RegressionEvaluator
mengxr, felixcheung This pull request just relaxes the type of the prediction/label columns to be float and double. Internally, these columns are casted to double. The other evaluators might need to be changed also. Author: Dominik Dahlem <dominik.dahlem@gmail.combination> Closes #9296 from dahlem/ddahlem_regression_evaluator_double_predictions_27102015.
Diffstat (limited to 'mllib/src/main/scala/org')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala12
1 files changed, 8 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
index 3fd34d8571..ba012f444d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
@@ -23,7 +23,8 @@ import org.apache.spark.ml.param.shared.{HasLabelCol, HasPredictionCol}
import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
import org.apache.spark.mllib.evaluation.RegressionMetrics
import org.apache.spark.sql.{DataFrame, Row}
-import org.apache.spark.sql.types.DoubleType
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.{DoubleType, FloatType}
/**
* :: Experimental ::
@@ -72,10 +73,13 @@ final class RegressionEvaluator @Since("1.4.0") (@Since("1.4.0") override val ui
@Since("1.4.0")
override def evaluate(dataset: DataFrame): Double = {
val schema = dataset.schema
- SchemaUtils.checkColumnType(schema, $(predictionCol), DoubleType)
- SchemaUtils.checkColumnType(schema, $(labelCol), DoubleType)
+ val predictionType = schema($(predictionCol)).dataType
+ require(predictionType == FloatType || predictionType == DoubleType)
+ val labelType = schema($(labelCol)).dataType
+ require(labelType == FloatType || labelType == DoubleType)
- val predictionAndLabels = dataset.select($(predictionCol), $(labelCol))
+ val predictionAndLabels = dataset
+ .select(col($(predictionCol)).cast(DoubleType), col($(labelCol)).cast(DoubleType))
.map { case Row(prediction: Double, label: Double) =>
(prediction, label)
}