aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-05-20 07:55:51 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-20 07:55:51 -0700
commit98a46f9dffec294386f6c39acafa7f11adb87a8f (patch)
treea53b1011e52f440e97780de1bf1adc43624cc276 /mllib
parent589b12f8e62ec5d10713ce057756ebc791e7ddc6 (diff)
downloadspark-98a46f9dffec294386f6c39acafa7f11adb87a8f.tar.gz
spark-98a46f9dffec294386f6c39acafa7f11adb87a8f.tar.bz2
spark-98a46f9dffec294386f6c39acafa7f11adb87a8f.zip
[SPARK-6094] [MLLIB] Add MultilabelMetrics in PySpark/MLlib
Add MultilabelMetrics in PySpark/MLlib Author: Yanbo Liang <ybliang8@gmail.com> Closes #6276 from yanboliang/spark-6094 and squashes the following commits: b8e3343 [Yanbo Liang] Add MultilabelMetrics in PySpark/MLlib
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala8
1 files changed, 8 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala
index a8378a76d2..bf6eb1d5bd 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MultilabelMetrics.scala
@@ -19,6 +19,7 @@ package org.apache.spark.mllib.evaluation
import org.apache.spark.rdd.RDD
import org.apache.spark.SparkContext._
+import org.apache.spark.sql.DataFrame
/**
* Evaluator for multilabel classification.
@@ -27,6 +28,13 @@ import org.apache.spark.SparkContext._
*/
class MultilabelMetrics(predictionAndLabels: RDD[(Array[Double], Array[Double])]) {
+ /**
+ * An auxiliary constructor taking a DataFrame.
+ * @param predictionAndLabels a DataFrame with two double array columns: prediction and label
+ */
+ private[mllib] def this(predictionAndLabels: DataFrame) =
+ this(predictionAndLabels.map(r => (r.getSeq[Double](0).toArray, r.getSeq[Double](1).toArray)))
+
private lazy val numDocs: Long = predictionAndLabels.count()
private lazy val numLabels: Long = predictionAndLabels.flatMap { case (_, labels) =>