aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorFeynman Liang <fliang@databricks.com>2015-08-19 11:35:05 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-08-19 11:35:05 -0700
commit28a98464ea65aa7b35e24fca5ddaa60c2e5d53ee (patch)
tree35c507135016a31b2157b15cc1de9aa67212dfe3 /mllib
parent5fd53c64bb01de74ae57a7068de85b34adc856cf (diff)
downloadspark-28a98464ea65aa7b35e24fca5ddaa60c2e5d53ee.tar.gz
spark-28a98464ea65aa7b35e24fca5ddaa60c2e5d53ee.tar.bz2
spark-28a98464ea65aa7b35e24fca5ddaa60c2e5d53ee.zip
[SPARK-10097] Adds `shouldMaximize` flag to `ml.evaluation.Evaluator`
Previously, users of evaluator (`CrossValidator` and `TrainValidationSplit`) would only maximize the metric in evaluator, leading to a hacky solution which negated metrics to be minimized and caused erroneous negative values to be reported to the user. This PR adds a `isLargerBetter` attribute to the `Evaluator` base class, instructing users of `Evaluator` on whether the chosen metric should be maximized or minimized. CC jkbradley Author: Feynman Liang <fliang@databricks.com> Author: Joseph K. Bradley <joseph@databricks.com> Closes #8290 from feynmanliang/SPARK-10097.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala20
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala7
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala8
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala19
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala4
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala2
9 files changed, 50 insertions, 20 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
index 5d5cb7e94f..56419a0a15 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/BinaryClassificationEvaluator.scala
@@ -40,8 +40,11 @@ class BinaryClassificationEvaluator(override val uid: String)
* param for metric name in evaluation
* @group param
*/
- val metricName: Param[String] = new Param(this, "metricName",
- "metric name in evaluation (areaUnderROC|areaUnderPR)")
+ val metricName: Param[String] = {
+ val allowedParams = ParamValidators.inArray(Array("areaUnderROC", "areaUnderPR"))
+ new Param(
+ this, "metricName", "metric name in evaluation (areaUnderROC|areaUnderPR)", allowedParams)
+ }
/** @group getParam */
def getMetricName: String = $(metricName)
@@ -76,16 +79,17 @@ class BinaryClassificationEvaluator(override val uid: String)
}
val metrics = new BinaryClassificationMetrics(scoreAndLabels)
val metric = $(metricName) match {
- case "areaUnderROC" =>
- metrics.areaUnderROC()
- case "areaUnderPR" =>
- metrics.areaUnderPR()
- case other =>
- throw new IllegalArgumentException(s"Does not support metric $other.")
+ case "areaUnderROC" => metrics.areaUnderROC()
+ case "areaUnderPR" => metrics.areaUnderPR()
}
metrics.unpersist()
metric
}
+ override def isLargerBetter: Boolean = $(metricName) match {
+ case "areaUnderROC" => true
+ case "areaUnderPR" => true
+ }
+
override def copy(extra: ParamMap): BinaryClassificationEvaluator = defaultCopy(extra)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala
index e56c946a06..13bd3307f8 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/Evaluator.scala
@@ -46,5 +46,12 @@ abstract class Evaluator extends Params {
*/
def evaluate(dataset: DataFrame): Double
+ /**
+ * Indicates whether the metric returned by [[evaluate()]] should be maximized (true, default)
+ * or minimized (false).
+ * A given evaluator may support multiple metrics which may be maximized or minimized.
+ */
+ def isLargerBetter: Boolean = true
+
override def copy(extra: ParamMap): Evaluator
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
index 44f779c190..f73d234507 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
@@ -81,5 +81,13 @@ class MulticlassClassificationEvaluator (override val uid: String)
metric
}
+ override def isLargerBetter: Boolean = $(metricName) match {
+ case "f1" => true
+ case "precision" => true
+ case "recall" => true
+ case "weightedPrecision" => true
+ case "weightedRecall" => true
+ }
+
override def copy(extra: ParamMap): MulticlassClassificationEvaluator = defaultCopy(extra)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
index 01c000b475..d21c88ab9b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/RegressionEvaluator.scala
@@ -73,17 +73,20 @@ final class RegressionEvaluator(override val uid: String)
}
val metrics = new RegressionMetrics(predictionAndLabels)
val metric = $(metricName) match {
- case "rmse" =>
- -metrics.rootMeanSquaredError
- case "mse" =>
- -metrics.meanSquaredError
- case "r2" =>
- metrics.r2
- case "mae" =>
- -metrics.meanAbsoluteError
+ case "rmse" => metrics.rootMeanSquaredError
+ case "mse" => metrics.meanSquaredError
+ case "r2" => metrics.r2
+ case "mae" => metrics.meanAbsoluteError
}
metric
}
+ override def isLargerBetter: Boolean = $(metricName) match {
+ case "rmse" => false
+ case "mse" => false
+ case "r2" => true
+ case "mae" => false
+ }
+
override def copy(extra: ParamMap): RegressionEvaluator = defaultCopy(extra)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 4792eb0f0a..0679bfd0f3 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -100,7 +100,9 @@ class CrossValidator(override val uid: String) extends Estimator[CrossValidatorM
}
f2jBLAS.dscal(numModels, 1.0 / $(numFolds), metrics, 1)
logInfo(s"Average cross-validation metrics: ${metrics.toSeq}")
- val (bestMetric, bestIndex) = metrics.zipWithIndex.maxBy(_._1)
+ val (bestMetric, bestIndex) =
+ if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1)
+ else metrics.zipWithIndex.minBy(_._1)
logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
logInfo(s"Best cross-validation metric: $bestMetric.")
val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
index c0edc730b6..73a14b8310 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -99,7 +99,9 @@ class TrainValidationSplit(override val uid: String) extends Estimator[TrainVali
validationDataset.unpersist()
logInfo(s"Train validation split metrics: ${metrics.toSeq}")
- val (bestMetric, bestIndex) = metrics.zipWithIndex.maxBy(_._1)
+ val (bestMetric, bestIndex) =
+ if (eval.isLargerBetter) metrics.zipWithIndex.maxBy(_._1)
+ else metrics.zipWithIndex.minBy(_._1)
logInfo(s"Best set of parameters:\n${epm(bestIndex)}")
logInfo(s"Best train validation split metric: $bestMetric.")
val bestModel = est.fit(dataset, epm(bestIndex)).asInstanceOf[Model[_]]
diff --git a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
index 5b20378455..aa722da323 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/evaluation/RegressionEvaluatorSuite.scala
@@ -63,7 +63,7 @@ class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext
// default = rmse
val evaluator = new RegressionEvaluator()
- assert(evaluator.evaluate(predictions) ~== -0.1019382 absTol 0.001)
+ assert(evaluator.evaluate(predictions) ~== 0.1019382 absTol 0.001)
// r2 score
evaluator.setMetricName("r2")
@@ -71,6 +71,6 @@ class RegressionEvaluatorSuite extends SparkFunSuite with MLlibTestSparkContext
// mae
evaluator.setMetricName("mae")
- assert(evaluator.evaluate(predictions) ~== -0.08036075 absTol 0.001)
+ assert(evaluator.evaluate(predictions) ~== 0.08036075 absTol 0.001)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
index aaca08bb61..fde02e0c84 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/CrossValidatorSuite.scala
@@ -143,6 +143,8 @@ object CrossValidatorSuite {
throw new UnsupportedOperationException
}
+ override def isLargerBetter: Boolean = true
+
override val uid: String = "eval"
override def copy(extra: ParamMap): MyEvaluator = defaultCopy(extra)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
index c8e58f216c..ef24e6fb6b 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
@@ -132,6 +132,8 @@ object TrainValidationSplitSuite {
throw new UnsupportedOperationException
}
+ override def isLargerBetter: Boolean = true
+
override val uid: String = "eval"
override def copy(extra: ParamMap): MyEvaluator = defaultCopy(extra)