aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/wrapper.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/wrapper.py')
-rw-r--r--python/pyspark/ml/wrapper.py17
1 files changed, 16 insertions, 1 deletions
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index 394f23c5e9..73741c4b40 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -20,7 +20,7 @@ from abc import ABCMeta
from pyspark import SparkContext
from pyspark.sql import DataFrame
from pyspark.ml.param import Params
-from pyspark.ml.pipeline import Estimator, Transformer
+from pyspark.ml.pipeline import Estimator, Transformer, Evaluator
from pyspark.mllib.common import inherit_doc
@@ -147,3 +147,18 @@ class JavaModel(JavaTransformer):
def _java_obj(self):
return self._java_model
+
+
+@inherit_doc
+class JavaEvaluator(Evaluator, JavaWrapper):
+ """
+ Base class for :py:class:`Evaluator`s that wrap Java/Scala
+ implementations.
+ """
+
+ __metaclass__ = ABCMeta
+
+ def evaluate(self, dataset, params={}):
+ java_obj = self._java_obj()
+ self._transfer_params_to_java(params, java_obj)
+ return java_obj.evaluate(dataset._jdf, self._empty_java_param_map())