diff options
author | Xiangrui Meng <meng@databricks.com> | 2015-05-18 12:02:18 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-05-18 12:02:18 -0700 |
commit | 9c7e802a5a2b8cd3eb77642f84c54a8e976fc996 (patch) | |
tree | 2e3b7e367f57b64ef46733ee8b64aa258e58cca8 /python/pyspark/ml/feature.py | |
parent | 56ede88485cfca90974425fcb603b257be47229b (diff) | |
download | spark-9c7e802a5a2b8cd3eb77642f84c54a8e976fc996.tar.gz spark-9c7e802a5a2b8cd3eb77642f84c54a8e976fc996.tar.bz2 spark-9c7e802a5a2b8cd3eb77642f84c54a8e976fc996.zip |
[SPARK-7380] [MLLIB] pipeline stages should be copyable in Python
This PR makes pipeline stages in Python copyable and hence simplifies some implementations. It also includes the following changes:
1. Rename `paramMap` and `defaultParamMap` to `_paramMap` and `_defaultParamMap`, respectively.
2. Accept a list of param maps in `fit`.
3. Use parent uid and name to identify param.
jkbradley
Author: Xiangrui Meng <meng@databricks.com>
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #6088 from mengxr/SPARK-7380 and squashes the following commits:
413c463 [Xiangrui Meng] remove unnecessary doc
4159f35 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7380
611c719 [Xiangrui Meng] fix python style
68862b8 [Xiangrui Meng] update _java_obj initialization
927ad19 [Xiangrui Meng] fix ml/tests.py
0138fc3 [Xiangrui Meng] update feature transformers and fix a bug in RegexTokenizer
9ca44fb [Xiangrui Meng] simplify Java wrappers and add tests
c7d84ef [Xiangrui Meng] update ml/tests.py to test copy params
7e0d27f [Xiangrui Meng] merge master
46840fb [Xiangrui Meng] update wrappers
b6db1ed [Xiangrui Meng] update all self.paramMap to self._paramMap
46cb6ed [Xiangrui Meng] merge master
a163413 [Xiangrui Meng] fix style
1042e80 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into SPARK-7380
9630eae [Xiangrui Meng] fix Identifiable._randomUID
13bd70a [Xiangrui Meng] update ml/tests.py
64a536c [Xiangrui Meng] use _fit/_transform/_evaluate to simplify the impl
02abf13 [Xiangrui Meng] Merge remote-tracking branch 'apache/master' into copyable-python
66ce18c [Joseph K. Bradley] some cleanups before sending to Xiangrui
7431272 [Joseph K. Bradley] Rebased with master
Diffstat (limited to 'python/pyspark/ml/feature.py')
-rw-r--r-- | python/pyspark/ml/feature.py | 91 |
1 files changed, 52 insertions, 39 deletions
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 58e22190c7..c8115cb5bc 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -43,7 +43,6 @@ class Binarizer(JavaTransformer, HasInputCol, HasOutputCol): 1.0 """ - _java_class = "org.apache.spark.ml.feature.Binarizer" # a placeholder to make it appear in the generated doc threshold = Param(Params._dummy(), "threshold", "threshold in binary classification prediction, in range [0, 1]") @@ -54,6 +53,7 @@ class Binarizer(JavaTransformer, HasInputCol, HasOutputCol): __init__(self, threshold=0.0, inputCol=None, outputCol=None) """ super(Binarizer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Binarizer", self.uid) self.threshold = Param(self, "threshold", "threshold in binary classification prediction, in range [0, 1]") self._setDefault(threshold=0.0) @@ -73,7 +73,7 @@ class Binarizer(JavaTransformer, HasInputCol, HasOutputCol): """ Sets the value of :py:attr:`threshold`. """ - self.paramMap[self.threshold] = value + self._paramMap[self.threshold] = value return self def getThreshold(self): @@ -104,7 +104,6 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol): 0.0 """ - _java_class = "org.apache.spark.ml.feature.Bucketizer" # a placeholder to make it appear in the generated doc splits = \ Param(Params._dummy(), "splits", @@ -121,6 +120,7 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol): __init__(self, splits=None, inputCol=None, outputCol=None) """ super(Bucketizer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Bucketizer", self.uid) #: param for Splitting points for mapping continuous features into buckets. With n+1 splits, # there are n buckets. A bucket defined by splits x,y holds values in the range [x,y) # except the last bucket, which also includes y. The splits should be strictly increasing. @@ -150,7 +150,7 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol): """ Sets the value of :py:attr:`splits`. """ - self.paramMap[self.splits] = value + self._paramMap[self.splits] = value return self def getSplits(self): @@ -177,14 +177,13 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures): SparseVector(5, {2: 1.0, 3: 1.0, 4: 1.0}) """ - _java_class = "org.apache.spark.ml.feature.HashingTF" - @keyword_only def __init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None): """ __init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None) """ super(HashingTF, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.HashingTF", self.uid) self._setDefault(numFeatures=1 << 18) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -217,8 +216,6 @@ class IDF(JavaEstimator, HasInputCol, HasOutputCol): DenseVector([0.2877, 0.0]) """ - _java_class = "org.apache.spark.ml.feature.IDF" - # a placeholder to make it appear in the generated doc minDocFreq = Param(Params._dummy(), "minDocFreq", "minimum of documents in which a term should appear for filtering") @@ -229,6 +226,7 @@ class IDF(JavaEstimator, HasInputCol, HasOutputCol): __init__(self, minDocFreq=0, inputCol=None, outputCol=None) """ super(IDF, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.IDF", self.uid) self.minDocFreq = Param(self, "minDocFreq", "minimum of documents in which a term should appear for filtering") self._setDefault(minDocFreq=0) @@ -248,7 +246,7 @@ class IDF(JavaEstimator, HasInputCol, HasOutputCol): """ Sets the value of :py:attr:`minDocFreq`. """ - self.paramMap[self.minDocFreq] = value + self._paramMap[self.minDocFreq] = value return self def getMinDocFreq(self): @@ -257,6 +255,9 @@ class IDF(JavaEstimator, HasInputCol, HasOutputCol): """ return self.getOrDefault(self.minDocFreq) + def _create_model(self, java_model): + return IDFModel(java_model) + class IDFModel(JavaModel): """ @@ -285,14 +286,13 @@ class Normalizer(JavaTransformer, HasInputCol, HasOutputCol): # a placeholder to make it appear in the generated doc p = Param(Params._dummy(), "p", "the p norm value.") - _java_class = "org.apache.spark.ml.feature.Normalizer" - @keyword_only def __init__(self, p=2.0, inputCol=None, outputCol=None): """ __init__(self, p=2.0, inputCol=None, outputCol=None) """ super(Normalizer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Normalizer", self.uid) self.p = Param(self, "p", "the p norm value.") self._setDefault(p=2.0) kwargs = self.__init__._input_kwargs @@ -311,7 +311,7 @@ class Normalizer(JavaTransformer, HasInputCol, HasOutputCol): """ Sets the value of :py:attr:`p`. """ - self.paramMap[self.p] = value + self._paramMap[self.p] = value return self def getP(self): @@ -347,8 +347,6 @@ class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol): SparseVector(3, {0: 1.0}) """ - _java_class = "org.apache.spark.ml.feature.OneHotEncoder" - # a placeholder to make it appear in the generated doc includeFirst = Param(Params._dummy(), "includeFirst", "include first category") @@ -358,6 +356,7 @@ class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol): __init__(self, includeFirst=True, inputCol=None, outputCol=None) """ super(OneHotEncoder, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.OneHotEncoder", self.uid) self.includeFirst = Param(self, "includeFirst", "include first category") self._setDefault(includeFirst=True) kwargs = self.__init__._input_kwargs @@ -376,7 +375,7 @@ class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol): """ Sets the value of :py:attr:`includeFirst`. """ - self.paramMap[self.includeFirst] = value + self._paramMap[self.includeFirst] = value return self def getIncludeFirst(self): @@ -404,8 +403,6 @@ class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol): DenseVector([0.5, 0.25, 2.0, 1.0, 4.0]) """ - _java_class = "org.apache.spark.ml.feature.PolynomialExpansion" - # a placeholder to make it appear in the generated doc degree = Param(Params._dummy(), "degree", "the polynomial degree to expand (>= 1)") @@ -415,6 +412,8 @@ class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol): __init__(self, degree=2, inputCol=None, outputCol=None) """ super(PolynomialExpansion, self).__init__() + self._java_obj = self._new_java_obj( + "org.apache.spark.ml.feature.PolynomialExpansion", self.uid) self.degree = Param(self, "degree", "the polynomial degree to expand (>= 1)") self._setDefault(degree=2) kwargs = self.__init__._input_kwargs @@ -433,7 +432,7 @@ class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol): """ Sets the value of :py:attr:`degree`. """ - self.paramMap[self.degree] = value + self._paramMap[self.degree] = value return self def getDegree(self): @@ -471,7 +470,6 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol): TypeError: Method setParams forces keyword arguments. """ - _java_class = "org.apache.spark.ml.feature.RegexTokenizer" # a placeholder to make it appear in the generated doc minTokenLength = Param(Params._dummy(), "minTokenLength", "minimum token length (>= 0)") gaps = Param(Params._dummy(), "gaps", "Set regex to match gaps or tokens") @@ -485,7 +483,8 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol): inputCol=None, outputCol=None) """ super(RegexTokenizer, self).__init__() - self.minTokenLength = Param(self, "minLength", "minimum token length (>= 0)") + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.RegexTokenizer", self.uid) + self.minTokenLength = Param(self, "minTokenLength", "minimum token length (>= 0)") self.gaps = Param(self, "gaps", "Set regex to match gaps or tokens") self.pattern = Param(self, "pattern", "regex pattern used for tokenizing") self._setDefault(minTokenLength=1, gaps=False, pattern="\\p{L}+|[^\\p{L}\\s]+") @@ -507,7 +506,7 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol): """ Sets the value of :py:attr:`minTokenLength`. """ - self.paramMap[self.minTokenLength] = value + self._paramMap[self.minTokenLength] = value return self def getMinTokenLength(self): @@ -520,7 +519,7 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol): """ Sets the value of :py:attr:`gaps`. """ - self.paramMap[self.gaps] = value + self._paramMap[self.gaps] = value return self def getGaps(self): @@ -533,7 +532,7 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol): """ Sets the value of :py:attr:`pattern`. """ - self.paramMap[self.pattern] = value + self._paramMap[self.pattern] = value return self def getPattern(self): @@ -557,8 +556,6 @@ class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol): DenseVector([1.4142]) """ - _java_class = "org.apache.spark.ml.feature.StandardScaler" - # a placeholder to make it appear in the generated doc withMean = Param(Params._dummy(), "withMean", "Center data with mean") withStd = Param(Params._dummy(), "withStd", "Scale to unit standard deviation") @@ -569,6 +566,7 @@ class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol): __init__(self, withMean=False, withStd=True, inputCol=None, outputCol=None) """ super(StandardScaler, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StandardScaler", self.uid) self.withMean = Param(self, "withMean", "Center data with mean") self.withStd = Param(self, "withStd", "Scale to unit standard deviation") self._setDefault(withMean=False, withStd=True) @@ -588,7 +586,7 @@ class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol): """ Sets the value of :py:attr:`withMean`. """ - self.paramMap[self.withMean] = value + self._paramMap[self.withMean] = value return self def getWithMean(self): @@ -601,7 +599,7 @@ class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol): """ Sets the value of :py:attr:`withStd`. """ - self.paramMap[self.withStd] = value + self._paramMap[self.withStd] = value return self def getWithStd(self): @@ -610,6 +608,9 @@ class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol): """ return self.getOrDefault(self.withStd) + def _create_model(self, java_model): + return StandardScalerModel(java_model) + class StandardScalerModel(JavaModel): """ @@ -633,14 +634,13 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol): [(0, 0.0), (1, 2.0), (2, 1.0), (3, 0.0), (4, 0.0), (5, 1.0)] """ - _java_class = "org.apache.spark.ml.feature.StringIndexer" - @keyword_only def __init__(self, inputCol=None, outputCol=None): """ __init__(self, inputCol=None, outputCol=None) """ super(StringIndexer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.StringIndexer", self.uid) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -653,6 +653,9 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol): kwargs = self.setParams._input_kwargs return self._set(**kwargs) + def _create_model(self, java_model): + return StringIndexerModel(java_model) + class StringIndexerModel(JavaModel): """ @@ -686,14 +689,13 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol): TypeError: Method setParams forces keyword arguments. """ - _java_class = "org.apache.spark.ml.feature.Tokenizer" - @keyword_only def __init__(self, inputCol=None, outputCol=None): """ __init__(self, inputCol=None, outputCol=None) """ super(Tokenizer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Tokenizer", self.uid) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -723,14 +725,13 @@ class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol): DenseVector([0.0, 1.0]) """ - _java_class = "org.apache.spark.ml.feature.VectorAssembler" - @keyword_only def __init__(self, inputCols=None, outputCol=None): """ __init__(self, inputCols=None, outputCol=None) """ super(VectorAssembler, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorAssembler", self.uid) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @@ -797,7 +798,6 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol): DenseVector([1.0, 0.0]) """ - _java_class = "org.apache.spark.ml.feature.VectorIndexer" # a placeholder to make it appear in the generated doc maxCategories = Param(Params._dummy(), "maxCategories", "Threshold for the number of values a categorical feature can take " + @@ -810,6 +810,7 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol): __init__(self, maxCategories=20, inputCol=None, outputCol=None) """ super(VectorIndexer, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorIndexer", self.uid) self.maxCategories = Param(self, "maxCategories", "Threshold for the number of values a categorical feature " + "can take (>= 2). If a feature is found to have " + @@ -831,7 +832,7 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol): """ Sets the value of :py:attr:`maxCategories`. """ - self.paramMap[self.maxCategories] = value + self._paramMap[self.maxCategories] = value return self def getMaxCategories(self): @@ -840,6 +841,15 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol): """ return self.getOrDefault(self.maxCategories) + def _create_model(self, java_model): + return VectorIndexerModel(java_model) + + +class VectorIndexerModel(JavaModel): + """ + Model fitted by VectorIndexer. + """ + @inherit_doc @ignore_unicode_prefix @@ -855,7 +865,6 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has DenseVector([-0.0422, -0.5138, -0.2546, 0.6885, 0.276]) """ - _java_class = "org.apache.spark.ml.feature.Word2Vec" # a placeholder to make it appear in the generated doc vectorSize = Param(Params._dummy(), "vectorSize", "the dimension of codes after transforming from words") @@ -873,6 +882,7 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has seed=42, inputCol=None, outputCol=None) """ super(Word2Vec, self).__init__() + self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Word2Vec", self.uid) self.vectorSize = Param(self, "vectorSize", "the dimension of codes after transforming from words") self.numPartitions = Param(self, "numPartitions", @@ -900,7 +910,7 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has """ Sets the value of :py:attr:`vectorSize`. """ - self.paramMap[self.vectorSize] = value + self._paramMap[self.vectorSize] = value return self def getVectorSize(self): @@ -913,7 +923,7 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has """ Sets the value of :py:attr:`numPartitions`. """ - self.paramMap[self.numPartitions] = value + self._paramMap[self.numPartitions] = value return self def getNumPartitions(self): @@ -926,7 +936,7 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has """ Sets the value of :py:attr:`minCount`. """ - self.paramMap[self.minCount] = value + self._paramMap[self.minCount] = value return self def getMinCount(self): @@ -935,6 +945,9 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has """ return self.getOrDefault(self.minCount) + def _create_model(self, java_model): + return Word2VecModel(java_model) + class Word2VecModel(JavaModel): """ |