aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-03-05 11:50:09 -0800
committerXiangrui Meng <meng@databricks.com>2015-03-05 11:50:09 -0800
commit0bfacd5c5dd7d10a69bcbcbda630f0843d1cf285 (patch)
tree2b13352131bb3dbd88e4214c6c7728d26898d25e /mllib
parentc9cfba0cebe3eb546e3e96f3e5b9b89a74c5b7de (diff)
downloadspark-0bfacd5c5dd7d10a69bcbcbda630f0843d1cf285.tar.gz
spark-0bfacd5c5dd7d10a69bcbcbda630f0843d1cf285.tar.bz2
spark-0bfacd5c5dd7d10a69bcbcbda630f0843d1cf285.zip
[SPARK-6090][MLLIB] add a basic BinaryClassificationMetrics to PySpark/MLlib
A simple wrapper around the Scala implementation. `DataFrame` is used for serialization/deserialization. Methods that return `RDD`s are not supported in this PR. davies If we recognize Scala's `Product`s in Py4J, we can easily add wrappers for Scala methods that returns `RDD[(Double, Double)]`. Is it easy to register serializer for `Product` in PySpark? Author: Xiangrui Meng <meng@databricks.com> Closes #4863 from mengxr/SPARK-6090 and squashes the following commits: 009a3a3 [Xiangrui Meng] provide schema dcddab5 [Xiangrui Meng] add a basic BinaryClassificationMetrics to PySpark/MLlib
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala8
1 files changed, 8 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala
index ced042e2f9..c1d1a22481 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/evaluation/BinaryClassificationMetrics.scala
@@ -22,6 +22,7 @@ import org.apache.spark.Logging
import org.apache.spark.SparkContext._
import org.apache.spark.mllib.evaluation.binary._
import org.apache.spark.rdd.{RDD, UnionRDD}
+import org.apache.spark.sql.DataFrame
/**
* :: Experimental ::
@@ -53,6 +54,13 @@ class BinaryClassificationMetrics(
*/
def this(scoreAndLabels: RDD[(Double, Double)]) = this(scoreAndLabels, 0)
+ /**
+ * An auxiliary constructor taking a DataFrame.
+ * @param scoreAndLabels a DataFrame with two double columns: score and label
+ */
+ private[mllib] def this(scoreAndLabels: DataFrame) =
+ this(scoreAndLabels.map(r => (r.getDouble(0), r.getDouble(1))))
+
/** Unpersist intermediate RDDs used in the computation. */
def unpersist() {
cumulativeCounts.unpersist()