From 129f2f455da982ec9fab593299fa4021b62827eb Mon Sep 17 00:00:00 2001 From: sethah Date: Fri, 15 Apr 2016 12:14:41 -0700 Subject: [SPARK-14104][PYSPARK][ML] All Python param setters should use the `_set` method ## What changes were proposed in this pull request? Param setters in python previously accessed the _paramMap directly to update values. The `_set` method now implements type checking, so it should be used to update all parameters. This PR eliminates all direct accesses to `_paramMap` besides the one in the `_set` method to ensure type checking happens. Additional changes: * [SPARK-13068](https://github.com/apache/spark/pull/11663) missed adding type converters in evaluation.py so those are done here * An incorrect `toBoolean` type converter was used for StringIndexer `handleInvalid` param in previous PR. This is fixed here. ## How was this patch tested? Existing unit tests verify that parameters are still set properly. No new functionality is actually added in this PR. Author: sethah Closes #11939 from sethah/SPARK-14104. --- python/pyspark/ml/param/__init__.py | 22 +++++++++++++++++++--- python/pyspark/ml/param/_shared_params_code_gen.py | 2 +- python/pyspark/ml/param/shared.py | 2 +- 3 files changed, 21 insertions(+), 5 deletions(-) (limited to 'python/pyspark/ml/param') diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index a1265294a1..9f0b063aac 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -26,6 +26,8 @@ import copy import numpy as np import warnings +from py4j.java_gateway import JavaObject + from pyspark import since from pyspark.ml.util import Identifiable from pyspark.mllib.linalg import DenseVector, Vector @@ -389,8 +391,8 @@ class Params(Identifiable): if extra is None: extra = dict() that = copy.copy(self) - that._paramMap = self.extractParamMap(extra) - return that + that._paramMap = {} + return self._copyValues(that, extra) def _shouldOwn(self, param): """ @@ -439,12 +441,26 @@ class Params(Identifiable): self._paramMap[p] = value return self + def _clear(self, param): + """ + Clears a param from the param map if it has been explicitly set. + """ + if self.isSet(param): + del self._paramMap[param] + def _setDefault(self, **kwargs): """ Sets default params. """ for param, value in kwargs.items(): - self._defaultParamMap[getattr(self, param)] = value + p = getattr(self, param) + if value is not None and not isinstance(value, JavaObject): + try: + value = p.typeConverter(value) + except TypeError as e: + raise TypeError('Invalid default param value given for param "%s". %s' + % (p.name, e)) + self._defaultParamMap[p] = value return self def _copyValues(self, to, extra=None): diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index a7615c43be..a2acf956bc 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -131,7 +131,7 @@ if __name__ == "__main__": "TypeConverters.toFloat"), ("handleInvalid", "how to handle invalid entries. Options are skip (which will filter " + "out rows with bad values), or error (which will throw an errror). More options may be " + - "added later.", None, "TypeConverters.toBoolean"), + "added later.", None, "TypeConverters.toString"), ("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", "0.0", "TypeConverters.toFloat"), diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index c9e975525c..538c0b718a 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -392,7 +392,7 @@ class HasHandleInvalid(Params): Mixin for param handleInvalid: how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later. """ - handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.", typeConverter=TypeConverters.toBoolean) + handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.", typeConverter=TypeConverters.toString) def __init__(self): super(HasHandleInvalid, self).__init__() -- cgit v1.2.3