aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorAlexander Ulanov <nashb@yandex.ru>2014-10-31 18:31:03 -0700
committerXiangrui Meng <meng@databricks.com>2014-10-31 18:31:03 -0700
commit62d01d255c001a6d397cc166a10aba3894f43459 (patch)
tree3c7407e1bd3b47607dd14e85e21b36649bff2cab /mllib/src/test
parent23f73f525ce3d2b4a614e60f4f9170c860ab93da (diff)
downloadspark-62d01d255c001a6d397cc166a10aba3894f43459.tar.gz
spark-62d01d255c001a6d397cc166a10aba3894f43459.tar.bz2
spark-62d01d255c001a6d397cc166a10aba3894f43459.zip
[MLLIB] SPARK-2329 Add multi-label evaluation metrics
Implementation of various multi-label classification measures, including: Hamming-loss, strict and default Accuracy, macro-averaged Precision, Recall and F1-measure based on documents and labels, micro-averaged measures: https://issues.apache.org/jira/browse/SPARK-2329 Multi-class measures are currently in the following pull request: https://github.com/apache/spark/pull/1155 Author: Alexander Ulanov <nashb@yandex.ru> Author: avulanov <nashb@yandex.ru> Closes #1270 from avulanov/multilabelmetrics and squashes the following commits: fc8175e [Alexander Ulanov] Merge with previous updates 43a613e [Alexander Ulanov] Addressing reviewers comments: change Set to Array 517a594 [avulanov] Addressing reviewers comments: Scala style cf4222bc [avulanov] Addressing reviewers comments: renaming. Added label method that returns the list of labels 1843f73 [Alexander Ulanov] Scala style fix 79e8476 [Alexander Ulanov] Replacing fold(_ + _) with sum as suggested by srowen ca46765 [Alexander Ulanov] Cosmetic changes: Apache header and parameter explanation 40593f5 [Alexander Ulanov] Multi-label metrics: Hamming-loss, strict and normal accuracy, fix to macro measures, bunch of tests ad62df0 [Alexander Ulanov] Comments and scala style check 154164b [Alexander Ulanov] Multilabel evaluation metics and tests: macro precision and recall averaged by docs, micro and per-class precision and recall averaged by class
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala103
1 files changed, 103 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala
new file mode 100644
index 0000000000..342baa0274
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/MultilabelMetricsSuite.scala
@@ -0,0 +1,103 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.evaluation
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.util.LocalSparkContext
+import org.apache.spark.rdd.RDD
+
+class MultilabelMetricsSuite extends FunSuite with LocalSparkContext {
+ test("Multilabel evaluation metrics") {
+ /*
+ * Documents true labels (5x class0, 3x class1, 4x class2):
+ * doc 0 - predict 0, 1 - class 0, 2
+ * doc 1 - predict 0, 2 - class 0, 1
+ * doc 2 - predict none - class 0
+ * doc 3 - predict 2 - class 2
+ * doc 4 - predict 2, 0 - class 2, 0
+ * doc 5 - predict 0, 1, 2 - class 0, 1
+ * doc 6 - predict 1 - class 1, 2
+ *
+ * predicted classes
+ * class 0 - doc 0, 1, 4, 5 (total 4)
+ * class 1 - doc 0, 5, 6 (total 3)
+ * class 2 - doc 1, 3, 4, 5 (total 4)
+ *
+ * true classes
+ * class 0 - doc 0, 1, 2, 4, 5 (total 5)
+ * class 1 - doc 1, 5, 6 (total 3)
+ * class 2 - doc 0, 3, 4, 6 (total 4)
+ *
+ */
+ val scoreAndLabels: RDD[(Array[Double], Array[Double])] = sc.parallelize(
+ Seq((Array(0.0, 1.0), Array(0.0, 2.0)),
+ (Array(0.0, 2.0), Array(0.0, 1.0)),
+ (Array(), Array(0.0)),
+ (Array(2.0), Array(2.0)),
+ (Array(2.0, 0.0), Array(2.0, 0.0)),
+ (Array(0.0, 1.0, 2.0), Array(0.0, 1.0)),
+ (Array(1.0), Array(1.0, 2.0))), 2)
+ val metrics = new MultilabelMetrics(scoreAndLabels)
+ val delta = 0.00001
+ val precision0 = 4.0 / (4 + 0)
+ val precision1 = 2.0 / (2 + 1)
+ val precision2 = 2.0 / (2 + 2)
+ val recall0 = 4.0 / (4 + 1)
+ val recall1 = 2.0 / (2 + 1)
+ val recall2 = 2.0 / (2 + 2)
+ val f1measure0 = 2 * precision0 * recall0 / (precision0 + recall0)
+ val f1measure1 = 2 * precision1 * recall1 / (precision1 + recall1)
+ val f1measure2 = 2 * precision2 * recall2 / (precision2 + recall2)
+ val sumTp = 4 + 2 + 2
+ assert(sumTp == (1 + 1 + 0 + 1 + 2 + 2 + 1))
+ val microPrecisionClass = sumTp.toDouble / (4 + 0 + 2 + 1 + 2 + 2)
+ val microRecallClass = sumTp.toDouble / (4 + 1 + 2 + 1 + 2 + 2)
+ val microF1MeasureClass = 2.0 * sumTp.toDouble /
+ (2 * sumTp.toDouble + (1 + 1 + 2) + (0 + 1 + 2))
+ val macroPrecisionDoc = 1.0 / 7 *
+ (1.0 / 2 + 1.0 / 2 + 0 + 1.0 / 1 + 2.0 / 2 + 2.0 / 3 + 1.0 / 1.0)
+ val macroRecallDoc = 1.0 / 7 *
+ (1.0 / 2 + 1.0 / 2 + 0 / 1 + 1.0 / 1 + 2.0 / 2 + 2.0 / 2 + 1.0 / 2)
+ val macroF1MeasureDoc = (1.0 / 7) *
+ 2 * ( 1.0 / (2 + 2) + 1.0 / (2 + 2) + 0 + 1.0 / (1 + 1) +
+ 2.0 / (2 + 2) + 2.0 / (3 + 2) + 1.0 / (1 + 2) )
+ val hammingLoss = (1.0 / (7 * 3)) * (2 + 2 + 1 + 0 + 0 + 1 + 1)
+ val strictAccuracy = 2.0 / 7
+ val accuracy = 1.0 / 7 * (1.0 / 3 + 1.0 /3 + 0 + 1.0 / 1 + 2.0 / 2 + 2.0 / 3 + 1.0 / 2)
+ assert(math.abs(metrics.precision(0.0) - precision0) < delta)
+ assert(math.abs(metrics.precision(1.0) - precision1) < delta)
+ assert(math.abs(metrics.precision(2.0) - precision2) < delta)
+ assert(math.abs(metrics.recall(0.0) - recall0) < delta)
+ assert(math.abs(metrics.recall(1.0) - recall1) < delta)
+ assert(math.abs(metrics.recall(2.0) - recall2) < delta)
+ assert(math.abs(metrics.f1Measure(0.0) - f1measure0) < delta)
+ assert(math.abs(metrics.f1Measure(1.0) - f1measure1) < delta)
+ assert(math.abs(metrics.f1Measure(2.0) - f1measure2) < delta)
+ assert(math.abs(metrics.microPrecision - microPrecisionClass) < delta)
+ assert(math.abs(metrics.microRecall - microRecallClass) < delta)
+ assert(math.abs(metrics.microF1Measure - microF1MeasureClass) < delta)
+ assert(math.abs(metrics.precision - macroPrecisionDoc) < delta)
+ assert(math.abs(metrics.recall - macroRecallDoc) < delta)
+ assert(math.abs(metrics.f1Measure - macroF1MeasureDoc) < delta)
+ assert(math.abs(metrics.hammingLoss - hammingLoss) < delta)
+ assert(math.abs(metrics.subsetAccuracy - strictAccuracy) < delta)
+ assert(math.abs(metrics.accuracy - accuracy) < delta)
+ assert(metrics.labels.sameElements(Array(0.0, 1.0, 2.0)))
+ }
+}