From 129f2f455da982ec9fab593299fa4021b62827eb Mon Sep 17 00:00:00 2001 From: sethah Date: Fri, 15 Apr 2016 12:14:41 -0700 Subject: [SPARK-14104][PYSPARK][ML] All Python param setters should use the `_set` method ## What changes were proposed in this pull request? Param setters in python previously accessed the _paramMap directly to update values. The `_set` method now implements type checking, so it should be used to update all parameters. This PR eliminates all direct accesses to `_paramMap` besides the one in the `_set` method to ensure type checking happens. Additional changes: * [SPARK-13068](https://github.com/apache/spark/pull/11663) missed adding type converters in evaluation.py so those are done here * An incorrect `toBoolean` type converter was used for StringIndexer `handleInvalid` param in previous PR. This is fixed here. ## How was this patch tested? Existing unit tests verify that parameters are still set properly. No new functionality is actually added in this PR. Author: sethah Closes #11939 from sethah/SPARK-14104. --- python/pyspark/ml/feature.py | 72 +++++++++++++++++++++++--------------------- 1 file changed, 37 insertions(+), 35 deletions(-) (limited to 'python/pyspark/ml/feature.py') diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 0d8ef1297f..776906eaab 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -19,6 +19,8 @@ import sys if sys.version > '3': basestring = str +from py4j.java_collections import JavaArray + from pyspark import since from pyspark.rdd import ignore_unicode_prefix from pyspark.ml.param.shared import * @@ -112,7 +114,7 @@ class Binarizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Java """ Sets the value of :py:attr:`threshold`. """ - self._paramMap[self.threshold] = value + self._set(threshold=value) return self @since("1.4.0") @@ -188,7 +190,7 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Jav """ Sets the value of :py:attr:`splits`. """ - self._paramMap[self.splits] = value + self._set(splits=value) return self @since("1.4.0") @@ -293,7 +295,7 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, """ Sets the value of :py:attr:`minTF`. """ - self._paramMap[self.minTF] = value + self._set(minTF=value) return self @since("1.6.0") @@ -308,7 +310,7 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, """ Sets the value of :py:attr:`minDF`. """ - self._paramMap[self.minDF] = value + self._set(minDF=value) return self @since("1.6.0") @@ -323,7 +325,7 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, """ Sets the value of :py:attr:`vocabSize`. """ - self._paramMap[self.vocabSize] = value + self._set(vocabSize=value) return self @since("1.6.0") @@ -431,7 +433,7 @@ class DCT(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWrit """ Sets the value of :py:attr:`inverse`. """ - self._paramMap[self.inverse] = value + self._set(inverse=value) return self @since("1.6.0") @@ -498,7 +500,7 @@ class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReada """ Sets the value of :py:attr:`scalingVec`. """ - self._paramMap[self.scalingVec] = value + self._set(scalingVec=value) return self @since("1.5.0") @@ -641,7 +643,7 @@ class IDF(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritab """ Sets the value of :py:attr:`minDocFreq`. """ - self._paramMap[self.minDocFreq] = value + self._set(minDocFreq=value) return self @since("1.4.0") @@ -826,7 +828,7 @@ class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Jav """ Sets the value of :py:attr:`min`. """ - self._paramMap[self.min] = value + self._set(min=value) return self @since("1.6.0") @@ -841,7 +843,7 @@ class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Jav """ Sets the value of :py:attr:`max`. """ - self._paramMap[self.max] = value + self._set(max=value) return self @since("1.6.0") @@ -950,7 +952,7 @@ class NGram(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWr """ Sets the value of :py:attr:`n`. """ - self._paramMap[self.n] = value + self._set(n=value) return self @since("1.5.0") @@ -1017,7 +1019,7 @@ class Normalizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Jav """ Sets the value of :py:attr:`p`. """ - self._paramMap[self.p] = value + self._set(p=value) return self @since("1.4.0") @@ -1100,7 +1102,7 @@ class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, """ Sets the value of :py:attr:`dropLast`. """ - self._paramMap[self.dropLast] = value + self._set(dropLast=value) return self @since("1.4.0") @@ -1169,7 +1171,7 @@ class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, JavaMLRead """ Sets the value of :py:attr:`degree`. """ - self._paramMap[self.degree] = value + self._set(degree=value) return self @since("1.4.0") @@ -1251,7 +1253,7 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasSeed, Jav """ Sets the value of :py:attr:`numBuckets`. """ - self._paramMap[self.numBuckets] = value + self._set(numBuckets=value) return self @since("2.0.0") @@ -1349,7 +1351,7 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, """ Sets the value of :py:attr:`minTokenLength`. """ - self._paramMap[self.minTokenLength] = value + self._set(minTokenLength=value) return self @since("1.4.0") @@ -1364,7 +1366,7 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, """ Sets the value of :py:attr:`gaps`. """ - self._paramMap[self.gaps] = value + self._set(gaps=value) return self @since("1.4.0") @@ -1379,7 +1381,7 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, """ Sets the value of :py:attr:`pattern`. """ - self._paramMap[self.pattern] = value + self._set(pattern=value) return self @since("1.4.0") @@ -1394,7 +1396,7 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, """ Sets the value of :py:attr:`toLowercase`. """ - self._paramMap[self.toLowercase] = value + self._set(toLowercase=value) return self @since("2.0.0") @@ -1455,7 +1457,7 @@ class SQLTransformer(JavaTransformer, JavaMLReadable, JavaMLWritable): """ Sets the value of :py:attr:`statement`. """ - self._paramMap[self.statement] = value + self._set(statement=value) return self @since("1.6.0") @@ -1532,7 +1534,7 @@ class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, J """ Sets the value of :py:attr:`withMean`. """ - self._paramMap[self.withMean] = value + self._set(withMean=value) return self @since("1.4.0") @@ -1547,7 +1549,7 @@ class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, J """ Sets the value of :py:attr:`withStd`. """ - self._paramMap[self.withStd] = value + self._set(withStd=value) return self @since("1.4.0") @@ -1598,7 +1600,7 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, The indices are in [0, numLabels), ordered by label frequencies. So the most frequent label gets index 0. - >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed") + >>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed", handleInvalid='error') >>> model = stringIndexer.fit(stringIndDf) >>> td = model.transform(stringIndDf) >>> sorted(set([(i[0], i[1]) for i in td.select(td.id, td.indexed).collect()]), @@ -1716,7 +1718,7 @@ class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, """ Sets the value of :py:attr:`labels`. """ - self._paramMap[self.labels] = value + self._set(labels=value) return self @since("1.6.0") @@ -1787,7 +1789,7 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl """ Specify the stopwords to be filtered. """ - self._paramMap[self.stopWords] = value + self._set(stopWords=value) return self @since("1.6.0") @@ -1802,7 +1804,7 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl """ Set whether to do a case sensitive comparison over the stop words """ - self._paramMap[self.caseSensitive] = value + self._set(caseSensitive=value) return self @since("1.6.0") @@ -2019,7 +2021,7 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja """ Sets the value of :py:attr:`maxCategories`. """ - self._paramMap[self.maxCategories] = value + self._set(maxCategories=value) return self @since("1.4.0") @@ -2129,7 +2131,7 @@ class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, J """ Sets the value of :py:attr:`indices`. """ - self._paramMap[self.indices] = value + self._set(indices=value) return self @since("1.6.0") @@ -2144,7 +2146,7 @@ class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, J """ Sets the value of :py:attr:`names`. """ - self._paramMap[self.names] = value + self._set(names=value) return self @since("1.6.0") @@ -2249,7 +2251,7 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has """ Sets the value of :py:attr:`vectorSize`. """ - self._paramMap[self.vectorSize] = value + self._set(vectorSize=value) return self @since("1.4.0") @@ -2264,7 +2266,7 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has """ Sets the value of :py:attr:`numPartitions`. """ - self._paramMap[self.numPartitions] = value + self._set(numPartitions=value) return self @since("1.4.0") @@ -2279,7 +2281,7 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has """ Sets the value of :py:attr:`minCount`. """ - self._paramMap[self.minCount] = value + self._set(minCount=value) return self @since("1.4.0") @@ -2385,7 +2387,7 @@ class PCA(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritab """ Sets the value of :py:attr:`k`. """ - self._paramMap[self.k] = value + self._set(k=value) return self @since("1.5.0") @@ -2517,7 +2519,7 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM """ Sets the value of :py:attr:`formula`. """ - self._paramMap[self.formula] = value + self._set(formula=value) return self @since("1.5.0") @@ -2609,7 +2611,7 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, Ja """ Sets the value of :py:attr:`numTopFeatures`. """ - self._paramMap[self.numTopFeatures] = value + self._set(numTopFeatures=value) return self @since("2.0.0") -- cgit v1.2.3