aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/param
diff options
context:
space:
mode:
authorsethah <seth.hendrickson16@gmail.com>2016-04-15 12:14:41 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-15 12:14:41 -0700
commit129f2f455da982ec9fab593299fa4021b62827eb (patch)
tree4cb68c4b09db6e572db333acd8ee242a4a4fbcbe /python/pyspark/ml/param
parentd6ae7d4637d23c57c4eeab79d1177216f380ec9c (diff)
downloadspark-129f2f455da982ec9fab593299fa4021b62827eb.tar.gz
spark-129f2f455da982ec9fab593299fa4021b62827eb.tar.bz2
spark-129f2f455da982ec9fab593299fa4021b62827eb.zip
[SPARK-14104][PYSPARK][ML] All Python param setters should use the `_set` method
## What changes were proposed in this pull request? Param setters in python previously accessed the _paramMap directly to update values. The `_set` method now implements type checking, so it should be used to update all parameters. This PR eliminates all direct accesses to `_paramMap` besides the one in the `_set` method to ensure type checking happens. Additional changes: * [SPARK-13068](https://github.com/apache/spark/pull/11663) missed adding type converters in evaluation.py so those are done here * An incorrect `toBoolean` type converter was used for StringIndexer `handleInvalid` param in previous PR. This is fixed here. ## How was this patch tested? Existing unit tests verify that parameters are still set properly. No new functionality is actually added in this PR. Author: sethah <seth.hendrickson16@gmail.com> Closes #11939 from sethah/SPARK-14104.
Diffstat (limited to 'python/pyspark/ml/param')
-rw-r--r--python/pyspark/ml/param/__init__.py22
-rw-r--r--python/pyspark/ml/param/_shared_params_code_gen.py2
-rw-r--r--python/pyspark/ml/param/shared.py2
3 files changed, 21 insertions, 5 deletions
diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py
index a1265294a1..9f0b063aac 100644
--- a/python/pyspark/ml/param/__init__.py
+++ b/python/pyspark/ml/param/__init__.py
@@ -26,6 +26,8 @@ import copy
import numpy as np
import warnings
+from py4j.java_gateway import JavaObject
+
from pyspark import since
from pyspark.ml.util import Identifiable
from pyspark.mllib.linalg import DenseVector, Vector
@@ -389,8 +391,8 @@ class Params(Identifiable):
if extra is None:
extra = dict()
that = copy.copy(self)
- that._paramMap = self.extractParamMap(extra)
- return that
+ that._paramMap = {}
+ return self._copyValues(that, extra)
def _shouldOwn(self, param):
"""
@@ -439,12 +441,26 @@ class Params(Identifiable):
self._paramMap[p] = value
return self
+ def _clear(self, param):
+ """
+ Clears a param from the param map if it has been explicitly set.
+ """
+ if self.isSet(param):
+ del self._paramMap[param]
+
def _setDefault(self, **kwargs):
"""
Sets default params.
"""
for param, value in kwargs.items():
- self._defaultParamMap[getattr(self, param)] = value
+ p = getattr(self, param)
+ if value is not None and not isinstance(value, JavaObject):
+ try:
+ value = p.typeConverter(value)
+ except TypeError as e:
+ raise TypeError('Invalid default param value given for param "%s". %s'
+ % (p.name, e))
+ self._defaultParamMap[p] = value
return self
def _copyValues(self, to, extra=None):
diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py
index a7615c43be..a2acf956bc 100644
--- a/python/pyspark/ml/param/_shared_params_code_gen.py
+++ b/python/pyspark/ml/param/_shared_params_code_gen.py
@@ -131,7 +131,7 @@ if __name__ == "__main__":
"TypeConverters.toFloat"),
("handleInvalid", "how to handle invalid entries. Options are skip (which will filter " +
"out rows with bad values), or error (which will throw an errror). More options may be " +
- "added later.", None, "TypeConverters.toBoolean"),
+ "added later.", None, "TypeConverters.toString"),
("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " +
"the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", "0.0",
"TypeConverters.toFloat"),
diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py
index c9e975525c..538c0b718a 100644
--- a/python/pyspark/ml/param/shared.py
+++ b/python/pyspark/ml/param/shared.py
@@ -392,7 +392,7 @@ class HasHandleInvalid(Params):
Mixin for param handleInvalid: how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.
"""
- handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.", typeConverter=TypeConverters.toBoolean)
+ handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.", typeConverter=TypeConverters.toString)
def __init__(self):
super(HasHandleInvalid, self).__init__()