aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorAlexander Ulanov <nashb@yandex.ru>2014-10-31 18:31:03 -0700
committerXiangrui Meng <meng@databricks.com>2014-10-31 18:31:03 -0700
commit62d01d255c001a6d397cc166a10aba3894f43459 (patch)
tree3c7407e1bd3b47607dd14e85e21b36649bff2cab /mllib
parent23f73f525ce3d2b4a614e60f4f9170c860ab93da (diff)
downloadspark-62d01d255c001a6d397cc166a10aba3894f43459.tar.gz
spark-62d01d255c001a6d397cc166a10aba3894f43459.tar.bz2
spark-62d01d255c001a6d397cc166a10aba3894f43459.zip
[MLLIB] SPARK-2329 Add multi-label evaluation metrics
Implementation of various multi-label classification measures, including: Hamming-loss, strict and default Accuracy, macro-averaged Precision, Recall and F1-measure based on documents and labels, micro-averaged measures: https://issues.apache.org/jira/browse/SPARK-2329 Multi-class measures are currently in the following pull request: https://github.com/apache/spark/pull/1155 Author: Alexander Ulanov <nashb@yandex.ru> Author: avulanov <nashb@yandex.ru> Closes #1270 from avulanov/multilabelmetrics and squashes the following commits: fc8175e [Alexander Ulanov] Merge with previous updates 43a613e [Alexander Ulanov] Addressing reviewers comments: change Set to Array 517a594 [avulanov] Addressing reviewers comments: Scala style cf4222bc [avulanov] Addressing reviewers comments: renaming. Added label method that returns the list of labels 1843f73 [Alexander Ulanov] Scala style fix 79e8476 [Alexander Ulanov] Replacing fold(_ + _) with sum as suggested by srowen ca46765 [Alexander Ulanov] Cosmetic changes: Apache header and parameter explanation 40593f5 [Alexander Ulanov] Multi-label metrics: Hamming-loss, strict and normal accuracy, fix to macro measures, bunch of tests ad62df0 [Alexander Ulanov] Comments and scala style check 154164b [Alexander Ulanov] Multilabel evaluation metics and tests: macro precision and recall averaged by docs, micro and per-class precision and recall averaged by class
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala157
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala103
2 files changed, 260 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala
new file mode 100644
index 0000000000..ea10bde5fa
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala
@@ -0,0 +1,157 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.evaluation
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.SparkContext._
+
+/**
+ * Evaluator for multilabel classification.
+ * @param predictionAndLabels an RDD of (predictions, labels) pairs,
+ * both are non-null Arrays, each with unique elements.
+ */
+class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]) {
+
+ private lazy val numDocs: Long = predictionAndLabels.count()
+
+ private lazy val numLabels: Long = predictionAndLabels.flatMap { case (_, labels) =>
+ labels}.distinct().count()
+
+ /**
+ * Returns subset accuracy
+ * (for equal sets of labels)
+ */
+ lazy val subsetAccuracy: Double = predictionAndLabels.filter { case (predictions, labels) =>
+ predictions.deep == labels.deep
+ }.count().toDouble / numDocs
+
+ /**
+ * Returns accuracy
+ */
+ lazy val accuracy: Double = predictionAndLabels.map { case (predictions, labels) =>
+ labels.intersect(predictions).size.toDouble /
+ (labels.size + predictions.size - labels.intersect(predictions).size)}.sum / numDocs
+
+
+ /**
+ * Returns Hamming-loss
+ */
+ lazy val hammingLoss: Double = predictionAndLabels.map { case (predictions, labels) =>
+ labels.size + predictions.size - 2 * labels.intersect(predictions).size
+ }.sum / (numDocs * numLabels)
+
+ /**
+ * Returns document-based precision averaged by the number of documents
+ */
+ lazy val precision: Double = predictionAndLabels.map { case (predictions, labels) =>
+ if (predictions.size > 0) {
+ predictions.intersect(labels).size.toDouble / predictions.size
+ } else {
+ 0
+ }
+ }.sum / numDocs
+
+ /**
+ * Returns document-based recall averaged by the number of documents
+ */
+ lazy val recall: Double = predictionAndLabels.map { case (predictions, labels) =>
+ labels.intersect(predictions).size.toDouble / labels.size
+ }.sum / numDocs
+
+ /**
+ * Returns document-based f1-measure averaged by the number of documents
+ */
+ lazy val f1Measure: Double = predictionAndLabels.map { case (predictions, labels) =>
+ 2.0 * predictions.intersect(labels).size / (predictions.size + labels.size)
+ }.sum / numDocs
+
+ private lazy val tpPerClass = predictionAndLabels.flatMap { case (predictions, labels) =>
+ predictions.intersect(labels)
+ }.countByValue()
+
+ private lazy val fpPerClass = predictionAndLabels.flatMap { case (predictions, labels) =>
+ predictions.diff(labels)
+ }.countByValue()
+
+ private lazy val fnPerClass = predictionAndLabels.flatMap { case(predictions, labels) =>
+ labels.diff(predictions)
+ }.countByValue()
+
+ /**
+ * Returns precision for a given label (category)
+ * @param label the label.
+ */
+ def precision(label: Double) = {
+ val tp = tpPerClass(label)
+ val fp = fpPerClass.getOrElse(label, 0L)
+ if (tp + fp == 0) 0 else tp.toDouble / (tp + fp)
+ }
+
+ /**
+ * Returns recall for a given label (category)
+ * @param label the label.
+ */
+ def recall(label: Double) = {
+ val tp = tpPerClass(label)
+ val fn = fnPerClass.getOrElse(label, 0L)
+ if (tp + fn == 0) 0 else tp.toDouble / (tp + fn)
+ }
+
+ /**
+ * Returns f1-measure for a given label (category)
+ * @param label the label.
+ */
+ def f1Measure(label: Double) = {
+ val p = precision(label)
+ val r = recall(label)
+ if((p + r) == 0) 0 else 2 * p * r / (p + r)
+ }
+
+ private lazy val sumTp = tpPerClass.foldLeft(0L) { case (sum, (_, tp)) => sum + tp }
+ private lazy val sumFpClass = fpPerClass.foldLeft(0L) { case (sum, (_, fp)) => sum + fp }
+ private lazy val sumFnClass = fnPerClass.foldLeft(0L) { case (sum, (_, fn)) => sum + fn }
+
+ /**
+ * Returns micro-averaged label-based precision
+ * (equals to micro-averaged document-based precision)
+ */
+ lazy val microPrecision = {
+ val sumFp = fpPerClass.foldLeft(0L){ case(cum, (_, fp)) => cum + fp}
+ sumTp.toDouble / (sumTp + sumFp)
+ }
+
+ /**
+ * Returns micro-averaged label-based recall
+ * (equals to micro-averaged document-based recall)
+ */
+ lazy val microRecall = {
+ val sumFn = fnPerClass.foldLeft(0.0){ case(cum, (_, fn)) => cum + fn}
+ sumTp.toDouble / (sumTp + sumFn)
+ }
+
+ /**
+ * Returns micro-averaged label-based f1-measure
+ * (equals to micro-averaged document-based f1-measure)
+ */
+ lazy val microF1Measure = 2.0 * sumTp / (2 * sumTp + sumFnClass + sumFpClass)
+
+ /**
+ * Returns the sequence of labels in ascending order
+ */
+ lazy val labels: Array[Double] = tpPerClass.keys.toArray.sorted
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala
new file mode 100644
index 0000000000..342baa0274
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.evaluation
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.rdd.RDD
+
+class MultilabelMetricsSuite extends FunSuite with LocalSparkContext {
+ test("Multilabel evaluation metrics") {
+ /*
+ * Documents true labels (5x class0, 3x class1, 4x class2):
+ * doc 0 - predict 0, 1 - class 0, 2
+ * doc 1 - predict 0, 2 - class 0, 1
+ * doc 2 - predict none - class 0
+ * doc 3 - predict 2 - class 2
+ * doc 4 - predict 2, 0 - class 2, 0
+ * doc 5 - predict 0, 1, 2 - class 0, 1
+ * doc 6 - predict 1 - class 1, 2
+ *
+ * predicted classes
+ * class 0 - doc 0, 1, 4, 5 (total 4)
+ * class 1 - doc 0, 5, 6 (total 3)
+ * class 2 - doc 1, 3, 4, 5 (total 4)
+ *
+ * true classes
+ * class 0 - doc 0, 1, 2, 4, 5 (total 5)
+ * class 1 - doc 1, 5, 6 (total 3)
+ * class 2 - doc 0, 3, 4, 6 (total 4)
+ *
+ */
+ val scoreAndLabels: RDD[(Array[Double], Array[Double])] = sc.parallelize(
+ Seq((Array(0.0, 1.0), Array(0.0, 2.0)),
+ (Array(0.0, 2.0), Array(0.0, 1.0)),
+ (Array(), Array(0.0)),
+ (Array(2.0), Array(2.0)),
+ (Array(2.0, 0.0), Array(2.0, 0.0)),
+ (Array(0.0, 1.0, 2.0), Array(0.0, 1.0)),
+ (Array(1.0), Array(1.0, 2.0))), 2)
+ val metrics = new MultilabelMetrics(scoreAndLabels)
+ val delta = 0.00001
+ val precision0 = 4.0 / (4 + 0)
+ val precision1 = 2.0 / (2 + 1)
+ val precision2 = 2.0 / (2 + 2)
+ val recall0 = 4.0 / (4 + 1)
+ val recall1 = 2.0 / (2 + 1)
+ val recall2 = 2.0 / (2 + 2)
+ val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0)
+ val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1)
+ val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2)
+ val sumTp = 4 + 2 + 2
+ assert(sumTp == (1 + 1 + 0 + 1 + 2 + 2 + 1))
+ val microPrecisionClass = sumTp.toDouble / (4 + 0 + 2 + 1 + 2 + 2)
+ val microRecallClass = sumTp.toDouble / (4 + 1 + 2 + 1 + 2 + 2)
+ val microF1MeasureClass = 2.0 * sumTp.toDouble /
+ (2 * sumTp.toDouble + (1 + 1 + 2) + (0 + 1 + 2))
+ val macroPrecisionDoc = 1.0 / 7 *
+ (1.0 / 2 + 1.0 / 2 + 0 + 1.0 / 1 + 2.0 / 2 + 2.0 / 3 + 1.0 / 1.0)
+ val macroRecallDoc = 1.0 / 7 *
+ (1.0 / 2 + 1.0 / 2 + 0 / 1 + 1.0 / 1 + 2.0 / 2 + 2.0 / 2 + 1.0 / 2)
+ val macroF1MeasureDoc = (1.0 / 7) *
+ 2 * ( 1.0 / (2 + 2) + 1.0 / (2 + 2) + 0 + 1.0 / (1 + 1) +
+ 2.0 / (2 + 2) + 2.0 / (3 + 2) + 1.0 / (1 + 2) )
+ val hammingLoss = (1.0 / (7 * 3)) * (2 + 2 + 1 + 0 + 0 + 1 + 1)
+ val strictAccuracy = 2.0 / 7
+ val accuracy = 1.0 / 7 * (1.0 / 3 + 1.0 /3 + 0 + 1.0 / 1 + 2.0 / 2 + 2.0 / 3 + 1.0 / 2)
+ assert(math.abs(metrics.precision(0.0) - precision0) < delta)
+ assert(math.abs(metrics.precision(1.0) - precision1) < delta)
+ assert(math.abs(metrics.precision(2.0) - precision2) < delta)
+ assert(math.abs(metrics.recall(0.0) - recall0) < delta)
+ assert(math.abs(metrics.recall(1.0) - recall1) < delta)
+ assert(math.abs(metrics.recall(2.0) - recall2) < delta)
+ assert(math.abs(metrics.f1Measure(0.0) - f1measure0) < delta)
+ assert(math.abs(metrics.f1Measure(1.0) - f1measure1) < delta)
+ assert(math.abs(metrics.f1Measure(2.0) - f1measure2) < delta)
+ assert(math.abs(metrics.microPrecision - microPrecisionClass) < delta)
+ assert(math.abs(metrics.microRecall - microRecallClass) < delta)
+ assert(math.abs(metrics.microF1Measure - microF1MeasureClass) < delta)
+ assert(math.abs(metrics.precision - macroPrecisionDoc) < delta)
+ assert(math.abs(metrics.recall - macroRecallDoc) < delta)
+ assert(math.abs(metrics.f1Measure - macroF1MeasureDoc) < delta)
+ assert(math.abs(metrics.hammingLoss - hammingLoss) < delta)
+ assert(math.abs(metrics.subsetAccuracy - strictAccuracy) < delta)
+ assert(math.abs(metrics.accuracy - accuracy) < delta)
+ assert(metrics.labels.sameElements(Array(0.0, 1.0, 2.0)))
+ }
+}