aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@gmail.com>2015-06-20 13:01:59 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-06-20 13:01:59 -0700
commit0b8995168f02bb55afb0a5b7dbdb941c3c89cb4c (patch)
tree64a27502be793519bed306017f558f1a3fb15044 /mllib
parent1b6fe9b1a70aa3f81448c2705ea3a4b501cbda9d (diff)
downloadspark-0b8995168f02bb55afb0a5b7dbdb941c3c89cb4c.tar.gz
spark-0b8995168f02bb55afb0a5b7dbdb941c3c89cb4c.tar.bz2
spark-0b8995168f02bb55afb0a5b7dbdb941c3c89cb4c.zip
[SPARK-8468] [ML] Take the negative of some metrics in RegressionEvaluator to get correct cross validation
JIRA: https://issues.apache.org/jira/browse/SPARK-8468 Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #6905 from viirya/cv_min and squashes the following commits: 930d3db [Liang-Chi Hsieh] Fix python unit test and add document. d632135 [Liang-Chi Hsieh] Merge remote-tracking branch 'upstream/master' into cv_min 16e3b2c [Liang-Chi Hsieh] Take the negative instead of reciprocal. c3dd8d9 [Liang-Chi Hsieh] For comments. b5f52c1 [Liang-Chi Hsieh] Add param to CrossValidator for choosing whether to maximize evaulation value.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/params.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala35
4 files changed, 43 insertions, 8 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
index 8670e9679d..01c000b475 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
@@ -37,6 +37,10 @@ final class RegressionEvaluator(override val uid: String)
/**
* param for metric name in evaluation (supports `"rmse"` (default), `"mse"`, `"r2"`, and `"mae"`)
+ *
+ * Because we will maximize evaluation value (ref: `CrossValidator`),
+ * when we evaluate a metric that is needed to minimize (e.g., `"rmse"`, `"mse"`, `"mae"`),
+ * we take and output the negative of this metric.
* @group param
*/
val metricName: Param[String] = {
@@ -70,13 +74,13 @@ final class RegressionEvaluator(override val uid: String)
val metrics = new RegressionMetrics(predictionAndLabels)
val metric = $(metricName) match {
case "rmse" =>
- metrics.rootMeanSquaredError
+ -metrics.rootMeanSquaredError
case "mse" =>
- metrics.meanSquaredError
+ -metrics.meanSquaredError
case "r2" =>
metrics.r2
case "mae" =>
- metrics.meanAbsoluteError
+ -metrics.meanAbsoluteError
}
metric
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
index 15ebad8838..50c0d85506 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/params.scala
@@ -297,7 +297,7 @@ class DoubleArrayParam(parent: Params, name: String, doc: String, isValid: Array
/**
* :: Experimental ::
- * A param amd its value.
+ * A param and its value.
*/
@Experimental
case class ParamPair[T](param: Param[T], value: T) {
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
index aa722da323..5b20378455 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
@@ -63,7 +63,7 @@ class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext
// default = rmse
val evaluator = new RegressionEvaluator()
- assert(evaluator.evaluate(predictions) ~== 0.1019382 absTol 0.001)
+ assert(evaluator.evaluate(predictions) ~== -0.1019382 absTol 0.001)
// r2 score
evaluator.setMetricName("r2")
@@ -71,6 +71,6 @@ class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext
// mae
evaluator.setMetricName("mae")
- assert(evaluator.evaluate(predictions) ~== 0.08036075 absTol 0.001)
+ assert(evaluator.evaluate(predictions) ~== -0.08036075 absTol 0.001)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index 36af4b34a9..db64511a76 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -20,11 +20,12 @@ package org.apache.spark.ml.tuning
import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.classification.LogisticRegression
-import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator}
+import org.apache.spark.ml.evaluation.{BinaryClassificationEvaluator, Evaluator, RegressionEvaluator}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.param.shared.HasInputCol
+import org.apache.spark.ml.regression.LinearRegression
import org.apache.spark.mllib.classification.LogisticRegressionSuite.generateLogisticInput
-import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.{LinearDataGenerator, MLlibTestSparkContext}
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.sql.types.StructType
@@ -58,6 +59,36 @@ class CrossValidatorSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(cvModel.avgMetrics.length === lrParamMaps.length)
}
+ test("cross validation with linear regression") {
+ val dataset = sqlContext.createDataFrame(
+ sc.parallelize(LinearDataGenerator.generateLinearInput(
+ 6.3, Array(4.7, 7.2), Array(0.9, -1.3), Array(0.7, 1.2), 100, 42, 0.1), 2))
+
+ val trainer = new LinearRegression
+ val lrParamMaps = new ParamGridBuilder()
+ .addGrid(trainer.regParam, Array(1000.0, 0.001))
+ .addGrid(trainer.maxIter, Array(0, 10))
+ .build()
+ val eval = new RegressionEvaluator()
+ val cv = new CrossValidator()
+ .setEstimator(trainer)
+ .setEstimatorParamMaps(lrParamMaps)
+ .setEvaluator(eval)
+ .setNumFolds(3)
+ val cvModel = cv.fit(dataset)
+ val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression]
+ assert(parent.getRegParam === 0.001)
+ assert(parent.getMaxIter === 10)
+ assert(cvModel.avgMetrics.length === lrParamMaps.length)
+
+ eval.setMetricName("r2")
+ val cvModel2 = cv.fit(dataset)
+ val parent2 = cvModel2.bestModel.parent.asInstanceOf[LinearRegression]
+ assert(parent2.getRegParam === 0.001)
+ assert(parent2.getMaxIter === 10)
+ assert(cvModel2.avgMetrics.length === lrParamMaps.length)
+ }
+
test("validateParams should check estimatorParamMaps") {
import CrossValidatorSuite._