From bdff299f9e51b06b809fe505bda466009e759831 Mon Sep 17 00:00:00 2001 From: "wm624@hotmail.com" Date: Fri, 13 May 2016 08:29:37 +0100 Subject: [SPARK-14900][ML] spark.ml classification metrics should include accuracy ## What changes were proposed in this pull request? (Please fill in changes proposed in this fix) Add accuracy to MulticlassMetrics class and add corresponding code in MulticlassClassificationEvaluator. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) Scala Unit tests in ml.evaluation Author: wm624@hotmail.com Closes #12882 from wangmiao1981/accuracy. --- .../evaluation/MulticlassClassificationEvaluator.scala | 15 +++++---------- .../spark/mllib/evaluation/MulticlassMetrics.scala | 17 ++++++++++++++--- .../spark/mllib/evaluation/MulticlassMetricsSuite.scala | 9 +++++---- 3 files changed, 24 insertions(+), 17 deletions(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala index 3d89843a0b..8408516751 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala @@ -40,15 +40,15 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid /** * param for metric name in evaluation (supports `"f1"` (default), `"precision"`, `"recall"`, - * `"weightedPrecision"`, `"weightedRecall"`) + * `"weightedPrecision"`, `"weightedRecall"`, `"accuracy"`) * @group param */ @Since("1.5.0") val metricName: Param[String] = { val allowedParams = ParamValidators.inArray(Array("f1", "precision", - "recall", "weightedPrecision", "weightedRecall")) + "recall", "weightedPrecision", "weightedRecall", "accuracy")) new Param(this, "metricName", "metric name in evaluation " + - "(f1|precision|recall|weightedPrecision|weightedRecall)", allowedParams) + "(f1|precision|recall|weightedPrecision|weightedRecall|accuracy)", allowedParams) } /** @group getParam */ @@ -86,18 +86,13 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid case "recall" => metrics.recall case "weightedPrecision" => metrics.weightedPrecision case "weightedRecall" => metrics.weightedRecall + case "accuracy" => metrics.accuracy } metric } @Since("1.5.0") - override def isLargerBetter: Boolean = $(metricName) match { - case "f1" => true - case "precision" => true - case "recall" => true - case "weightedPrecision" => true - case "weightedRecall" => true - } + override def isLargerBetter: Boolean = true @Since("1.5.0") override def copy(extra: ParamMap): MulticlassClassificationEvaluator = defaultCopy(extra) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala index 5dde2bdb17..719695a338 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala @@ -139,7 +139,8 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl * Returns precision */ @Since("1.1.0") - lazy val precision: Double = tpByClass.values.sum.toDouble / labelCount + @deprecated("Use accuracy.", "2.0.0") + lazy val precision: Double = accuracy /** * Returns recall @@ -148,14 +149,24 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl * of all false negatives) */ @Since("1.1.0") - lazy val recall: Double = precision + @deprecated("Use accuracy.", "2.0.0") + lazy val recall: Double = accuracy /** * Returns f-measure * (equals to precision and recall because precision equals recall) */ @Since("1.1.0") - lazy val fMeasure: Double = precision + @deprecated("Use accuracy.", "2.0.0") + lazy val fMeasure: Double = accuracy + + /** + * Returns accuracy + * (equals to the total number of correctly classified instances + * out of the total number of instances.) + */ + @Since("2.0.0") + lazy val accuracy: Double = tpByClass.values.sum.toDouble / labelCount /** * Returns weighted true positive rate diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala index d55bc8c3ec..f316c67234 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala @@ -69,11 +69,12 @@ class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext { assert(math.abs(metrics.fMeasure(1.0, 2.0) - f2measure1) < delta) assert(math.abs(metrics.fMeasure(2.0, 2.0) - f2measure2) < delta) - assert(math.abs(metrics.recall - + assert(math.abs(metrics.accuracy - (2.0 + 3.0 + 1.0) / ((2 + 3 + 1) + (1 + 1 + 1))) < delta) - assert(math.abs(metrics.recall - metrics.precision) < delta) - assert(math.abs(metrics.recall - metrics.fMeasure) < delta) - assert(math.abs(metrics.recall - metrics.weightedRecall) < delta) + assert(math.abs(metrics.accuracy - metrics.precision) < delta) + assert(math.abs(metrics.accuracy - metrics.recall) < delta) + assert(math.abs(metrics.accuracy - metrics.fMeasure) < delta) + assert(math.abs(metrics.accuracy - metrics.weightedRecall) < delta) assert(math.abs(metrics.weightedFalsePositiveRate - ((4.0 / 9) * fpRate0 + (4.0 / 9) * fpRate1 + (1.0 / 9) * fpRate2)) < delta) assert(math.abs(metrics.weightedPrecision - -- cgit v1.2.3