aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-05-11 09:14:20 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-11 09:14:20 -0700
commit042dda3c5c25b5ecb6ae4fd37c85b211b01c187b (patch)
treeab4fee18073ed758b01dd62d1ce48434fc1c947c /python
parentd70a076892e0677acceccaba665908cdf664f1b4 (diff)
downloadspark-042dda3c5c25b5ecb6ae4fd37c85b211b01c187b.tar.gz
spark-042dda3c5c25b5ecb6ae4fd37c85b211b01c187b.tar.bz2
spark-042dda3c5c25b5ecb6ae4fd37c85b211b01c187b.zip
[SPARK-6092] [MLLIB] Add RankingMetrics in PySpark/MLlib
Author: Yanbo Liang <ybliang8@gmail.com> Closes #6044 from yanboliang/spark-6092 and squashes the following commits: 726a9b1 [Yanbo Liang] add newRankingMetrics 33f649c [Yanbo Liang] Add RankingMetrics in PySpark/MLlib
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/mllib/evaluation.py78
1 files changed, 76 insertions, 2 deletions
diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py
index 36914597de..4c777f2180 100644
--- a/python/pyspark/mllib/evaluation.py
+++ b/python/pyspark/mllib/evaluation.py
@@ -15,9 +15,12 @@
# limitations under the License.
#
-from pyspark.mllib.common import JavaModelWrapper
+from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc
from pyspark.sql import SQLContext
-from pyspark.sql.types import StructField, StructType, DoubleType
+from pyspark.sql.types import StructField, StructType, DoubleType, IntegerType, ArrayType
+
+__all__ = ['BinaryClassificationMetrics', 'RegressionMetrics',
+ 'MulticlassMetrics', 'RankingMetrics']
class BinaryClassificationMetrics(JavaModelWrapper):
@@ -270,6 +273,77 @@ class MulticlassMetrics(JavaModelWrapper):
return self.call("weightedFMeasure", beta)
+class RankingMetrics(JavaModelWrapper):
+ """
+ Evaluator for ranking algorithms.
+
+ >>> predictionAndLabels = sc.parallelize([
+ ... ([1, 6, 2, 7, 8, 3, 9, 10, 4, 5], [1, 2, 3, 4, 5]),
+ ... ([4, 1, 5, 6, 2, 7, 3, 8, 9, 10], [1, 2, 3]),
+ ... ([1, 2, 3, 4, 5], [])])
+ >>> metrics = RankingMetrics(predictionAndLabels)
+ >>> metrics.precisionAt(1)
+ 0.33...
+ >>> metrics.precisionAt(5)
+ 0.26...
+ >>> metrics.precisionAt(15)
+ 0.17...
+ >>> metrics.meanAveragePrecision
+ 0.35...
+ >>> metrics.ndcgAt(3)
+ 0.33...
+ >>> metrics.ndcgAt(10)
+ 0.48...
+
+ """
+
+ def __init__(self, predictionAndLabels):
+ """
+ :param predictionAndLabels: an RDD of (predicted ranking, ground truth set) pairs.
+ """
+ sc = predictionAndLabels.ctx
+ sql_ctx = SQLContext(sc)
+ df = sql_ctx.createDataFrame(predictionAndLabels,
+ schema=sql_ctx._inferSchema(predictionAndLabels))
+ java_model = callMLlibFunc("newRankingMetrics", df._jdf)
+ super(RankingMetrics, self).__init__(java_model)
+
+ def precisionAt(self, k):
+ """
+ Compute the average precision of all the queries, truncated at ranking position k.
+
+ If for a query, the ranking algorithm returns n (n < k) results, the precision value
+ will be computed as #(relevant items retrieved) / k. This formula also applies when
+ the size of the ground truth set is less than k.
+
+ If a query has an empty ground truth set, zero will be used as precision together
+ with a log warning.
+ """
+ return self.call("precisionAt", int(k))
+
+ @property
+ def meanAveragePrecision(self):
+ """
+ Returns the mean average precision (MAP) of all the queries.
+ If a query has an empty ground truth set, the average precision will be zero and
+ a log warining is generated.
+ """
+ return self.call("meanAveragePrecision")
+
+ def ndcgAt(self, k):
+ """
+ Compute the average NDCG value of all the queries, truncated at ranking position k.
+ The discounted cumulative gain at position k is computed as:
+ sum,,i=1,,^k^ (2^{relevance of ''i''th item}^ - 1) / log(i + 1),
+ and the NDCG is obtained by dividing the DCG value on the ground truth set.
+ In the current implementation, the relevance value is binary.
+
+ If a query has an empty ground truth set, zero will be used as ndcg together with
+ a log warning.
+ """
+ return self.call("ndcgAt", int(k))
+
+
def _test():
import doctest
from pyspark import SparkContext