aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSean Owen <sowen@cloudera.com>2016-04-29 09:21:27 +0200
committerNick Pentreath <nickp@za.ibm.com>2016-04-29 09:21:27 +0200
commitd1cf320105504f908ee01f33044d0a6b29c3c03f (patch)
treeb2116c0a0cd1a1a8833dc2fe80b4d07e8423c2c1
parent24d07e45d4e2e5222043b9a4447aa6c384069d4f (diff)
downloadspark-d1cf320105504f908ee01f33044d0a6b29c3c03f.tar.gz
spark-d1cf320105504f908ee01f33044d0a6b29c3c03f.tar.bz2
spark-d1cf320105504f908ee01f33044d0a6b29c3c03f.zip
[SPARK-14886][MLLIB] RankingMetrics.ndcgAt throw java.lang.ArrayIndexOutOfBoundsException
## What changes were proposed in this pull request? Handle case where number of predictions is less than label set, k in nDCG computation ## How was this patch tested? New unit test; existing tests Author: Sean Owen <sowen@cloudera.com> Closes #12756 from srowen/SPARK-14886.
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala26
2 files changed, 22 insertions, 6 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
index c45742cebb..4ed4a05894 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
@@ -140,7 +140,7 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]
var i = 0
while (i < n) {
val gain = 1.0 / math.log(i + 2)
- if (labSet.contains(pred(i))) {
+ if (i < pred.length && labSet.contains(pred(i))) {
dcg += gain
}
if (i < labSetSize) {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala
index 77ec49d005..8e9d910e64 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/evaluation/RankingMetricsSuite.scala
@@ -22,14 +22,15 @@ import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.mllib.util.TestingUtils._
class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
- test("Ranking metrics: map, ndcg") {
+
+ test("Ranking metrics: MAP, NDCG") {
val predictionAndLabels = sc.parallelize(
Seq(
- (Array[Int](1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Array[Int](1, 2, 3, 4, 5)),
- (Array[Int](4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Array[Int](1, 2, 3)),
- (Array[Int](1, 2, 3, 4, 5), Array[Int]())
+ (Array(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Array(1, 2, 3, 4, 5)),
+ (Array(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Array(1, 2, 3)),
+ (Array(1, 2, 3, 4, 5), Array[Int]())
), 2)
- val eps: Double = 1E-5
+ val eps = 1.0E-5
val metrics = new RankingMetrics(predictionAndLabels)
val map = metrics.meanAveragePrecision
@@ -48,6 +49,21 @@ class RankingMetricsSuite extends SparkFunSuite with MLlibTestSparkContext {
assert(metrics.ndcgAt(5) ~== 0.328788 absTol eps)
assert(metrics.ndcgAt(10) ~== 0.487913 absTol eps)
assert(metrics.ndcgAt(15) ~== metrics.ndcgAt(10) absTol eps)
+ }
+
+ test("MAP, NDCG with few predictions (SPARK-14886)") {
+ val predictionAndLabels = sc.parallelize(
+ Seq(
+ (Array(1, 6, 2), Array(1, 2, 3, 4, 5)),
+ (Array[Int](), Array(1, 2, 3))
+ ), 2)
+ val eps = 1.0E-5
+ val metrics = new RankingMetrics(predictionAndLabels)
+ assert(metrics.precisionAt(1) ~== 0.5 absTol eps)
+ assert(metrics.precisionAt(2) ~== 0.25 absTol eps)
+ assert(metrics.ndcgAt(1) ~== 0.5 absTol eps)
+ assert(metrics.ndcgAt(2) ~== 0.30657 absTol eps)
}
+
}