diff options
author | sethah <seth.hendrickson16@gmail.com> | 2016-04-15 12:14:41 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-04-15 12:14:41 -0700 |
commit | 129f2f455da982ec9fab593299fa4021b62827eb (patch) | |
tree | 4cb68c4b09db6e572db333acd8ee242a4a4fbcbe /python/pyspark/ml/tuning.py | |
parent | d6ae7d4637d23c57c4eeab79d1177216f380ec9c (diff) | |
download | spark-129f2f455da982ec9fab593299fa4021b62827eb.tar.gz spark-129f2f455da982ec9fab593299fa4021b62827eb.tar.bz2 spark-129f2f455da982ec9fab593299fa4021b62827eb.zip |
[SPARK-14104][PYSPARK][ML] All Python param setters should use the `_set` method
## What changes were proposed in this pull request?
Param setters in python previously accessed the _paramMap directly to update values. The `_set` method now implements type checking, so it should be used to update all parameters. This PR eliminates all direct accesses to `_paramMap` besides the one in the `_set` method to ensure type checking happens.
Additional changes:
* [SPARK-13068](https://github.com/apache/spark/pull/11663) missed adding type converters in evaluation.py so those are done here
* An incorrect `toBoolean` type converter was used for StringIndexer `handleInvalid` param in previous PR. This is fixed here.
## How was this patch tested?
Existing unit tests verify that parameters are still set properly. No new functionality is actually added in this PR.
Author: sethah <seth.hendrickson16@gmail.com>
Closes #11939 from sethah/SPARK-14104.
Diffstat (limited to 'python/pyspark/ml/tuning.py')
-rw-r--r-- | python/pyspark/ml/tuning.py | 4 |
1 files changed, 2 insertions, 2 deletions
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 456d79d897..5ac539edde 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -228,7 +228,7 @@ class CrossValidator(Estimator, ValidatorParams, MLReadable, MLWritable): """ Sets the value of :py:attr:`numFolds`. """ - self._paramMap[self.numFolds] = value + self._set(numFolds=value) return self @since("1.4.0") @@ -479,7 +479,7 @@ class TrainValidationSplit(Estimator, ValidatorParams, MLReadable, MLWritable): """ Sets the value of :py:attr:`trainRatio`. """ - self._paramMap[self.trainRatio] = value + self._set(trainRatio=value) return self @since("2.0.0") |