aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/feature.py
diff options
context:
space:
mode:
authorsethah <seth.hendrickson16@gmail.com>2016-03-23 11:20:44 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-03-23 11:20:44 -0700
commit30bdb5cbd9aec191cf15cdc83c3fee375c04c2b2 (patch)
tree4d48b42ebe347fc40d5deeb3a77996db0c30eea1 /python/pyspark/ml/feature.py
parent48ee16d8012602c75d50aa2a85e26b7de3c48944 (diff)
downloadspark-30bdb5cbd9aec191cf15cdc83c3fee375c04c2b2.tar.gz
spark-30bdb5cbd9aec191cf15cdc83c3fee375c04c2b2.tar.bz2
spark-30bdb5cbd9aec191cf15cdc83c3fee375c04c2b2.zip
[SPARK-13068][PYSPARK][ML] Type conversion for Pyspark params
## What changes were proposed in this pull request? This patch adds type conversion functionality for parameters in Pyspark. A `typeConverter` field is added to the constructor of `Param` class. This argument is a function which converts values passed to this param to the appropriate type if possible. This is beneficial so that the params can fail at set time if they are given inappropriate values, but even more so because coherent error messages are now provided when Py4J cannot cast the python type to the appropriate Java type. This patch also adds a `TypeConverters` class with factory methods for common type conversions. Most of the changes involve adding these factory type converters to existing params. The previous solution to this issue, `expectedType`, is deprecated and can be removed in 2.1.0 as discussed on the Jira. ## How was this patch tested? Unit tests were added in python/pyspark/ml/tests.py to test parameter type conversion. These tests check that values that should be convertible are converted correctly, and that the appropriate errors are thrown when invalid values are provided. Author: sethah <seth.hendrickson16@gmail.com> Closes #11663 from sethah/SPARK-13068-tc.
Diffstat (limited to 'python/pyspark/ml/feature.py')
-rw-r--r--python/pyspark/ml/feature.py95
1 files changed, 55 insertions, 40 deletions
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 16cb9d1db3..86b53285b5 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -83,7 +83,8 @@ class Binarizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Java
"""
threshold = Param(Params._dummy(), "threshold",
- "threshold in binary classification prediction, in range [0, 1]")
+ "threshold in binary classification prediction, in range [0, 1]",
+ typeConverter=TypeConverters.toFloat)
@keyword_only
def __init__(self, threshold=0.0, inputCol=None, outputCol=None):
@@ -159,7 +160,8 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Jav
"range [x,y) except the last bucket, which also includes y. The splits " +
"should be strictly increasing. Values at -inf, inf must be explicitly " +
"provided to cover all Double values; otherwise, values outside the splits " +
- "specified will be treated as errors.")
+ "specified will be treated as errors.",
+ typeConverter=TypeConverters.toListFloat)
@keyword_only
def __init__(self, splits=None, inputCol=None, outputCol=None):
@@ -243,15 +245,17 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable,
" threshold are ignored. If this is an integer >= 1, then this specifies a count (of" +
" times the term must appear in the document); if this is a double in [0,1), then this " +
"specifies a fraction (out of the document's token count). Note that the parameter is " +
- "only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0")
+ "only used in transform of CountVectorizerModel and does not affect fitting. Default 1.0",
+ typeConverter=TypeConverters.toFloat)
minDF = Param(
Params._dummy(), "minDF", "Specifies the minimum number of" +
" different documents a term must appear in to be included in the vocabulary." +
" If this is an integer >= 1, this specifies the number of documents the term must" +
" appear in; if this is a double in [0,1), then this specifies the fraction of documents." +
- " Default 1.0")
+ " Default 1.0", typeConverter=TypeConverters.toFloat)
vocabSize = Param(
- Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.")
+ Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.",
+ typeConverter=TypeConverters.toInt)
@keyword_only
def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None):
@@ -375,7 +379,7 @@ class DCT(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWrit
"""
inverse = Param(Params._dummy(), "inverse", "Set transformer to perform inverse DCT, " +
- "default False.")
+ "default False.", typeConverter=TypeConverters.toBoolean)
@keyword_only
def __init__(self, inverse=False, inputCol=None, outputCol=None):
@@ -441,8 +445,8 @@ class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReada
.. versionadded:: 1.5.0
"""
- scalingVec = Param(Params._dummy(), "scalingVec", "vector for hadamard product, " +
- "it must be MLlib Vector type.")
+ scalingVec = Param(Params._dummy(), "scalingVec", "Vector for hadamard product.",
+ typeConverter=TypeConverters.toVector)
@keyword_only
def __init__(self, scalingVec=None, inputCol=None, outputCol=None):
@@ -564,7 +568,8 @@ class IDF(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritab
"""
minDocFreq = Param(Params._dummy(), "minDocFreq",
- "minimum of documents in which a term should appear for filtering")
+ "minimum of documents in which a term should appear for filtering",
+ typeConverter=TypeConverters.toInt)
@keyword_only
def __init__(self, minDocFreq=0, inputCol=None, outputCol=None):
@@ -746,8 +751,10 @@ class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Jav
.. versionadded:: 1.6.0
"""
- min = Param(Params._dummy(), "min", "Lower bound of the output feature range")
- max = Param(Params._dummy(), "max", "Upper bound of the output feature range")
+ min = Param(Params._dummy(), "min", "Lower bound of the output feature range",
+ typeConverter=TypeConverters.toFloat)
+ max = Param(Params._dummy(), "max", "Upper bound of the output feature range",
+ typeConverter=TypeConverters.toFloat)
@keyword_only
def __init__(self, min=0.0, max=1.0, inputCol=None, outputCol=None):
@@ -870,7 +877,8 @@ class NGram(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWr
.. versionadded:: 1.5.0
"""
- n = Param(Params._dummy(), "n", "number of elements per n-gram (>=1)")
+ n = Param(Params._dummy(), "n", "number of elements per n-gram (>=1)",
+ typeConverter=TypeConverters.toInt)
@keyword_only
def __init__(self, n=2, inputCol=None, outputCol=None):
@@ -936,7 +944,8 @@ class Normalizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, Jav
.. versionadded:: 1.4.0
"""
- p = Param(Params._dummy(), "p", "the p norm value.")
+ p = Param(Params._dummy(), "p", "the p norm value.",
+ typeConverter=TypeConverters.toFloat)
@keyword_only
def __init__(self, p=2.0, inputCol=None, outputCol=None):
@@ -1018,7 +1027,8 @@ class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable,
.. versionadded:: 1.4.0
"""
- dropLast = Param(Params._dummy(), "dropLast", "whether to drop the last category")
+ dropLast = Param(Params._dummy(), "dropLast", "whether to drop the last category",
+ typeConverter=TypeConverters.toBoolean)
@keyword_only
def __init__(self, dropLast=True, inputCol=None, outputCol=None):
@@ -1085,7 +1095,8 @@ class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, JavaMLRead
.. versionadded:: 1.4.0
"""
- degree = Param(Params._dummy(), "degree", "the polynomial degree to expand (>= 1)")
+ degree = Param(Params._dummy(), "degree", "the polynomial degree to expand (>= 1)",
+ typeConverter=TypeConverters.toInt)
@keyword_only
def __init__(self, degree=2, inputCol=None, outputCol=None):
@@ -1163,7 +1174,8 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasSeed, Jav
# a placeholder to make it appear in the generated doc
numBuckets = Param(Params._dummy(), "numBuckets",
"Maximum number of buckets (quantiles, or " +
- "categories) into which data points are grouped. Must be >= 2. Default 2.")
+ "categories) into which data points are grouped. Must be >= 2. Default 2.",
+ typeConverter=TypeConverters.toInt)
@keyword_only
def __init__(self, numBuckets=2, inputCol=None, outputCol=None, seed=None):
@@ -1255,11 +1267,13 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable,
.. versionadded:: 1.4.0
"""
- minTokenLength = Param(Params._dummy(), "minTokenLength", "minimum token length (>= 0)")
+ minTokenLength = Param(Params._dummy(), "minTokenLength", "minimum token length (>= 0)",
+ typeConverter=TypeConverters.toInt)
gaps = Param(Params._dummy(), "gaps", "whether regex splits on gaps (True) or matches tokens")
- pattern = Param(Params._dummy(), "pattern", "regex pattern (Java dialect) used for tokenizing")
+ pattern = Param(Params._dummy(), "pattern", "regex pattern (Java dialect) used for tokenizing",
+ TypeConverters.toString)
toLowercase = Param(Params._dummy(), "toLowercase", "whether to convert all characters to " +
- "lowercase before tokenizing")
+ "lowercase before tokenizing", TypeConverters.toBoolean)
@keyword_only
def __init__(self, minTokenLength=1, gaps=True, pattern="\\s+", inputCol=None,
@@ -1370,7 +1384,7 @@ class SQLTransformer(JavaTransformer, JavaMLReadable, JavaMLWritable):
.. versionadded:: 1.6.0
"""
- statement = Param(Params._dummy(), "statement", "SQL statement")
+ statement = Param(Params._dummy(), "statement", "SQL statement", TypeConverters.toString)
@keyword_only
def __init__(self, statement=None):
@@ -1444,8 +1458,9 @@ class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, J
.. versionadded:: 1.4.0
"""
- withMean = Param(Params._dummy(), "withMean", "Center data with mean")
- withStd = Param(Params._dummy(), "withStd", "Scale to unit standard deviation")
+ withMean = Param(Params._dummy(), "withMean", "Center data with mean", TypeConverters.toBoolean)
+ withStd = Param(Params._dummy(), "withStd", "Scale to unit standard deviation",
+ TypeConverters.toBoolean)
@keyword_only
def __init__(self, withMean=False, withStd=True, inputCol=None, outputCol=None):
@@ -1628,7 +1643,8 @@ class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable,
labels = Param(Params._dummy(), "labels",
"Optional array of labels specifying index-string mapping." +
- " If not provided or if empty, then metadata from inputCol is used instead.")
+ " If not provided or if empty, then metadata from inputCol is used instead.",
+ typeConverter=TypeConverters.toListString)
@keyword_only
def __init__(self, inputCol=None, outputCol=None, labels=None):
@@ -1689,9 +1705,10 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadabl
.. versionadded:: 1.6.0
"""
- stopWords = Param(Params._dummy(), "stopWords", "The words to be filtered out")
+ stopWords = Param(Params._dummy(), "stopWords", "The words to be filtered out",
+ typeConverter=TypeConverters.toListString)
caseSensitive = Param(Params._dummy(), "caseSensitive", "whether to do a case sensitive " +
- "comparison over the stop words")
+ "comparison over the stop words", TypeConverters.toBoolean)
@keyword_only
def __init__(self, inputCol=None, outputCol=None, stopWords=None,
@@ -1930,7 +1947,7 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, Ja
maxCategories = Param(Params._dummy(), "maxCategories",
"Threshold for the number of values a categorical feature can take " +
"(>= 2). If a feature is found to have > maxCategories values, then " +
- "it is declared continuous.")
+ "it is declared continuous.", typeConverter=TypeConverters.toInt)
@keyword_only
def __init__(self, maxCategories=20, inputCol=None, outputCol=None):
@@ -2035,11 +2052,12 @@ class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, J
"""
indices = Param(Params._dummy(), "indices", "An array of indices to select features from " +
- "a vector column. There can be no overlap with names.")
+ "a vector column. There can be no overlap with names.",
+ typeConverter=TypeConverters.toListInt)
names = Param(Params._dummy(), "names", "An array of feature names to select features from " +
"a vector column. These names must be specified by ML " +
"org.apache.spark.ml.attribute.Attribute. There can be no overlap with " +
- "indices.")
+ "indices.", typeConverter=TypeConverters.toListString)
@keyword_only
def __init__(self, inputCol=None, outputCol=None, indices=None, names=None):
@@ -2147,12 +2165,14 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has
"""
vectorSize = Param(Params._dummy(), "vectorSize",
- "the dimension of codes after transforming from words")
+ "the dimension of codes after transforming from words",
+ typeConverter=TypeConverters.toInt)
numPartitions = Param(Params._dummy(), "numPartitions",
- "number of partitions for sentences of words")
+ "number of partitions for sentences of words",
+ typeConverter=TypeConverters.toInt)
minCount = Param(Params._dummy(), "minCount",
"the minimum number of times a token must appear to be included in the " +
- "word2vec model's vocabulary")
+ "word2vec model's vocabulary", typeConverter=TypeConverters.toInt)
@keyword_only
def __init__(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1,
@@ -2293,7 +2313,8 @@ class PCA(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritab
.. versionadded:: 1.5.0
"""
- k = Param(Params._dummy(), "k", "the number of principal components")
+ k = Param(Params._dummy(), "k", "the number of principal components",
+ typeConverter=TypeConverters.toInt)
@keyword_only
def __init__(self, k=None, inputCol=None, outputCol=None):
@@ -2425,7 +2446,7 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaM
.. versionadded:: 1.5.0
"""
- formula = Param(Params._dummy(), "formula", "R model formula")
+ formula = Param(Params._dummy(), "formula", "R model formula", TypeConverters.toString)
@keyword_only
def __init__(self, formula=None, featuresCol="features", labelCol="label"):
@@ -2511,12 +2532,11 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, Ja
.. versionadded:: 2.0.0
"""
- # a placeholder to make it appear in the generated doc
numTopFeatures = \
Param(Params._dummy(), "numTopFeatures",
"Number of features that selector will select, ordered by statistics value " +
"descending. If the number of features is < numTopFeatures, then this will select " +
- "all features.")
+ "all features.", typeConverter=TypeConverters.toInt)
@keyword_only
def __init__(self, numTopFeatures=50, featuresCol="features", outputCol=None, labelCol="label"):
@@ -2525,11 +2545,6 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, Ja
"""
super(ChiSqSelector, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.ChiSqSelector", self.uid)
- self.numTopFeatures = \
- Param(self, "numTopFeatures",
- "Number of features that selector will select, ordered by statistics value " +
- "descending. If the number of features is < numTopFeatures, then this will " +
- "select all features.")
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)