diff options
author | sethah <seth.hendrickson16@gmail.com> | 2016-04-15 12:14:41 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-04-15 12:14:41 -0700 |
commit | 129f2f455da982ec9fab593299fa4021b62827eb (patch) | |
tree | 4cb68c4b09db6e572db333acd8ee242a4a4fbcbe /python/pyspark | |
parent | d6ae7d4637d23c57c4eeab79d1177216f380ec9c (diff) | |
download | spark-129f2f455da982ec9fab593299fa4021b62827eb.tar.gz spark-129f2f455da982ec9fab593299fa4021b62827eb.tar.bz2 spark-129f2f455da982ec9fab593299fa4021b62827eb.zip |
[SPARK-14104][PYSPARK][ML] All Python param setters should use the `_set` method
## What changes were proposed in this pull request?
Param setters in python previously accessed the _paramMap directly to update values. The `_set` method now implements type checking, so it should be used to update all parameters. This PR eliminates all direct accesses to `_paramMap` besides the one in the `_set` method to ensure type checking happens.
Additional changes:
* [SPARK-13068](https://github.com/apache/spark/pull/11663) missed adding type converters in evaluation.py so those are done here
* An incorrect `toBoolean` type converter was used for StringIndexer `handleInvalid` param in previous PR. This is fixed here.
## How was this patch tested?
Existing unit tests verify that parameters are still set properly. No new functionality is actually added in this PR.
Author: sethah <seth.hendrickson16@gmail.com>
Closes #11939 from sethah/SPARK-14104.
Diffstat (limited to 'python/pyspark')
-rw-r--r-- | python/pyspark/ml/classification.py | 22 | ||||
-rw-r--r-- | python/pyspark/ml/clustering.py | 10 | ||||
-rw-r--r-- | python/pyspark/ml/evaluation.py | 17 | ||||
-rw-r--r-- | python/pyspark/ml/feature.py | 72 | ||||
-rw-r--r-- | python/pyspark/ml/param/__init__.py | 22 | ||||
-rw-r--r-- | python/pyspark/ml/param/_shared_params_code_gen.py | 2 | ||||
-rw-r--r-- | python/pyspark/ml/param/shared.py | 2 | ||||
-rw-r--r-- | python/pyspark/ml/pipeline.py | 2 | ||||
-rw-r--r-- | python/pyspark/ml/recommendation.py | 22 | ||||
-rw-r--r-- | python/pyspark/ml/regression.py | 24 | ||||
-rw-r--r-- | python/pyspark/ml/tuning.py | 4 | ||||
-rw-r--r-- | python/pyspark/ml/wrapper.py | 2 |
12 files changed, 110 insertions, 91 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 6ef119a426..7051798485 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -142,9 +142,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti Sets the value of :py:attr:`threshold`. Clears value of :py:attr:`thresholds` if it has been set. """ - self._paramMap[self.threshold] = value - if self.isSet(self.thresholds): - del self._paramMap[self.thresholds] + self._set(threshold=value) + self._clear(self.thresholds) return self @since("1.4.0") @@ -169,9 +168,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti Sets the value of :py:attr:`thresholds`. Clears value of :py:attr:`threshold` if it has been set. """ - self._paramMap[self.thresholds] = value - if self.isSet(self.threshold): - del self._paramMap[self.threshold] + self._set(thresholds=value) + self._clear(self.threshold) return self @since("1.5.0") @@ -471,7 +469,7 @@ class TreeClassifierParams(object): """ Sets the value of :py:attr:`impurity`. """ - self._paramMap[self.impurity] = value + self._set(impurity=value) return self @since("1.6.0") @@ -833,7 +831,7 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol """ Sets the value of :py:attr:`lossType`. """ - self._paramMap[self.lossType] = value + self._set(lossType=value) return self @since("1.4.0") @@ -963,7 +961,7 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H """ Sets the value of :py:attr:`smoothing`. """ - self._paramMap[self.smoothing] = value + self._set(smoothing=value) return self @since("1.5.0") @@ -978,7 +976,7 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H """ Sets the value of :py:attr:`modelType`. """ - self._paramMap[self.modelType] = value + self._set(modelType=value) return self @since("1.5.0") @@ -1108,7 +1106,7 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, """ Sets the value of :py:attr:`layers`. """ - self._paramMap[self.layers] = value + self._set(layers=value) return self @since("1.6.0") @@ -1123,7 +1121,7 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, """ Sets the value of :py:attr:`blockSize`. """ - self._paramMap[self.blockSize] = value + self._set(blockSize=value) return self @since("1.6.0") diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index f071c597c8..64c4bf1b92 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -130,7 +130,7 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol """ Sets the value of :py:attr:`k`. """ - self._paramMap[self.k] = value + self._set(k=value) return self @since("1.5.0") @@ -145,7 +145,7 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol """ Sets the value of :py:attr:`initMode`. """ - self._paramMap[self.initMode] = value + self._set(initMode=value) return self @since("1.5.0") @@ -160,7 +160,7 @@ class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol """ Sets the value of :py:attr:`initSteps`. """ - self._paramMap[self.initSteps] = value + self._set(initSteps=value) return self @since("1.5.0") @@ -280,7 +280,7 @@ class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte """ Sets the value of :py:attr:`k`. """ - self._paramMap[self.k] = value + self._set(k=value) return self @since("2.0.0") @@ -295,7 +295,7 @@ class BisectingKMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIte """ Sets the value of :py:attr:`minDivisibleClusterSize`. """ - self._paramMap[self.minDivisibleClusterSize] = value + self._set(minDivisibleClusterSize=value) return self @since("2.0.0") diff --git a/python/pyspark/ml/evaluation.py b/python/pyspark/ml/evaluation.py index 4b0bade102..52a3fe8985 100644 --- a/python/pyspark/ml/evaluation.py +++ b/python/pyspark/ml/evaluation.py @@ -19,7 +19,7 @@ from abc import abstractmethod, ABCMeta from pyspark import since from pyspark.ml.wrapper import JavaParams -from pyspark.ml.param import Param, Params +from pyspark.ml.param import Param, Params, TypeConverters from pyspark.ml.param.shared import HasLabelCol, HasPredictionCol, HasRawPredictionCol from pyspark.ml.util import keyword_only from pyspark.mllib.common import inherit_doc @@ -125,7 +125,8 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction """ metricName = Param(Params._dummy(), "metricName", - "metric name in evaluation (areaUnderROC|areaUnderPR)") + "metric name in evaluation (areaUnderROC|areaUnderPR)", + typeConverter=TypeConverters.toString) @keyword_only def __init__(self, rawPredictionCol="rawPrediction", labelCol="label", @@ -147,7 +148,7 @@ class BinaryClassificationEvaluator(JavaEvaluator, HasLabelCol, HasRawPrediction """ Sets the value of :py:attr:`metricName`. """ - self._paramMap[self.metricName] = value + self._set(metricName=value) return self @since("1.4.0") @@ -194,7 +195,8 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol): # when we evaluate a metric that is needed to minimize (e.g., `"rmse"`, `"mse"`, `"mae"`), # we take and output the negative of this metric. metricName = Param(Params._dummy(), "metricName", - "metric name in evaluation (mse|rmse|r2|mae)") + "metric name in evaluation (mse|rmse|r2|mae)", + typeConverter=TypeConverters.toString) @keyword_only def __init__(self, predictionCol="prediction", labelCol="label", @@ -216,7 +218,7 @@ class RegressionEvaluator(JavaEvaluator, HasLabelCol, HasPredictionCol): """ Sets the value of :py:attr:`metricName`. """ - self._paramMap[self.metricName] = value + self._set(metricName=value) return self @since("1.4.0") @@ -260,7 +262,8 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio """ metricName = Param(Params._dummy(), "metricName", "metric name in evaluation " - "(f1|precision|recall|weightedPrecision|weightedRecall)") + "(f1|precision|recall|weightedPrecision|weightedRecall)", + typeConverter=TypeConverters.toString) @keyword_only def __init__(self, predictionCol="prediction", labelCol="label", @@ -282,7 +285,7 @@ class MulticlassClassificationEvaluator(JavaEvaluator, HasLabelCol, HasPredictio """ Sets the value of :py:attr:`metricName`. """ - self._paramMap[self.metricName] = value + self._set(metricName=value) return self @since("1.5.0") diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 0d8ef1297f..776906eaab 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -19,6 +19,8 @@ import sys if sys.version > '3': basestring = str +from py4j.java_collections import JavaArray + from pyspark import since from pyspark.rdd import ignore_unicode_prefix from pyspark.ml.param.shared import * @@ -112,7 +114,7 @@ class Binarizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Java """ Sets the value of :py:attr:`threshold`. """ - self._paramMap[self.threshold] = value + self._set(threshold=value) return self @since("1.4.0") @@ -188,7 +190,7 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Jav """ Sets the value of :py:attr:`splits`. """ - self._paramMap[self.splits] = value + self._set(splits=value) return self @since("1.4.0") @@ -293,7 +295,7 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, """ Sets the value of :py:attr:`minTF`. """ - self._paramMap[self.minTF] = value + self._set(minTF=value) return self @since("1.6.0") @@ -308,7 +310,7 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, """ Sets the value of :py:attr:`minDF`. """ - self._paramMap[self.minDF] = value + self._set(minDF=value) return self @since("1.6.0") @@ -323,7 +325,7 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, """ Sets the value of :py:attr:`vocabSize`. """ - self._paramMap[self.vocabSize] = value + self._set(vocabSize=value) return self @since("1.6.0") @@ -431,7 +433,7 @@ class DCT(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWrit """ Sets the value of :py:attr:`inverse`. """ - self._paramMap[self.inverse] = value + self._set(inverse=value) return self @since("1.6.0") @@ -498,7 +500,7 @@ class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReada """ Sets the value of :py:attr:`scalingVec`. """ - self._paramMap[self.scalingVec] = value + self._set(scalingVec=value) return self @since("1.5.0") @@ -641,7 +643,7 @@ class IDF(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritab """ Sets the value of :py:attr:`minDocFreq`. """ - self._paramMap[self.minDocFreq] = value + self._set(minDocFreq=value) return self @since("1.4.0") @@ -826,7 +828,7 @@ class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Jav """ Sets the value of :py:attr:`min`. """ - self._paramMap[self.min] = value + self._set(min=value) return self @since("1.6.0") @@ -841,7 +843,7 @@ class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Jav """ Sets the value of :py:attr:`max`. """ - self._paramMap[self.max] = value + self._set(max=value) return self @since("1.6.0") @@ -950,7 +952,7 @@ class NGram(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWr """ Sets the value of :py:attr:`n`. """ - self._paramMap[self.n] = value + self._set(n=value) return self @since("1.5.0") @@ -1017,7 +1019,7 @@ class Normalizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Jav """ Sets the value of :py:attr:`p`. """ - self._paramMap[self.p] = value + self._set(p=value) return self @since("1.4.0") @@ -1100,7 +1102,7 @@ class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, """ Sets the value of :py:attr:`dropLast`. """ - self._paramMap[self.dropLast] = value + self._set(dropLast=value) return self @since("1.4.0") @@ -1169,7 +1171,7 @@ class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, JavaMLRead """ Sets the value of :py:attr:`degree`. """ - self._paramMap[self.degree] = value + self._set(degree=value) return self @since("1.4.0") @@ -1251,7 +1253,7 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasSeed, Jav """ Sets the value of :py:attr:`numBuckets`. """ - self._paramMap[self.numBuckets] = value + self._set(numBuckets=value) return self @since("2.0.0") @@ -1349,7 +1351,7 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, """ Sets the value of :py:attr:`minTokenLength`. """ - self._paramMap[self.minTokenLength] = value + self._set(minTokenLength=value) return self @since("1.4.0") @@ -1364,7 +1366,7 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, """ Sets the value of :py:attr:`gaps`. """ - self._paramMap[self.gaps] = value + self._set(gaps=value) return self @since("1.4.0") @@ -1379,7 +1381,7 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, """ Sets the value of :py:attr:`pattern`. """ - self._paramMap[self.pattern] = value + self._set(pattern=value) return self @since("1.4.0") @@ -1394,7 +1396,7 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, """ Sets the value of :py:attr:`toLowercase`. """ - self._paramMap[self.toLowercase] = value + self._set(toLowercase=value) return self @since("2.0.0") @@ -1455,7 +1457,7 @@ class SQLTransformer(JavaTransformer, JavaMLReadable, JavaMLWritable): """ Sets the value of :py:attr:`statement`. """ - self._paramMap[self.statement] = value + self._set(statement=value) return self @since("1.6.0") @@ -1532,7 +1534,7 @@ class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, J """ Sets the value of :py:attr:`withMean`. """ - self._paramMap[self.withMean] = value + self._set(withMean=value) return self @since("1.4.0") @@ -1547,7 +1549,7 @@ class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, J """ Sets the value of :py:attr:`withStd`. """ - self._paramMap[self.withStd] = value + self._set(withStd=value) return self @since("1.4.0") @@ -1598,7 +1600,7 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, The indices are in [0, numLabels), ordered by label frequencies. So the most frequent label gets index 0. - >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") + >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed", handleInvalid='error') >>> model = stringIndexer.fit(stringIndDf) >>> td = model.transform(stringIndDf) >>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]), @@ -1716,7 +1718,7 @@ class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, """ Sets the value of :py:attr:`labels`. """ - self._paramMap[self.labels] = value + self._set(labels=value) return self @since("1.6.0") @@ -1787,7 +1789,7 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl """ Specify the stopwords to be filtered. """ - self._paramMap[self.stopWords] = value + self._set(stopWords=value) return self @since("1.6.0") @@ -1802,7 +1804,7 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl """ Set whether to do a case sensitive comparison over the stop words """ - self._paramMap[self.caseSensitive] = value + self._set(caseSensitive=value) return self @since("1.6.0") @@ -2019,7 +2021,7 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja """ Sets the value of :py:attr:`maxCategories`. """ - self._paramMap[self.maxCategories] = value + self._set(maxCategories=value) return self @since("1.4.0") @@ -2129,7 +2131,7 @@ class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, J """ Sets the value of :py:attr:`indices`. """ - self._paramMap[self.indices] = value + self._set(indices=value) return self @since("1.6.0") @@ -2144,7 +2146,7 @@ class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, J """ Sets the value of :py:attr:`names`. """ - self._paramMap[self.names] = value + self._set(names=value) return self @since("1.6.0") @@ -2249,7 +2251,7 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has """ Sets the value of :py:attr:`vectorSize`. """ - self._paramMap[self.vectorSize] = value + self._set(vectorSize=value) return self @since("1.4.0") @@ -2264,7 +2266,7 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has """ Sets the value of :py:attr:`numPartitions`. """ - self._paramMap[self.numPartitions] = value + self._set(numPartitions=value) return self @since("1.4.0") @@ -2279,7 +2281,7 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has """ Sets the value of :py:attr:`minCount`. """ - self._paramMap[self.minCount] = value + self._set(minCount=value) return self @since("1.4.0") @@ -2385,7 +2387,7 @@ class PCA(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritab """ Sets the value of :py:attr:`k`. """ - self._paramMap[self.k] = value + self._set(k=value) return self @since("1.5.0") @@ -2517,7 +2519,7 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM """ Sets the value of :py:attr:`formula`. """ - self._paramMap[self.formula] = value + self._set(formula=value) return self @since("1.5.0") @@ -2609,7 +2611,7 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, Ja """ Sets the value of :py:attr:`numTopFeatures`. """ - self._paramMap[self.numTopFeatures] = value + self._set(numTopFeatures=value) return self @since("2.0.0") diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py index a1265294a1..9f0b063aac 100644 --- a/python/pyspark/ml/param/__init__.py +++ b/python/pyspark/ml/param/__init__.py @@ -26,6 +26,8 @@ import copy import numpy as np import warnings +from py4j.java_gateway import JavaObject + from pyspark import since from pyspark.ml.util import Identifiable from pyspark.mllib.linalg import DenseVector, Vector @@ -389,8 +391,8 @@ class Params(Identifiable): if extra is None: extra = dict() that = copy.copy(self) - that._paramMap = self.extractParamMap(extra) - return that + that._paramMap = {} + return self._copyValues(that, extra) def _shouldOwn(self, param): """ @@ -439,12 +441,26 @@ class Params(Identifiable): self._paramMap[p] = value return self + def _clear(self, param): + """ + Clears a param from the param map if it has been explicitly set. + """ + if self.isSet(param): + del self._paramMap[param] + def _setDefault(self, **kwargs): """ Sets default params. """ for param, value in kwargs.items(): - self._defaultParamMap[getattr(self, param)] = value + p = getattr(self, param) + if value is not None and not isinstance(value, JavaObject): + try: + value = p.typeConverter(value) + except TypeError as e: + raise TypeError('Invalid default param value given for param "%s". %s' + % (p.name, e)) + self._defaultParamMap[p] = value return self def _copyValues(self, to, extra=None): diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index a7615c43be..a2acf956bc 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -131,7 +131,7 @@ if __name__ == "__main__": "TypeConverters.toFloat"), ("handleInvalid", "how to handle invalid entries. Options are skip (which will filter " + "out rows with bad values), or error (which will throw an errror). More options may be " + - "added later.", None, "TypeConverters.toBoolean"), + "added later.", None, "TypeConverters.toString"), ("elasticNetParam", "the ElasticNet mixing parameter, in range [0, 1]. For alpha = 0, " + "the penalty is an L2 penalty. For alpha = 1, it is an L1 penalty.", "0.0", "TypeConverters.toFloat"), diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index c9e975525c..538c0b718a 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -392,7 +392,7 @@ class HasHandleInvalid(Params): Mixin for param handleInvalid: how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later. """ - handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.", typeConverter=TypeConverters.toBoolean) + handleInvalid = Param(Params._dummy(), "handleInvalid", "how to handle invalid entries. Options are skip (which will filter out rows with bad values), or error (which will throw an errror). More options may be added later.", typeConverter=TypeConverters.toString) def __init__(self): super(HasHandleInvalid, self).__init__() diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index 9d654e8b0f..6f599b5159 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -90,7 +90,7 @@ class Pipeline(Estimator, MLReadable, MLWritable): :param value: a list of transformers or estimators :return: the pipeline instance """ - self._paramMap[self.stages] = value + self._set(stages=value) return self @since("1.3.0") diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index 7c7a1b67a1..9c38f2431b 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -157,7 +157,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha """ Sets the value of :py:attr:`rank`. """ - self._paramMap[self.rank] = value + self._set(rank=value) return self @since("1.4.0") @@ -172,7 +172,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha """ Sets the value of :py:attr:`numUserBlocks`. """ - self._paramMap[self.numUserBlocks] = value + self._set(numUserBlocks=value) return self @since("1.4.0") @@ -187,7 +187,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha """ Sets the value of :py:attr:`numItemBlocks`. """ - self._paramMap[self.numItemBlocks] = value + self._set(numItemBlocks=value) return self @since("1.4.0") @@ -202,15 +202,15 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha """ Sets both :py:attr:`numUserBlocks` and :py:attr:`numItemBlocks` to the specific value. """ - self._paramMap[self.numUserBlocks] = value - self._paramMap[self.numItemBlocks] = value + self._set(numUserBlocks=value) + self._set(numItemBlocks=value) @since("1.4.0") def setImplicitPrefs(self, value): """ Sets the value of :py:attr:`implicitPrefs`. """ - self._paramMap[self.implicitPrefs] = value + self._set(implicitPrefs=value) return self @since("1.4.0") @@ -225,7 +225,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha """ Sets the value of :py:attr:`alpha`. """ - self._paramMap[self.alpha] = value + self._set(alpha=value) return self @since("1.4.0") @@ -240,7 +240,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha """ Sets the value of :py:attr:`userCol`. """ - self._paramMap[self.userCol] = value + self._set(userCol=value) return self @since("1.4.0") @@ -255,7 +255,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha """ Sets the value of :py:attr:`itemCol`. """ - self._paramMap[self.itemCol] = value + self._set(itemCol=value) return self @since("1.4.0") @@ -270,7 +270,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha """ Sets the value of :py:attr:`ratingCol`. """ - self._paramMap[self.ratingCol] = value + self._set(ratingCol=value) return self @since("1.4.0") @@ -285,7 +285,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha """ Sets the value of :py:attr:`nonnegative`. """ - self._paramMap[self.nonnegative] = value + self._set(nonnegative=value) return self @since("1.4.0") diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 3c7852526a..8b68622524 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -478,7 +478,7 @@ class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti """ Sets the value of :py:attr:`isotonic`. """ - self._paramMap[self.isotonic] = value + self._set(isotonic=value) return self def getIsotonic(self): @@ -491,7 +491,7 @@ class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti """ Sets the value of :py:attr:`featureIndex`. """ - self._paramMap[self.featureIndex] = value + self._set(featureIndex=value) return self def getFeatureIndex(self): @@ -541,7 +541,7 @@ class TreeEnsembleParams(DecisionTreeParams): """ Sets the value of :py:attr:`subsamplingRate`. """ - self._paramMap[self.subsamplingRate] = value + self._set(subsamplingRate=value) return self @since("1.4.0") @@ -571,7 +571,7 @@ class TreeRegressorParams(Params): """ Sets the value of :py:attr:`impurity`. """ - self._paramMap[self.impurity] = value + self._set(impurity=value) return self @since("1.4.0") @@ -604,7 +604,7 @@ class RandomForestParams(TreeEnsembleParams): """ Sets the value of :py:attr:`numTrees`. """ - self._paramMap[self.numTrees] = value + self._set(numTrees=value) return self @since("1.4.0") @@ -619,7 +619,7 @@ class RandomForestParams(TreeEnsembleParams): """ Sets the value of :py:attr:`featureSubsetStrategy`. """ - self._paramMap[self.featureSubsetStrategy] = value + self._set(featureSubsetStrategy=value) return self @since("1.4.0") @@ -991,7 +991,7 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, """ Sets the value of :py:attr:`lossType`. """ - self._paramMap[self.lossType] = value + self._set(lossType=value) return self @since("1.4.0") @@ -1126,7 +1126,7 @@ class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi """ Sets the value of :py:attr:`censorCol`. """ - self._paramMap[self.censorCol] = value + self._set(censorCol=value) return self @since("1.6.0") @@ -1141,7 +1141,7 @@ class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi """ Sets the value of :py:attr:`quantileProbabilities`. """ - self._paramMap[self.quantileProbabilities] = value + self._set(quantileProbabilities=value) return self @since("1.6.0") @@ -1156,7 +1156,7 @@ class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi """ Sets the value of :py:attr:`quantilesCol`. """ - self._paramMap[self.quantilesCol] = value + self._set(quantilesCol=value) return self @since("1.6.0") @@ -1305,7 +1305,7 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha """ Sets the value of :py:attr:`family`. """ - self._paramMap[self.family] = value + self._set(family=value) return self @since("2.0.0") @@ -1320,7 +1320,7 @@ class GeneralizedLinearRegression(JavaEstimator, HasLabelCol, HasFeaturesCol, Ha """ Sets the value of :py:attr:`link`. """ - self._paramMap[self.link] = value + self._set(link=value) return self @since("2.0.0") diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 456d79d897..5ac539edde 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -228,7 +228,7 @@ class CrossValidator(Estimator, ValidatorParams, MLReadable, MLWritable): """ Sets the value of :py:attr:`numFolds`. """ - self._paramMap[self.numFolds] = value + self._set(numFolds=value) return self @since("1.4.0") @@ -479,7 +479,7 @@ class TrainValidationSplit(Estimator, ValidatorParams, MLReadable, MLWritable): """ Sets the value of :py:attr:`trainRatio`. """ - self._paramMap[self.trainRatio] = value + self._set(trainRatio=value) return self @since("2.0.0") diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index cd0e5b80d5..055a2816f8 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -112,7 +112,7 @@ class JavaParams(JavaWrapper, Params): java_param = self._java_obj.getParam(param.name) if self._java_obj.isDefined(java_param): value = _java2py(sc, self._java_obj.getOrDefault(java_param)) - self._paramMap[param] = value + self._set(**{param.name: value}) def _transfer_param_map_from_java(self, javaParamMap): """ |