From 8f11c6116bf8c7246682cbb2d6f27bf0f1531c6d Mon Sep 17 00:00:00 2001 From: Xiangrui Meng Date: Thu, 21 May 2015 22:57:33 -0700 Subject: [SPARK-7535] [.0] [MLLIB] Audit the pipeline APIs for 1.4 MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Some changes to the pipeilne APIs: 1. Estimator/Transformer/ doesn’t need to extend Params since PipelineStage already does. 1. Move Evaluator to ml.evaluation. 1. Mention larger metric values are better. 1. PipelineModel doc. “compiled” -> “fitted” 1. Hide object PolynomialExpansion. 1. Hide object VectorAssembler. 1. Word2Vec.minCount (and other) -> group param 1. ParamValidators -> DeveloperApi 1. Hide MetadataUtils/SchemaUtils. jkbradley Author: Xiangrui Meng Closes #6322 from mengxr/SPARK-7535.0 and squashes the following commits: 9e9c7da [Xiangrui Meng] move JavaEvaluator to ml.evaluation as well e179480 [Xiangrui Meng] move Evaluation to ml.evaluation in PySpark 08ef61f [Xiangrui Meng] update pipieline APIs --- python/pyspark/ml/__init__.py | 4 +-- python/pyspark/ml/evaluation.py | 63 +++++++++++++++++++++++++++++++++++++++-- python/pyspark/ml/pipeline.py | 37 ------------------------ python/pyspark/ml/wrapper.py | 21 +------------- 4 files changed, 64 insertions(+), 61 deletions(-) (limited to 'python') diff --git a/python/pyspark/ml/__init__.py b/python/pyspark/ml/__init__.py index da793d9db7..327a11b14b 100644 --- a/python/pyspark/ml/__init__.py +++ b/python/pyspark/ml/__init__.py @@ -15,6 +15,6 @@ # limitations under the License. # -from pyspark.ml.pipeline import Transformer, Estimator, Model, Pipeline, PipelineModel, Evaluator +from pyspark.ml.pipeline import Transformer, Estimator, Model, Pipeline, PipelineModel -__all__ = ["Transformer", "Estimator", "Model", "Pipeline", "PipelineModel", "Evaluator"] +__all__ = ["Transformer", "Estimator", "Model", "Pipeline", "PipelineModel"] diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index f4655c513c..34e1353def 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -15,13 +15,72 @@ # limitations under the License. # -from pyspark.ml.wrapper import JavaEvaluator +from abc import abstractmethod, ABCMeta + +from pyspark.ml.wrapper import JavaWrapper from pyspark.ml.param import Param, Params from pyspark.ml.param.shared import HasLabelCol, HasRawPredictionCol from pyspark.ml.util import keyword_only from pyspark.mllib.common import inherit_doc -__all__ = ['BinaryClassificationEvaluator'] +__all__ = ['Evaluator', 'BinaryClassificationEvaluator'] + + +@inherit_doc +class Evaluator(Params): + """ + Base class for evaluators that compute metrics from predictions. + """ + + __metaclass__ = ABCMeta + + @abstractmethod + def _evaluate(self, dataset): + """ + Evaluates the output. + + :param dataset: a dataset that contains labels/observations and + predictions + :return: metric + """ + raise NotImplementedError() + + def evaluate(self, dataset, params={}): + """ + Evaluates the output with optional parameters. + + :param dataset: a dataset that contains labels/observations and + predictions + :param params: an optional param map that overrides embedded + params + :return: metric + """ + if isinstance(params, dict): + if params: + return self.copy(params)._evaluate(dataset) + else: + return self._evaluate(dataset) + else: + raise ValueError("Params must be a param map but got %s." % type(params)) + + +@inherit_doc +class JavaEvaluator(Evaluator, JavaWrapper): + """ + Base class for :py:class:`Evaluator`s that wrap Java/Scala + implementations. + """ + + __metaclass__ = ABCMeta + + def _evaluate(self, dataset): + """ + Evaluates the output. + :param dataset: a dataset that contains labels/observations and predictions. + :return: evaluation metric + """ + self._transfer_params_to_java() + return self._java_obj.evaluate(dataset._jdf) @inherit_doc diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 0f38e02127..a563024b2c 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -219,40 +219,3 @@ class PipelineModel(Model): def copy(self, extra={}): stages = [stage.copy(extra) for stage in self.stages] return PipelineModel(stages) - - -class Evaluator(Params): - """ - Base class for evaluators that compute metrics from predictions. - """ - - __metaclass__ = ABCMeta - - @abstractmethod - def _evaluate(self, dataset): - """ - Evaluates the output. - - :param dataset: a dataset that contains labels/observations and - predictions - :return: metric - """ - raise NotImplementedError() - - def evaluate(self, dataset, params={}): - """ - Evaluates the output with optional parameters. - - :param dataset: a dataset that contains labels/observations and - predictions - :param params: an optional param map that overrides embedded - params - :return: metric - """ - if isinstance(params, dict): - if params: - return self.copy(params)._evaluate(dataset) - else: - return self._evaluate(dataset) - else: - raise ValueError("Params must be a param map but got %s." % type(params)) diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 4419e16184..7b0893e2cd 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -20,7 +20,7 @@ from abc import ABCMeta from pyspark import SparkContext from pyspark.sql import DataFrame from pyspark.ml.param import Params -from pyspark.ml.pipeline import Estimator, Transformer, Evaluator, Model +from pyspark.ml.pipeline import Estimator, Transformer, Model from pyspark.mllib.common import inherit_doc, _java2py, _py2java @@ -185,22 +185,3 @@ class JavaModel(Model, JavaTransformer): sc = SparkContext._active_spark_context java_args = [_py2java(sc, arg) for arg in args] return _java2py(sc, m(*java_args)) - - -@inherit_doc -class JavaEvaluator(Evaluator, JavaWrapper): - """ - Base class for :py:class:`Evaluator`s that wrap Java/Scala - implementations. - """ - - __metaclass__ = ABCMeta - - def _evaluate(self, dataset): - """ - Evaluates the output. - :param dataset: a dataset that contains labels/observations and predictions. - :return: evaluation metric - """ - self._transfer_params_to_java() - return self._java_obj.evaluate(dataset._jdf) -- cgit v1.2.3