aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-05-12 16:53:47 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-12 16:53:47 -0700
commit2713bc65af1e0e81edd5fad0338e34fd127391f9 (patch)
treee2ed79db61c13a1b9a4e353372a3b6c52543285d
parent00e7b09a0bee2fcfd0ce34992bd26435758daf26 (diff)
downloadspark-2713bc65af1e0e81edd5fad0338e34fd127391f9.tar.gz
spark-2713bc65af1e0e81edd5fad0338e34fd127391f9.tar.bz2
spark-2713bc65af1e0e81edd5fad0338e34fd127391f9.zip
[SPARK-7528] [MLLIB] make RankingMetrics Java-friendly
`RankingMetrics` contains a ClassTag, which is hard to create in Java. This PR adds a factory method `of` for Java users. coderxiang Author: Xiangrui Meng <meng@databricks.com> Closes #6098 from mengxr/SPARK-7528 and squashes the following commits: e5d57ae [Xiangrui Meng] make RankingMetrics Java-friendly
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala27
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java64
2 files changed, 87 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
index 93a7353e2c..b9b54b93c2 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/RankingMetrics.scala
@@ -17,11 +17,14 @@
package org.apache.spark.mllib.evaluation
+import java.{lang => jl}
+
+import scala.collection.JavaConverters._
import scala.reflect.ClassTag
import org.apache.spark.Logging
-import org.apache.spark.SparkContext._
import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.{JavaSparkContext, JavaRDD}
import org.apache.spark.rdd.RDD
/**
@@ -71,7 +74,7 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]
logWarning("Empty ground truth set, check input data")
0.0
}
- }.mean
+ }.mean()
}
/**
@@ -100,7 +103,7 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]
logWarning("Empty ground truth set, check input data")
0.0
}
- }.mean
+ }.mean()
}
/**
@@ -146,7 +149,23 @@ class RankingMetrics[T: ClassTag](predictionAndLabels: RDD[(Array[T], Array[T])]
logWarning("Empty ground truth set, check input data")
0.0
}
- }.mean
+ }.mean()
}
}
+
+@Experimental
+object RankingMetrics {
+
+ /**
+ * Creates a [[RankingMetrics]] instance (for Java users).
+ * @param predictionAndLabels a JavaRDD of (predicted ranking, ground truth set) pairs
+ */
+ def of[E, T <: jl.Iterable[E]](predictionAndLabels: JavaRDD[(T, T)]): RankingMetrics[E] = {
+ implicit val tag = JavaSparkContext.fakeClassTag[E]
+ val rdd = predictionAndLabels.rdd.map { case (predictions, labels) =>
+ (predictions.asScala.toArray, labels.asScala.toArray)
+ }
+ new RankingMetrics(rdd)
+ }
+}
diff --git a/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java b/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java
new file mode 100644
index 0000000000..effc8a1a6d
--- /dev/null
+++ b/mllib/src/test/java/org/apache/spark/mllib/evaluation/JavaRankingMetricsSuite.java
@@ -0,0 +1,64 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.evaluation;
+
+import java.io.Serializable;
+import java.util.ArrayList;
+
+import scala.Tuple2;
+import scala.Tuple2$;
+
+import com.google.common.collect.Lists;
+import org.junit.After;
+import org.junit.Assert;
+import org.junit.Before;
+import org.junit.Test;
+
+import org.apache.spark.api.java.JavaRDD;
+import org.apache.spark.api.java.JavaSparkContext;
+
+public class JavaRankingMetricsSuite implements Serializable {
+ private transient JavaSparkContext sc;
+ private transient JavaRDD<Tuple2<ArrayList<Integer>, ArrayList<Integer>>> predictionAndLabels;
+
+ @Before
+ public void setUp() {
+ sc = new JavaSparkContext("local", "JavaRankingMetricsSuite");
+ predictionAndLabels = sc.parallelize(Lists.newArrayList(
+ Tuple2$.MODULE$.apply(
+ Lists.newArrayList(1, 6, 2, 7, 8, 3, 9, 10, 4, 5), Lists.newArrayList(1, 2, 3, 4, 5)),
+ Tuple2$.MODULE$.apply(
+ Lists.newArrayList(4, 1, 5, 6, 2, 7, 3, 8, 9, 10), Lists.newArrayList(1, 2, 3)),
+ Tuple2$.MODULE$.apply(
+ Lists.newArrayList(1, 2, 3, 4, 5), Lists.<Integer>newArrayList())), 2);
+ }
+
+ @After
+ public void tearDown() {
+ sc.stop();
+ sc = null;
+ }
+
+ @Test
+ public void rankingMetrics() {
+ @SuppressWarnings("unchecked")
+ RankingMetrics<?> metrics = RankingMetrics.of(predictionAndLabels);
+ Assert.assertEquals(0.355026, metrics.meanAveragePrecision(), 1e-5);
+ Assert.assertEquals(0.75 / 3.0, metrics.precisionAt(4), 1e-5);
+ }
+}