aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-05-10 00:57:14 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-10 00:57:14 -0700
commitbf7e81a51cd81706570615cd67362c86602dec88 (patch)
treedc3d55d57d58606fe4c10f8bb2ec0be428ec24b6 /mllib
parentb13162b364aeff35e3bdeea9c9a31e5ce66f8c9a (diff)
downloadspark-bf7e81a51cd81706570615cd67362c86602dec88.tar.gz
spark-bf7e81a51cd81706570615cd67362c86602dec88.tar.bz2
spark-bf7e81a51cd81706570615cd67362c86602dec88.zip
[SPARK-6091] [MLLIB] Add MulticlassMetrics in PySpark/MLlib
https://issues.apache.org/jira/browse/SPARK-6091 Author: Yanbo Liang <ybliang8@gmail.com> Closes #6011 from yanboliang/spark-6091 and squashes the following commits: bb3e4ba [Yanbo Liang] trigger jenkins 53c045d [Yanbo Liang] keep compatibility for python 2.6 972d5ac [Yanbo Liang] Add MulticlassMetrics in PySpark/MLlib
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala8
1 files changed, 8 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala
index 666362ae67..4628dc5690 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/MulticlassMetrics.scala
@@ -23,6 +23,7 @@ import org.apache.spark.SparkContext._
import org.apache.spark.annotation.Experimental
import org.apache.spark.mllib.linalg.{Matrices, Matrix}
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
/**
* ::Experimental::
@@ -33,6 +34,13 @@ import org.apache.spark.rdd.RDD
@Experimental
class MulticlassMetrics(predictionAndLabels: RDD[(Double, Double)]) {
+ /**
+ * An auxiliary constructor taking a DataFrame.
+ * @param predictionAndLabels a DataFrame with two double columns: prediction and label
+ */
+ private[mllib] def this(predictionAndLabels: DataFrame) =
+ this(predictionAndLabels.map(r => (r.getDouble(0), r.getDouble(1))))
+
private lazy val labelCountByClass: Map[Double, Long] = predictionAndLabels.values.countByValue()
private lazy val labelCount: Long = labelCountByClass.values.sum
private lazy val tpByClass: Map[Double, Int] = predictionAndLabels