diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2015-05-20 07:55:51 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-05-20 07:56:00 -0700 |
commit | 606ae3e10e76325c032860ad7be1da94921af44a (patch) | |
tree | 5a510b00b6cf90d941fd0ef88120c19026211086 | |
parent | 996e2d4b38c869c99e1a094ecd30da886064b4d3 (diff) | |
download | spark-606ae3e10e76325c032860ad7be1da94921af44a.tar.gz spark-606ae3e10e76325c032860ad7be1da94921af44a.tar.bz2 spark-606ae3e10e76325c032860ad7be1da94921af44a.zip |
[SPARK-6094] [MLLIB] Add MultilabelMetrics in PySpark/MLlib
Add MultilabelMetrics in PySpark/MLlib
Author: Yanbo Liang <ybliang8@gmail.com>
Closes #6276 from yanboliang/spark-6094 and squashes the following commits:
b8e3343 [Yanbo Liang] Add MultilabelMetrics in PySpark/MLlib
(cherry picked from commit 98a46f9dffec294386f6c39acafa7f11adb87a8f)
Signed-off-by: Xiangrui Meng <meng@databricks.com>
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala | 8 | ||||
-rw-r--r-- | python/pyspark/mllib/evaluation.py | 117 |
2 files changed, 125 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala index a8378a76d2..bf6eb1d5bd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala @@ -19,6 +19,7 @@ package org.apache.spark.mllib.evaluation import org.apache.spark.rdd.RDD import org.apache.spark.SparkContext._ +import org.apache.spark.sql.DataFrame /** * Evaluator for multilabel classification. @@ -27,6 +28,13 @@ import org.apache.spark.SparkContext._ */ class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]) { + /** + * An auxiliary constructor taking a DataFrame. + * @param predictionAndLabels a DataFrame with two double array columns: prediction and label + */ + private[mllib] def this(predictionAndLabels: DataFrame) = + this(predictionAndLabels.map(r => (r.getSeq[Double](0).toArray, r.getSeq[Double](1).toArray))) + private lazy val numDocs: Long = predictionAndLabels.count() private lazy val numLabels: Long = predictionAndLabels.flatMap { case (_, labels) => diff --git a/python/pyspark/mllib/evaluation.py b/python/pyspark/mllib/evaluation.py index a5e5ddc8fe..aab5e5f4b7 100644 --- a/python/pyspark/mllib/evaluation.py +++ b/python/pyspark/mllib/evaluation.py @@ -343,6 +343,123 @@ class RankingMetrics(JavaModelWrapper): return self.call("ndcgAt", int(k)) +class MultilabelMetrics(JavaModelWrapper): + """ + Evaluator for multilabel classification. + + >>> predictionAndLabels = sc.parallelize([([0.0, 1.0], [0.0, 2.0]), ([0.0, 2.0], [0.0, 1.0]), + ... ([], [0.0]), ([2.0], [2.0]), ([2.0, 0.0], [2.0, 0.0]), + ... ([0.0, 1.0, 2.0], [0.0, 1.0]), ([1.0], [1.0, 2.0])]) + >>> metrics = MultilabelMetrics(predictionAndLabels) + >>> metrics.precision(0.0) + 1.0 + >>> metrics.recall(1.0) + 0.66... + >>> metrics.f1Measure(2.0) + 0.5 + >>> metrics.precision() + 0.66... + >>> metrics.recall() + 0.64... + >>> metrics.f1Measure() + 0.63... + >>> metrics.microPrecision + 0.72... + >>> metrics.microRecall + 0.66... + >>> metrics.microF1Measure + 0.69... + >>> metrics.hammingLoss + 0.33... + >>> metrics.subsetAccuracy + 0.28... + >>> metrics.accuracy + 0.54... + """ + + def __init__(self, predictionAndLabels): + sc = predictionAndLabels.ctx + sql_ctx = SQLContext(sc) + df = sql_ctx.createDataFrame(predictionAndLabels, + schema=sql_ctx._inferSchema(predictionAndLabels)) + java_class = sc._jvm.org.apache.spark.mllib.evaluation.MultilabelMetrics + java_model = java_class(df._jdf) + super(MultilabelMetrics, self).__init__(java_model) + + def precision(self, label=None): + """ + Returns precision or precision for a given label (category) if specified. + """ + if label is None: + return self.call("precision") + else: + return self.call("precision", float(label)) + + def recall(self, label=None): + """ + Returns recall or recall for a given label (category) if specified. + """ + if label is None: + return self.call("recall") + else: + return self.call("recall", float(label)) + + def f1Measure(self, label=None): + """ + Returns f1Measure or f1Measure for a given label (category) if specified. + """ + if label is None: + return self.call("f1Measure") + else: + return self.call("f1Measure", float(label)) + + @property + def microPrecision(self): + """ + Returns micro-averaged label-based precision. + (equals to micro-averaged document-based precision) + """ + return self.call("microPrecision") + + @property + def microRecall(self): + """ + Returns micro-averaged label-based recall. + (equals to micro-averaged document-based recall) + """ + return self.call("microRecall") + + @property + def microF1Measure(self): + """ + Returns micro-averaged label-based f1-measure. + (equals to micro-averaged document-based f1-measure) + """ + return self.call("microF1Measure") + + @property + def hammingLoss(self): + """ + Returns Hamming-loss. + """ + return self.call("hammingLoss") + + @property + def subsetAccuracy(self): + """ + Returns subset accuracy. + (for equal sets of labels) + """ + return self.call("subsetAccuracy") + + @property + def accuracy(self): + """ + Returns accuracy. + """ + return self.call("accuracy") + + def _test(): import doctest from pyspark import SparkContext |