diff options
author | Xiangrui Meng <meng@databricks.com> | 2015-05-05 11:45:37 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-05-05 11:45:37 -0700 |
commit | ee374e89cd1f08730fed9d50b742627d5b19d241 (patch) | |
tree | 9912c353fe5e563bbf7ced6dc0e0c20f20272d5d /python/pyspark/ml/wrapper.py | |
parent | 18340d7be55a6834918956555bf820c96769aa52 (diff) | |
download | spark-ee374e89cd1f08730fed9d50b742627d5b19d241.tar.gz spark-ee374e89cd1f08730fed9d50b742627d5b19d241.tar.bz2 spark-ee374e89cd1f08730fed9d50b742627d5b19d241.zip |
[SPARK-7333] [MLLIB] Add BinaryClassificationEvaluator to PySpark
This PR adds `BinaryClassificationEvaluator` to Python ML Pipelines API, which is a simple wrapper of the Scala implementation. oefirouz
Author: Xiangrui Meng <meng@databricks.com>
Closes #5885 from mengxr/SPARK-7333 and squashes the following commits:
25d7451 [Xiangrui Meng] fix tests in python 3
babdde7 [Xiangrui Meng] fix doc
cb51e6a [Xiangrui Meng] add BinaryClassificationEvaluator in PySpark
Diffstat (limited to 'python/pyspark/ml/wrapper.py')
-rw-r--r-- | python/pyspark/ml/wrapper.py | 17 |
1 files changed, 16 insertions, 1 deletions
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 394f23c5e9..73741c4b40 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -20,7 +20,7 @@ from abc import ABCMeta from pyspark import SparkContext from pyspark.sql import DataFrame from pyspark.ml.param import Params -from pyspark.ml.pipeline import Estimator, Transformer +from pyspark.ml.pipeline import Estimator, Transformer, Evaluator from pyspark.mllib.common import inherit_doc @@ -147,3 +147,18 @@ class JavaModel(JavaTransformer): def _java_obj(self): return self._java_model + + +@inherit_doc +class JavaEvaluator(Evaluator, JavaWrapper): + """ + Base class for :py:class:`Evaluator`s that wrap Java/Scala + implementations. + """ + + __metaclass__ = ABCMeta + + def evaluate(self, dataset, params={}): + java_obj = self._java_obj() + self._transfer_params_to_java(params, java_obj) + return java_obj.evaluate(dataset._jdf, self._empty_java_param_map()) |