aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorcoderxiang <shuoxiangpub@gmail.com>2014-10-21 15:45:47 -0700
committerXiangrui Meng <meng@databricks.com>2014-10-21 15:45:47 -0700
commit814a9cd7fabebf2a06f7e2e5d46b6a2b28b917c2 (patch)
treed5c7cf40a503e2d1e6950ebd3e2889d75acf043a /mllib/src
parent5fdaf52a9df21cac69e2a4612aeb4e760e4424e7 (diff)
downloadspark-814a9cd7fabebf2a06f7e2e5d46b6a2b28b917c2.tar.gz
spark-814a9cd7fabebf2a06f7e2e5d46b6a2b28b917c2.tar.bz2
spark-814a9cd7fabebf2a06f7e2e5d46b6a2b28b917c2.zip
SPARK-3568 [mllib] add ranking metrics
Add common metrics for ranking algorithms (http://www-nlp.stanford.edu/IR-book/), including: - Mean Average Precision - Precisionn: top-n precision - Discounted cumulative gain (DCG) and NDCG The following methods and the corresponding tests are implemented: ``` class RankingMetrics[T](predictionAndLabels: RDD[(Array[T], Array[T])]) { /* Returns the precsionk for each query */ lazy val precAtK: RDD[Array[Double]] /** * param k the position to compute the truncated precision * return the average precision at the first k ranking positions */ def precision(k: Int): Double /* Returns the average precision for each query */ lazy val avePrec: RDD[Double] /*Returns the mean average precision (MAP) of all the queries*/ lazy val meanAvePrec: Double /*Returns the normalized discounted cumulative gain for each query */ lazy val ndcgAtK: RDD[Array[Double]] /** * param k the position to compute the truncated ndcg * return the average ndcg at the first k ranking positions */ def ndcg(k: Int): Double } ``` Author: coderxiang <shuoxiangpub@gmail.com> Closes #2667 from coderxiang/rankingmetrics and squashes the following commits: d881097 [coderxiang] update doc 14d9cd9 [coderxiang] remove unexpected files d7fb93f [coderxiang] style change and remove ignored files f113ee1 [coderxiang] modify doc for displaying superscript and subscript f626896 [coderxiang] improve doc and remove unnecessary computation while labSet is empty be6645e [coderxiang] set the precision of empty labset to 0.0 d64c120 [coderxiang] add logWarning for empty ground truth set dfae292 [coderxiang] handle empty labSet for map. add test 62047c4 [coderxiang] style change and add documentation f66612d [coderxiang] add additional test of precisionAt b794cb2 [coderxiang] move private members precAtK, ndcgAtK into public methods. style change 77c9e5d [coderxiang] set precAtK and ndcgAtK as private member. Improve documentation 5f87bce [coderxiang] add API to calculate precision and ndcg at each ranking position b7851cc [coderxiang] Use generic type to represent IDs e443fee [coderxiang] change style and use alternative builtin methods 3a5a6ff [coderxiang] add ranking metrics
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala152
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala54
2 files changed, 206 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
new file mode 100644
index 0000000000..93a7353e2c
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
@@ -0,0 +1,152 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.evaluation
+
+import scala.reflect.ClassTag
+
+import org.apache.spark.Logging
+import org.apache.spark.SparkContext._
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.rdd.RDD
+
+/**
+ * ::Experimental::
+ * Evaluator for ranking algorithms.
+ *
+ * @param predictionAndLabels an RDD of (predicted ranking, ground truth set) pairs.
+ */
+@Experimental
+class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])])
+ extends Logging with Serializable {
+
+ /**
+ * Compute the average precision of all the queries, truncated at ranking position k.
+ *
+ * If for a query, the ranking algorithm returns n (n < k) results, the precision value will be
+ * computed as #(relevant items retrieved) / k. This formula also applies when the size of the
+ * ground truth set is less than k.
+ *
+ * If a query has an empty ground truth set, zero will be used as precision together with
+ * a log warning.
+ *
+ * See the following paper for detail:
+ *
+ * IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen
+ *
+ * @param k the position to compute the truncated precision, must be positive
+ * @return the average precision at the first k ranking positions
+ */
+ def precisionAt(k: Int): Double = {
+ require(k > 0, "ranking position k should be positive")
+ predictionAndLabels.map { case (pred, lab) =>
+ val labSet = lab.toSet
+
+ if (labSet.nonEmpty) {
+ val n = math.min(pred.length, k)
+ var i = 0
+ var cnt = 0
+ while (i < n) {
+ if (labSet.contains(pred(i))) {
+ cnt += 1
+ }
+ i += 1
+ }
+ cnt.toDouble / k
+ } else {
+ logWarning("Empty ground truth set, check input data")
+ 0.0
+ }
+ }.mean
+ }
+
+ /**
+ * Returns the mean average precision (MAP) of all the queries.
+ * If a query has an empty ground truth set, the average precision will be zero and a log
+ * warining is generated.
+ */
+ lazy val meanAveragePrecision: Double = {
+ predictionAndLabels.map { case (pred, lab) =>
+ val labSet = lab.toSet
+
+ if (labSet.nonEmpty) {
+ var i = 0
+ var cnt = 0
+ var precSum = 0.0
+ val n = pred.length
+ while (i < n) {
+ if (labSet.contains(pred(i))) {
+ cnt += 1
+ precSum += cnt.toDouble / (i + 1)
+ }
+ i += 1
+ }
+ precSum / labSet.size
+ } else {
+ logWarning("Empty ground truth set, check input data")
+ 0.0
+ }
+ }.mean
+ }
+
+ /**
+ * Compute the average NDCG value of all the queries, truncated at ranking position k.
+ * The discounted cumulative gain at position k is computed as:
+ * sum,,i=1,,^k^ (2^{relevance of ''i''th item}^ - 1) / log(i + 1),
+ * and the NDCG is obtained by dividing the DCG value on the ground truth set. In the current
+ * implementation, the relevance value is binary.
+
+ * If a query has an empty ground truth set, zero will be used as ndcg together with
+ * a log warning.
+ *
+ * See the following paper for detail:
+ *
+ * IR evaluation methods for retrieving highly relevant documents. K. Jarvelin and J. Kekalainen
+ *
+ * @param k the position to compute the truncated ndcg, must be positive
+ * @return the average ndcg at the first k ranking positions
+ */
+ def ndcgAt(k: Int): Double = {
+ require(k > 0, "ranking position k should be positive")
+ predictionAndLabels.map { case (pred, lab) =>
+ val labSet = lab.toSet
+
+ if (labSet.nonEmpty) {
+ val labSetSize = labSet.size
+ val n = math.min(math.max(pred.length, labSetSize), k)
+ var maxDcg = 0.0
+ var dcg = 0.0
+ var i = 0
+ while (i < n) {
+ val gain = 1.0 / math.log(i + 2)
+ if (labSet.contains(pred(i))) {
+ dcg += gain
+ }
+ if (i < labSetSize) {
+ maxDcg += gain
+ }
+ i += 1
+ }
+ dcg / maxDcg
+ } else {
+ logWarning("Empty ground truth set, check input data")
+ 0.0
+ }
+ }.mean
+ }
+
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala
new file mode 100644
index 0000000000..a2d4bb4148
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala
@@ -0,0 +1,54 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.evaluation
+
+import org.scalatest.FunSuite
+
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.mllib.util.LocalSparkContext
+
+class RankingMetricsSuite extends FunSuite with LocalSparkContext {
+ test("Ranking metrics: map, ndcg") {
+ val predictionAndLabels = sc.parallelize(
+ Seq(
+ (Array[Int](1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Array[Int](1, 2, 3, 4, 5)),
+ (Array[Int](4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Array[Int](1, 2, 3)),
+ (Array[Int](1, 2, 3, 4, 5), Array[Int]())
+ ), 2)
+ val eps: Double = 1E-5
+
+ val metrics = new RankingMetrics(predictionAndLabels)
+ val map = metrics.meanAveragePrecision
+
+ assert(metrics.precisionAt(1) ~== 1.0/3 absTol eps)
+ assert(metrics.precisionAt(2) ~== 1.0/3 absTol eps)
+ assert(metrics.precisionAt(3) ~== 1.0/3 absTol eps)
+ assert(metrics.precisionAt(4) ~== 0.75/3 absTol eps)
+ assert(metrics.precisionAt(5) ~== 0.8/3 absTol eps)
+ assert(metrics.precisionAt(10) ~== 0.8/3 absTol eps)
+ assert(metrics.precisionAt(15) ~== 8.0/45 absTol eps)
+
+ assert(map ~== 0.355026 absTol eps)
+
+ assert(metrics.ndcgAt(3) ~== 1.0/3 absTol eps)
+ assert(metrics.ndcgAt(5) ~== 0.328788 absTol eps)
+ assert(metrics.ndcgAt(10) ~== 0.487913 absTol eps)
+ assert(metrics.ndcgAt(15) ~== metrics.ndcgAt(10) absTol eps)
+
+ }
+}