aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorwm624@hotmail.com <wm624@hotmail.com>2016-05-13 08:29:37 +0100
committerSean Owen <sowen@cloudera.com>2016-05-13 08:29:37 +0100
commitbdff299f9e51b06b809fe505bda466009e759831 (patch)
tree9cbc4bede62660c76708474d1e1db1365c36fbac /mllib
parente1dc853737fc1739fbb5377ffe31fb2d89935b1f (diff)
downloadspark-bdff299f9e51b06b809fe505bda466009e759831.tar.gz
spark-bdff299f9e51b06b809fe505bda466009e759831.tar.bz2
spark-bdff299f9e51b06b809fe505bda466009e759831.zip
[SPARK-14900][ML] spark.ml classification metrics should include accuracy
## What changes were proposed in this pull request? (Please fill in changes proposed in this fix) Add accuracy to MulticlassMetrics class and add corresponding code in MulticlassClassificationEvaluator. ## How was this patch tested? (Please explain how this patch was tested. E.g. unit tests, integration tests, manual tests) Scala Unit tests in ml.evaluation Author: wm624@hotmail.com <wm624@hotmail.com> Closes #12882 from wangmiao1981/accuracy.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala15
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala17
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala9
3 files changed, 24 insertions, 17 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
index 3d89843a0b..8408516751 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/evaluation/MulticlassClassificationEvaluator.scala
@@ -40,15 +40,15 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid
/**
* param for metric name in evaluation (supports `"f1"` (default), `"precision"`, `"recall"`,
- * `"weightedPrecision"`, `"weightedRecall"`)
+ * `"weightedPrecision"`, `"weightedRecall"`, `"accuracy"`)
* @group param
*/
@Since("1.5.0")
val metricName: Param[String] = {
val allowedParams = ParamValidators.inArray(Array("f1", "precision",
- "recall", "weightedPrecision", "weightedRecall"))
+ "recall", "weightedPrecision", "weightedRecall", "accuracy"))
new Param(this, "metricName", "metric name in evaluation " +
- "(f1|precision|recall|weightedPrecision|weightedRecall)", allowedParams)
+ "(f1|precision|recall|weightedPrecision|weightedRecall|accuracy)", allowedParams)
}
/** @group getParam */
@@ -86,18 +86,13 @@ class MulticlassClassificationEvaluator @Since("1.5.0") (@Since("1.5.0") overrid
case "recall" => metrics.recall
case "weightedPrecision" => metrics.weightedPrecision
case "weightedRecall" => metrics.weightedRecall
+ case "accuracy" => metrics.accuracy
}
metric
}
@Since("1.5.0")
- override def isLargerBetter: Boolean = $(metricName) match {
- case "f1" => true
- case "precision" => true
- case "recall" => true
- case "weightedPrecision" => true
- case "weightedRecall" => true
- }
+ override def isLargerBetter: Boolean = true
@Since("1.5.0")
override def copy(extra: ParamMap): MulticlassClassificationEvaluator = defaultCopy(extra)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala
index 5dde2bdb17..719695a338 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala
@@ -139,7 +139,8 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl
* Returns precision
*/
@Since("1.1.0")
- lazy val precision: Double = tpByClass.values.sum.toDouble / labelCount
+ @deprecated("Use accuracy.", "2.0.0")
+ lazy val precision: Double = accuracy
/**
* Returns recall
@@ -148,14 +149,24 @@ class MulticlassMetrics @Since("1.1.0") (predictionAndLabels: RDD[(Double, Doubl
* of all false negatives)
*/
@Since("1.1.0")
- lazy val recall: Double = precision
+ @deprecated("Use accuracy.", "2.0.0")
+ lazy val recall: Double = accuracy
/**
* Returns f-measure
* (equals to precision and recall because precision equals recall)
*/
@Since("1.1.0")
- lazy val fMeasure: Double = precision
+ @deprecated("Use accuracy.", "2.0.0")
+ lazy val fMeasure: Double = accuracy
+
+ /**
+ * Returns accuracy
+ * (equals to the total number of correctly classified instances
+ * out of the total number of instances.)
+ */
+ @Since("2.0.0")
+ lazy val accuracy: Double = tpByClass.values.sum.toDouble / labelCount
/**
* Returns weighted true positive rate
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala
index d55bc8c3ec..f316c67234 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MulticlassMetricsSuite.scala
@@ -69,11 +69,12 @@ class MulticlassMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(math.abs(metrics.fMeasure(1.0, 2.0) - f2measure1) < delta)
assert(math.abs(metrics.fMeasure(2.0, 2.0) - f2measure2) < delta)
- assert(math.abs(metrics.recall -
+ assert(math.abs(metrics.accuracy -
(2.0 + 3.0 + 1.0) / ((2 + 3 + 1) + (1 + 1 + 1))) < delta)
- assert(math.abs(metrics.recall - metrics.precision) < delta)
- assert(math.abs(metrics.recall - metrics.fMeasure) < delta)
- assert(math.abs(metrics.recall - metrics.weightedRecall) < delta)
+ assert(math.abs(metrics.accuracy - metrics.precision) < delta)
+ assert(math.abs(metrics.accuracy - metrics.recall) < delta)
+ assert(math.abs(metrics.accuracy - metrics.fMeasure) < delta)
+ assert(math.abs(metrics.accuracy - metrics.weightedRecall) < delta)
assert(math.abs(metrics.weightedFalsePositiveRate -
((4.0 / 9) * fpRate0 + (4.0 / 9) * fpRate1 + (1.0 / 9) * fpRate2)) < delta)
assert(math.abs(metrics.weightedPrecision -