aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-05-11 09:14:20 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-11 09:14:20 -0700
commit042dda3c5c25b5ecb6ae4fd37c85b211b01c187b (patch)
treeab4fee18073ed758b01dd62d1ce48434fc1c947c /mllib
parentd70a076892e0677acceccaba665908cdf664f1b4 (diff)
downloadspark-042dda3c5c25b5ecb6ae4fd37c85b211b01c187b.tar.gz
spark-042dda3c5c25b5ecb6ae4fd37c85b211b01c187b.tar.bz2
spark-042dda3c5c25b5ecb6ae4fd37c85b211b01c187b.zip
[SPARK-6092] [MLLIB] Add RankingMetrics in PySpark/MLlib
Author: Yanbo Liang <ybliang8@gmail.com> Closes #6044 from yanboliang/spark-6092 and squashes the following commits: 726a9b1 [Yanbo Liang] add newRankingMetrics 33f649c [Yanbo Liang] Add RankingMetrics in PySpark/MLlib
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala10
1 files changed, 10 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
index 8c30ad4b39..f4c4775965 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/api/python/PythonMLLibAPI.scala
@@ -32,6 +32,7 @@ import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.api.python.SerDeUtil
import org.apache.spark.mllib.classification._
import org.apache.spark.mllib.clustering._
+import org.apache.spark.mllib.evaluation.RankingMetrics
import org.apache.spark.mllib.feature._
import org.apache.spark.mllib.fpm.{FPGrowth, FPGrowthModel}
import org.apache.spark.mllib.linalg._
@@ -50,6 +51,7 @@ import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTree
import org.apache.spark.mllib.tree.{DecisionTree, GradientBoostedTrees, RandomForest}
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.DataFrame
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.Utils
@@ -923,6 +925,14 @@ private[python] class PythonMLLibAPI extends Serializable {
RG.gammaVectorRDD(jsc.sc, shape, scale, numRows, numCols, parts, s)
}
+ /**
+ * Java stub for the constructor of Python mllib RankingMetrics
+ */
+ def newRankingMetrics(predictionAndLabels: DataFrame): RankingMetrics[Any] = {
+ new RankingMetrics(predictionAndLabels.map(
+ r => (r.getSeq(0).toArray[Any], r.getSeq(1).toArray[Any])))
+ }
+
}