diff options
author | Feynman Liang <fliang@databricks.com> | 2015-08-19 11:35:05 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2015-08-19 11:35:05 -0700 |
commit | 28a98464ea65aa7b35e24fca5ddaa60c2e5d53ee (patch) | |
tree | 35c507135016a31b2157b15cc1de9aa67212dfe3 /mllib | |
parent | 5fd53c64bb01de74ae57a7068de85b34adc856cf (diff) | |
download | spark-28a98464ea65aa7b35e24fca5ddaa60c2e5d53ee.tar.gz spark-28a98464ea65aa7b35e24fca5ddaa60c2e5d53ee.tar.bz2 spark-28a98464ea65aa7b35e24fca5ddaa60c2e5d53ee.zip |
[SPARK-10097] Adds `shouldMaximize` flag to `ml.evaluation.Evaluator`
Previously, users of evaluator (`CrossValidator` and `TrainValidationSplit`) would only maximize the metric in evaluator, leading to a hacky solution which negated metrics to be minimized and caused erroneous negative values to be reported to the user.
This PR adds a `isLargerBetter` attribute to the `Evaluator` base class, instructing users of `Evaluator` on whether the chosen metric should be maximized or minimized.
CC jkbradley
Author: Feynman Liang <fliang@databricks.com>
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #8290 from feynmanliang/SPARK-10097.
Diffstat (limited to 'mllib')
9 files changed, 50 insertions, 20 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala index 5d5cb7e94f..56419a0a15 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala @@ -40,8 +40,11 @@ class BinaryClassificationEvaluator(override val uid: String) * param for metric name in evaluation * @group param */ - val metricName: Param[String] = new Param(this, "metricName", - "metric name in evaluation (areaUnderROC|areaUnderPR)") + val metricName: Param[String] = { + val allowedParams = ParamValidators.inArray(Array("areaUnderROC", "areaUnderPR")) + new Param( + this, "metricName", "metric name in evaluation (areaUnderROC|areaUnderPR)", allowedParams) + } /** @group getParam */ def getMetricName: String = $(metricName) @@ -76,16 +79,17 @@ class BinaryClassificationEvaluator(override val uid: String) } val metrics = new BinaryClassificationMetrics(scoreAndLabels) val metric = $(metricName) match { - case "areaUnderROC" => - metrics.areaUnderROC() - case "areaUnderPR" => - metrics.areaUnderPR() - case other => - throw new IllegalArgumentException(s"Does not support metric $other.") + case "areaUnderROC" => metrics.areaUnderROC() + case "areaUnderPR" => metrics.areaUnderPR() } metrics.unpersist() metric } + override def isLargerBetter: Boolean = $(metricName) match { + case "areaUnderROC" => true + case "areaUnderPR" => true + } + override def copy(extra: ParamMap): BinaryClassificationEvaluator = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala index e56c946a06..13bd3307f8 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala @@ -46,5 +46,12 @@ abstract class Evaluator extends Params { */ def evaluate(dataset: DataFrame): Double + /** + * Indicates whether the metric returned by [[evaluate()]] should be maximized (true, default) + * or minimized (false). + * A given evaluator may support multiple metrics which may be maximized or minimized. + */ + def isLargerBetter: Boolean = true + override def copy(extra: ParamMap): Evaluator } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala index 44f779c190..f73d234507 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -81,5 +81,13 @@ class MulticlassClassificationEvaluator (override val uid: String) metric } + override def isLargerBetter: Boolean = $(metricName) match { + case "f1" => true + case "precision" => true + case "recall" => true + case "weightedPrecision" => true + case "weightedRecall" => true + } + override def copy(extra: ParamMap): MulticlassClassificationEvaluator = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala index 01c000b475..d21c88ab9b 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala @@ -73,17 +73,20 @@ final class RegressionEvaluator(override val uid: String) } val metrics = new RegressionMetrics(predictionAndLabels) val metric = $(metricName) match { - case "rmse" => - -metrics.rootMeanSquaredError - case "mse" => - -metrics.meanSquaredError - case "r2" => - metrics.r2 - case "mae" => - -metrics.meanAbsoluteError + case "rmse" => metrics.rootMeanSquaredError + case "mse" => metrics.meanSquaredError + case "r2" => metrics.r2 + case "mae" => metrics.meanAbsoluteError } metric } + override def isLargerBetter: Boolean = $(metricName) match { + case "rmse" => false + case "mse" => false + case "r2" => true + case "mae" => false + } + override def copy(extra: ParamMap): RegressionEvaluator = defaultCopy(extra) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala index 4792eb0f0a..0679bfd0f3 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala @@ -100,7 +100,9 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM } f2jBLAS.dscal(numModels, 1.0 / $(numFolds), metrics, 1) logInfo(s"Average cross-validation metrics: ${metrics.toSeq}") - val (bestMetric, bestIndex) = metrics.zipWithIndex.maxBy(_._1) + val (bestMetric, bestIndex) = + if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1) + else metrics.zipWithIndex.minBy(_._1) logInfo(s"Best set of parameters:\n${epm(bestIndex)}") logInfo(s"Best cross-validation metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala index c0edc730b6..73a14b8310 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala @@ -99,7 +99,9 @@ class TrainValidationSplit(override val uid: String) extends Estimator[TrainVali validationDataset.unpersist() logInfo(s"Train validation split metrics: ${metrics.toSeq}") - val (bestMetric, bestIndex) = metrics.zipWithIndex.maxBy(_._1) + val (bestMetric, bestIndex) = + if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1) + else metrics.zipWithIndex.minBy(_._1) logInfo(s"Best set of parameters:\n${epm(bestIndex)}") logInfo(s"Best train validation split metric: $bestMetric.") val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]] diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala index 5b20378455..aa722da323 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala @@ -63,7 +63,7 @@ class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext // default = rmse val evaluator = new RegressionEvaluator() - assert(evaluator.evaluate(predictions) ~== -0.1019382 absTol 0.001) + assert(evaluator.evaluate(predictions) ~== 0.1019382 absTol 0.001) // r2 score evaluator.setMetricName("r2") @@ -71,6 +71,6 @@ class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext // mae evaluator.setMetricName("mae") - assert(evaluator.evaluate(predictions) ~== -0.08036075 absTol 0.001) + assert(evaluator.evaluate(predictions) ~== 0.08036075 absTol 0.001) } } diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala index aaca08bb61..fde02e0c84 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala @@ -143,6 +143,8 @@ object CrossValidatorSuite { throw new UnsupportedOperationException } + override def isLargerBetter: Boolean = true + override val uid: String = "eval" override def copy(extra: ParamMap): MyEvaluator = defaultCopy(extra) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index c8e58f216c..ef24e6fb6b 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -132,6 +132,8 @@ object TrainValidationSplitSuite { throw new UnsupportedOperationException } + override def isLargerBetter: Boolean = true + override val uid: String = "eval" override def copy(extra: ParamMap): MyEvaluator = defaultCopy(extra) |