aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2016-03-04 08:32:24 -0800
committerXiangrui Meng <meng@databricks.com>2016-03-04 08:32:24 -0800
commit83302c3bff13bd7734426c81d9c83bf4beb211c9 (patch)
tree44f447454524f6ca3c3e0c8e1b3ec75877222620 /python
parentc8f25459ed4ad6b51a5f11665364cfe0b84f7b3c (diff)
downloadspark-83302c3bff13bd7734426c81d9c83bf4beb211c9.tar.gz
spark-83302c3bff13bd7734426c81d9c83bf4beb211c9.tar.bz2
spark-83302c3bff13bd7734426c81d9c83bf4beb211c9.zip
[SPARK-13036][SPARK-13318][SPARK-13319] Add save/load for feature.py
Add save/load for feature.py. Meanwhile, add save/load for `ElementwiseProduct` in Scala side and fix a bug of missing `setDefault` in `VectorSlicer` and `StopWordsRemover`. In this PR I ignore the `RFormula` and `RFormulaModel` because its Scala implementation is pending in https://github.com/apache/spark/pull/9884. I'll add them in this PR if https://github.com/apache/spark/pull/9884 gets merged first. Or add a follow-up JIRA for `RFormula`. Author: Xusen Yin <yinxusen@gmail.com> Closes #11203 from yinxusen/SPARK-13036.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/ml/feature.py341
1 files changed, 296 insertions, 45 deletions
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index fb31c7310c..5025493c42 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -22,7 +22,7 @@ if sys.version > '3':
from pyspark import since
from pyspark.rdd import ignore_unicode_prefix
from pyspark.ml.param.shared import *
-from pyspark.ml.util import keyword_only
+from pyspark.ml.util import keyword_only, MLReadable, MLWritable
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer, _jvm
from pyspark.mllib.common import inherit_doc
from pyspark.mllib.linalg import _convert_to_vector
@@ -58,7 +58,7 @@ __all__ = ['Binarizer',
@inherit_doc
-class Binarizer(JavaTransformer, HasInputCol, HasOutputCol):
+class Binarizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
"""
.. note:: Experimental
@@ -73,6 +73,11 @@ class Binarizer(JavaTransformer, HasInputCol, HasOutputCol):
>>> params = {binarizer.threshold: -0.5, binarizer.outputCol: "vector"}
>>> binarizer.transform(df, params).head().vector
1.0
+ >>> binarizerPath = temp_path + "/binarizer"
+ >>> binarizer.save(binarizerPath)
+ >>> loadedBinarizer = Binarizer.load(binarizerPath)
+ >>> loadedBinarizer.getThreshold() == binarizer.getThreshold()
+ True
.. versionadded:: 1.4.0
"""
@@ -118,7 +123,7 @@ class Binarizer(JavaTransformer, HasInputCol, HasOutputCol):
@inherit_doc
-class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol):
+class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
"""
.. note:: Experimental
@@ -138,6 +143,11 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol):
2.0
>>> bucketizer.setParams(outputCol="b").transform(df).head().b
0.0
+ >>> bucketizerPath = temp_path + "/bucketizer"
+ >>> bucketizer.save(bucketizerPath)
+ >>> loadedBucketizer = Bucketizer.load(bucketizerPath)
+ >>> loadedBucketizer.getSplits() == bucketizer.getSplits()
+ True
.. versionadded:: 1.3.0
"""
@@ -188,7 +198,7 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol):
@inherit_doc
-class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol):
+class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable):
"""
.. note:: Experimental
@@ -207,8 +217,22 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol):
|1 |[a, b, b, c, a]|(3,[0,1,2],[2.0,2.0,1.0])|
+-----+---------------+-------------------------+
...
- >>> sorted(map(str, model.vocabulary))
- ['a', 'b', 'c']
+ >>> sorted(model.vocabulary) == ['a', 'b', 'c']
+ True
+ >>> countVectorizerPath = temp_path + "/count-vectorizer"
+ >>> cv.save(countVectorizerPath)
+ >>> loadedCv = CountVectorizer.load(countVectorizerPath)
+ >>> loadedCv.getMinDF() == cv.getMinDF()
+ True
+ >>> loadedCv.getMinTF() == cv.getMinTF()
+ True
+ >>> loadedCv.getVocabSize() == cv.getVocabSize()
+ True
+ >>> modelPath = temp_path + "/count-vectorizer-model"
+ >>> model.save(modelPath)
+ >>> loadedModel = CountVectorizerModel.load(modelPath)
+ >>> loadedModel.vocabulary == model.vocabulary
+ True
.. versionadded:: 1.6.0
"""
@@ -300,7 +324,7 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol):
return CountVectorizerModel(java_model)
-class CountVectorizerModel(JavaModel):
+class CountVectorizerModel(JavaModel, MLReadable, MLWritable):
"""
.. note:: Experimental
@@ -319,7 +343,7 @@ class CountVectorizerModel(JavaModel):
@inherit_doc
-class DCT(JavaTransformer, HasInputCol, HasOutputCol):
+class DCT(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
"""
.. note:: Experimental
@@ -341,6 +365,11 @@ class DCT(JavaTransformer, HasInputCol, HasOutputCol):
>>> df3 = DCT(inverse=True, inputCol="resultVec", outputCol="origVec").transform(df2)
>>> df3.head().origVec
DenseVector([5.0, 8.0, 6.0])
+ >>> dctPath = temp_path + "/dct"
+ >>> dct.save(dctPath)
+ >>> loadedDtc = DCT.load(dctPath)
+ >>> loadedDtc.getInverse()
+ False
.. versionadded:: 1.6.0
"""
@@ -386,7 +415,7 @@ class DCT(JavaTransformer, HasInputCol, HasOutputCol):
@inherit_doc
-class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol):
+class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
"""
.. note:: Experimental
@@ -402,6 +431,11 @@ class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol):
DenseVector([2.0, 2.0, 9.0])
>>> ep.setParams(scalingVec=Vectors.dense([2.0, 3.0, 5.0])).transform(df).head().eprod
DenseVector([4.0, 3.0, 15.0])
+ >>> elementwiseProductPath = temp_path + "/elementwise-product"
+ >>> ep.save(elementwiseProductPath)
+ >>> loadedEp = ElementwiseProduct.load(elementwiseProductPath)
+ >>> loadedEp.getScalingVec() == ep.getScalingVec()
+ True
.. versionadded:: 1.5.0
"""
@@ -447,7 +481,7 @@ class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol):
@inherit_doc
-class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures):
+class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, MLReadable, MLWritable):
"""
.. note:: Experimental
@@ -463,6 +497,11 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures):
>>> params = {hashingTF.numFeatures: 5, hashingTF.outputCol: "vector"}
>>> hashingTF.transform(df, params).head().vector
SparseVector(5, {2: 1.0, 3: 1.0, 4: 1.0})
+ >>> hashingTFPath = temp_path + "/hashing-tf"
+ >>> hashingTF.save(hashingTFPath)
+ >>> loadedHashingTF = HashingTF.load(hashingTFPath)
+ >>> loadedHashingTF.getNumFeatures() == hashingTF.getNumFeatures()
+ True
.. versionadded:: 1.3.0
"""
@@ -490,7 +529,7 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures):
@inherit_doc
-class IDF(JavaEstimator, HasInputCol, HasOutputCol):
+class IDF(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable):
"""
.. note:: Experimental
@@ -500,13 +539,24 @@ class IDF(JavaEstimator, HasInputCol, HasOutputCol):
>>> df = sqlContext.createDataFrame([(DenseVector([1.0, 2.0]),),
... (DenseVector([0.0, 1.0]),), (DenseVector([3.0, 0.2]),)], ["tf"])
>>> idf = IDF(minDocFreq=3, inputCol="tf", outputCol="idf")
- >>> idf.fit(df).transform(df).head().idf
+ >>> model = idf.fit(df)
+ >>> model.transform(df).head().idf
DenseVector([0.0, 0.0])
>>> idf.setParams(outputCol="freqs").fit(df).transform(df).collect()[1].freqs
DenseVector([0.0, 0.0])
>>> params = {idf.minDocFreq: 1, idf.outputCol: "vector"}
>>> idf.fit(df, params).transform(df).head().vector
DenseVector([0.2877, 0.0])
+ >>> idfPath = temp_path + "/idf"
+ >>> idf.save(idfPath)
+ >>> loadedIdf = IDF.load(idfPath)
+ >>> loadedIdf.getMinDocFreq() == idf.getMinDocFreq()
+ True
+ >>> modelPath = temp_path + "/idf-model"
+ >>> model.save(modelPath)
+ >>> loadedModel = IDFModel.load(modelPath)
+ >>> loadedModel.transform(df).head().idf == model.transform(df).head().idf
+ True
.. versionadded:: 1.4.0
"""
@@ -554,7 +604,7 @@ class IDF(JavaEstimator, HasInputCol, HasOutputCol):
return IDFModel(java_model)
-class IDFModel(JavaModel):
+class IDFModel(JavaModel, MLReadable, MLWritable):
"""
.. note:: Experimental
@@ -565,7 +615,7 @@ class IDFModel(JavaModel):
@inherit_doc
-class MaxAbsScaler(JavaEstimator, HasInputCol, HasOutputCol):
+class MaxAbsScaler(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable):
"""
.. note:: Experimental
@@ -585,6 +635,18 @@ class MaxAbsScaler(JavaEstimator, HasInputCol, HasOutputCol):
|[2.0]| [1.0]|
+-----+------+
...
+ >>> scalerPath = temp_path + "/max-abs-scaler"
+ >>> maScaler.save(scalerPath)
+ >>> loadedMAScaler = MaxAbsScaler.load(scalerPath)
+ >>> loadedMAScaler.getInputCol() == maScaler.getInputCol()
+ True
+ >>> loadedMAScaler.getOutputCol() == maScaler.getOutputCol()
+ True
+ >>> modelPath = temp_path + "/max-abs-scaler-model"
+ >>> model.save(modelPath)
+ >>> loadedModel = MaxAbsScalerModel.load(modelPath)
+ >>> loadedModel.maxAbs == model.maxAbs
+ True
.. versionadded:: 2.0.0
"""
@@ -614,7 +676,7 @@ class MaxAbsScaler(JavaEstimator, HasInputCol, HasOutputCol):
return MaxAbsScalerModel(java_model)
-class MaxAbsScalerModel(JavaModel):
+class MaxAbsScalerModel(JavaModel, MLReadable, MLWritable):
"""
.. note:: Experimental
@@ -623,9 +685,17 @@ class MaxAbsScalerModel(JavaModel):
.. versionadded:: 2.0.0
"""
+ @property
+ @since("2.0.0")
+ def maxAbs(self):
+ """
+ Max Abs vector.
+ """
+ return self._call_java("maxAbs")
+
@inherit_doc
-class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol):
+class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable):
"""
.. note:: Experimental
@@ -656,6 +726,20 @@ class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol):
|[2.0]| [1.0]|
+-----+------+
...
+ >>> minMaxScalerPath = temp_path + "/min-max-scaler"
+ >>> mmScaler.save(minMaxScalerPath)
+ >>> loadedMMScaler = MinMaxScaler.load(minMaxScalerPath)
+ >>> loadedMMScaler.getMin() == mmScaler.getMin()
+ True
+ >>> loadedMMScaler.getMax() == mmScaler.getMax()
+ True
+ >>> modelPath = temp_path + "/min-max-scaler-model"
+ >>> model.save(modelPath)
+ >>> loadedModel = MinMaxScalerModel.load(modelPath)
+ >>> loadedModel.originalMin == model.originalMin
+ True
+ >>> loadedModel.originalMax == model.originalMax
+ True
.. versionadded:: 1.6.0
"""
@@ -718,7 +802,7 @@ class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol):
return MinMaxScalerModel(java_model)
-class MinMaxScalerModel(JavaModel):
+class MinMaxScalerModel(JavaModel, MLReadable, MLWritable):
"""
.. note:: Experimental
@@ -746,7 +830,7 @@ class MinMaxScalerModel(JavaModel):
@inherit_doc
@ignore_unicode_prefix
-class NGram(JavaTransformer, HasInputCol, HasOutputCol):
+class NGram(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
"""
.. note:: Experimental
@@ -775,6 +859,11 @@ class NGram(JavaTransformer, HasInputCol, HasOutputCol):
Traceback (most recent call last):
...
TypeError: Method setParams forces keyword arguments.
+ >>> ngramPath = temp_path + "/ngram"
+ >>> ngram.save(ngramPath)
+ >>> loadedNGram = NGram.load(ngramPath)
+ >>> loadedNGram.getN() == ngram.getN()
+ True
.. versionadded:: 1.5.0
"""
@@ -819,7 +908,7 @@ class NGram(JavaTransformer, HasInputCol, HasOutputCol):
@inherit_doc
-class Normalizer(JavaTransformer, HasInputCol, HasOutputCol):
+class Normalizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
"""
.. note:: Experimental
@@ -836,6 +925,11 @@ class Normalizer(JavaTransformer, HasInputCol, HasOutputCol):
>>> params = {normalizer.p: 1.0, normalizer.inputCol: "dense", normalizer.outputCol: "vector"}
>>> normalizer.transform(df, params).head().vector
DenseVector([0.4286, -0.5714])
+ >>> normalizerPath = temp_path + "/normalizer"
+ >>> normalizer.save(normalizerPath)
+ >>> loadedNormalizer = Normalizer.load(normalizerPath)
+ >>> loadedNormalizer.getP() == normalizer.getP()
+ True
.. versionadded:: 1.4.0
"""
@@ -880,7 +974,7 @@ class Normalizer(JavaTransformer, HasInputCol, HasOutputCol):
@inherit_doc
-class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol):
+class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
"""
.. note:: Experimental
@@ -913,6 +1007,11 @@ class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol):
>>> params = {encoder.dropLast: False, encoder.outputCol: "test"}
>>> encoder.transform(td, params).head().test
SparseVector(3, {0: 1.0})
+ >>> onehotEncoderPath = temp_path + "/onehot-encoder"
+ >>> encoder.save(onehotEncoderPath)
+ >>> loadedEncoder = OneHotEncoder.load(onehotEncoderPath)
+ >>> loadedEncoder.getDropLast() == encoder.getDropLast()
+ True
.. versionadded:: 1.4.0
"""
@@ -957,7 +1056,7 @@ class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol):
@inherit_doc
-class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol):
+class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
"""
.. note:: Experimental
@@ -974,6 +1073,11 @@ class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol):
DenseVector([0.5, 0.25, 2.0, 1.0, 4.0])
>>> px.setParams(outputCol="test").transform(df).head().test
DenseVector([0.5, 0.25, 2.0, 1.0, 4.0])
+ >>> polyExpansionPath = temp_path + "/poly-expansion"
+ >>> px.save(polyExpansionPath)
+ >>> loadedPx = PolynomialExpansion.load(polyExpansionPath)
+ >>> loadedPx.getDegree() == px.getDegree()
+ True
.. versionadded:: 1.4.0
"""
@@ -1019,7 +1123,8 @@ class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol):
@inherit_doc
-class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasSeed):
+class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasSeed, MLReadable,
+ MLWritable):
"""
.. note:: Experimental
@@ -1043,6 +1148,11 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasSeed):
>>> bucketed = bucketizer.transform(df).head()
>>> bucketed.buckets
0.0
+ >>> quantileDiscretizerPath = temp_path + "/quantile-discretizer"
+ >>> qds.save(quantileDiscretizerPath)
+ >>> loadedQds = QuantileDiscretizer.load(quantileDiscretizerPath)
+ >>> loadedQds.getNumBuckets() == qds.getNumBuckets()
+ True
.. versionadded:: 2.0.0
"""
@@ -1103,7 +1213,7 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasSeed):
@inherit_doc
@ignore_unicode_prefix
-class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol):
+class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
"""
.. note:: Experimental
@@ -1131,6 +1241,13 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol):
Traceback (most recent call last):
...
TypeError: Method setParams forces keyword arguments.
+ >>> regexTokenizerPath = temp_path + "/regex-tokenizer"
+ >>> reTokenizer.save(regexTokenizerPath)
+ >>> loadedReTokenizer = RegexTokenizer.load(regexTokenizerPath)
+ >>> loadedReTokenizer.getMinTokenLength() == reTokenizer.getMinTokenLength()
+ True
+ >>> loadedReTokenizer.getGaps() == reTokenizer.getGaps()
+ True
.. versionadded:: 1.4.0
"""
@@ -1228,7 +1345,7 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol):
@inherit_doc
-class SQLTransformer(JavaTransformer):
+class SQLTransformer(JavaTransformer, MLReadable, MLWritable):
"""
.. note:: Experimental
@@ -1241,6 +1358,11 @@ class SQLTransformer(JavaTransformer):
... statement="SELECT *, (v1 + v2) AS v3, (v1 * v2) AS v4 FROM __THIS__")
>>> sqlTrans.transform(df).head()
Row(id=0, v1=1.0, v2=3.0, v3=4.0, v4=3.0)
+ >>> sqlTransformerPath = temp_path + "/sql-transformer"
+ >>> sqlTrans.save(sqlTransformerPath)
+ >>> loadedSqlTrans = SQLTransformer.load(sqlTransformerPath)
+ >>> loadedSqlTrans.getStatement() == sqlTrans.getStatement()
+ True
.. versionadded:: 1.6.0
"""
@@ -1284,7 +1406,7 @@ class SQLTransformer(JavaTransformer):
@inherit_doc
-class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol):
+class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable):
"""
.. note:: Experimental
@@ -1301,6 +1423,20 @@ class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol):
DenseVector([1.4142])
>>> model.transform(df).collect()[1].scaled
DenseVector([1.4142])
+ >>> standardScalerPath = temp_path + "/standard-scaler"
+ >>> standardScaler.save(standardScalerPath)
+ >>> loadedStandardScaler = StandardScaler.load(standardScalerPath)
+ >>> loadedStandardScaler.getWithMean() == standardScaler.getWithMean()
+ True
+ >>> loadedStandardScaler.getWithStd() == standardScaler.getWithStd()
+ True
+ >>> modelPath = temp_path + "/standard-scaler-model"
+ >>> model.save(modelPath)
+ >>> loadedModel = StandardScalerModel.load(modelPath)
+ >>> loadedModel.std == model.std
+ True
+ >>> loadedModel.mean == model.mean
+ True
.. versionadded:: 1.4.0
"""
@@ -1363,7 +1499,7 @@ class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol):
return StandardScalerModel(java_model)
-class StandardScalerModel(JavaModel):
+class StandardScalerModel(JavaModel, MLReadable, MLWritable):
"""
.. note:: Experimental
@@ -1390,7 +1526,8 @@ class StandardScalerModel(JavaModel):
@inherit_doc
-class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid):
+class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, MLReadable,
+ MLWritable):
"""
.. note:: Experimental
@@ -1410,6 +1547,21 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid):
>>> sorted(set([(i[0], str(i[1])) for i in itd.select(itd.id, itd.label2).collect()]),
... key=lambda x: x[0])
[(0, 'a'), (1, 'b'), (2, 'c'), (3, 'a'), (4, 'a'), (5, 'c')]
+ >>> stringIndexerPath = temp_path + "/string-indexer"
+ >>> stringIndexer.save(stringIndexerPath)
+ >>> loadedIndexer = StringIndexer.load(stringIndexerPath)
+ >>> loadedIndexer.getHandleInvalid() == stringIndexer.getHandleInvalid()
+ True
+ >>> modelPath = temp_path + "/string-indexer-model"
+ >>> model.save(modelPath)
+ >>> loadedModel = StringIndexerModel.load(modelPath)
+ >>> loadedModel.labels == model.labels
+ True
+ >>> indexToStringPath = temp_path + "/index-to-string"
+ >>> inverter.save(indexToStringPath)
+ >>> loadedInverter = IndexToString.load(indexToStringPath)
+ >>> loadedInverter.getLabels() == inverter.getLabels()
+ True
.. versionadded:: 1.4.0
"""
@@ -1439,7 +1591,7 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid):
return StringIndexerModel(java_model)
-class StringIndexerModel(JavaModel):
+class StringIndexerModel(JavaModel, MLReadable, MLWritable):
"""
.. note:: Experimental
@@ -1458,7 +1610,7 @@ class StringIndexerModel(JavaModel):
@inherit_doc
-class IndexToString(JavaTransformer, HasInputCol, HasOutputCol):
+class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
"""
.. note:: Experimental
@@ -1512,13 +1664,25 @@ class IndexToString(JavaTransformer, HasInputCol, HasOutputCol):
return self.getOrDefault(self.labels)
-class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol):
+class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
"""
.. note:: Experimental
A feature transformer that filters out stop words from input.
Note: null values from input array are preserved unless adding null to stopWords explicitly.
+ >>> df = sqlContext.createDataFrame([(["a", "b", "c"],)], ["text"])
+ >>> remover = StopWordsRemover(inputCol="text", outputCol="words", stopWords=["b"])
+ >>> remover.transform(df).head().words == ['a', 'c']
+ True
+ >>> stopWordsRemoverPath = temp_path + "/stopwords-remover"
+ >>> remover.save(stopWordsRemoverPath)
+ >>> loadedRemover = StopWordsRemover.load(stopWordsRemoverPath)
+ >>> loadedRemover.getStopWords() == remover.getStopWords()
+ True
+ >>> loadedRemover.getCaseSensitive() == remover.getCaseSensitive()
+ True
+
.. versionadded:: 1.6.0
"""
@@ -1538,7 +1702,7 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol):
self.uid)
stopWordsObj = _jvm().org.apache.spark.ml.feature.StopWords
defaultStopWords = stopWordsObj.English()
- self._setDefault(stopWords=defaultStopWords)
+ self._setDefault(stopWords=defaultStopWords, caseSensitive=False)
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
@@ -1587,7 +1751,7 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol):
@inherit_doc
@ignore_unicode_prefix
-class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol):
+class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
"""
.. note:: Experimental
@@ -1611,6 +1775,11 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol):
Traceback (most recent call last):
...
TypeError: Method setParams forces keyword arguments.
+ >>> tokenizerPath = temp_path + "/tokenizer"
+ >>> tokenizer.save(tokenizerPath)
+ >>> loadedTokenizer = Tokenizer.load(tokenizerPath)
+ >>> loadedTokenizer.transform(df).head().tokens == tokenizer.transform(df).head().tokens
+ True
.. versionadded:: 1.3.0
"""
@@ -1637,7 +1806,7 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol):
@inherit_doc
-class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol):
+class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, MLReadable, MLWritable):
"""
.. note:: Experimental
@@ -1652,6 +1821,11 @@ class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol):
>>> params = {vecAssembler.inputCols: ["b", "a"], vecAssembler.outputCol: "vector"}
>>> vecAssembler.transform(df, params).head().vector
DenseVector([0.0, 1.0])
+ >>> vectorAssemblerPath = temp_path + "/vector-assembler"
+ >>> vecAssembler.save(vectorAssemblerPath)
+ >>> loadedAssembler = VectorAssembler.load(vectorAssemblerPath)
+ >>> loadedAssembler.transform(df).head().freqs == vecAssembler.transform(df).head().freqs
+ True
.. versionadded:: 1.4.0
"""
@@ -1678,7 +1852,7 @@ class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol):
@inherit_doc
-class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol):
+class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable):
"""
.. note:: Experimental
@@ -1734,6 +1908,18 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol):
>>> model2 = indexer.fit(df, params)
>>> model2.transform(df).head().vector
DenseVector([1.0, 0.0])
+ >>> vectorIndexerPath = temp_path + "/vector-indexer"
+ >>> indexer.save(vectorIndexerPath)
+ >>> loadedIndexer = VectorIndexer.load(vectorIndexerPath)
+ >>> loadedIndexer.getMaxCategories() == indexer.getMaxCategories()
+ True
+ >>> modelPath = temp_path + "/vector-indexer-model"
+ >>> model.save(modelPath)
+ >>> loadedModel = VectorIndexerModel.load(modelPath)
+ >>> loadedModel.numFeatures == model.numFeatures
+ True
+ >>> loadedModel.categoryMaps == model.categoryMaps
+ True
.. versionadded:: 1.4.0
"""
@@ -1783,7 +1969,7 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol):
return VectorIndexerModel(java_model)
-class VectorIndexerModel(JavaModel):
+class VectorIndexerModel(JavaModel, MLReadable, MLWritable):
"""
.. note:: Experimental
@@ -1812,7 +1998,7 @@ class VectorIndexerModel(JavaModel):
@inherit_doc
-class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol):
+class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
"""
.. note:: Experimental
@@ -1834,6 +2020,13 @@ class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol):
>>> vs = VectorSlicer(inputCol="features", outputCol="sliced", indices=[1, 4])
>>> vs.transform(df).head().sliced
DenseVector([2.3, 1.0])
+ >>> vectorSlicerPath = temp_path + "/vector-slicer"
+ >>> vs.save(vectorSlicerPath)
+ >>> loadedVs = VectorSlicer.load(vectorSlicerPath)
+ >>> loadedVs.getIndices() == vs.getIndices()
+ True
+ >>> loadedVs.getNames() == vs.getNames()
+ True
.. versionadded:: 1.6.0
"""
@@ -1852,6 +2045,7 @@ class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol):
"""
super(VectorSlicer, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.VectorSlicer", self.uid)
+ self._setDefault(indices=[], names=[])
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
@@ -1898,7 +2092,8 @@ class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol):
@inherit_doc
@ignore_unicode_prefix
-class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, HasOutputCol):
+class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, HasOutputCol,
+ MLReadable, MLWritable):
"""
.. note:: Experimental
@@ -1907,7 +2102,8 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has
>>> sent = ("a b " * 100 + "a c " * 10).split(" ")
>>> doc = sqlContext.createDataFrame([(sent,), (sent,)], ["sentence"])
- >>> model = Word2Vec(vectorSize=5, seed=42, inputCol="sentence", outputCol="model").fit(doc)
+ >>> word2Vec = Word2Vec(vectorSize=5, seed=42, inputCol="sentence", outputCol="model")
+ >>> model = word2Vec.fit(doc)
>>> model.getVectors().show()
+----+--------------------+
|word| vector|
@@ -1927,6 +2123,22 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has
...
>>> model.transform(doc).head().model
DenseVector([0.5524, -0.4995, -0.3599, 0.0241, 0.3461])
+ >>> word2vecPath = temp_path + "/word2vec"
+ >>> word2Vec.save(word2vecPath)
+ >>> loadedWord2Vec = Word2Vec.load(word2vecPath)
+ >>> loadedWord2Vec.getVectorSize() == word2Vec.getVectorSize()
+ True
+ >>> loadedWord2Vec.getNumPartitions() == word2Vec.getNumPartitions()
+ True
+ >>> loadedWord2Vec.getMinCount() == word2Vec.getMinCount()
+ True
+ >>> modelPath = temp_path + "/word2vec-model"
+ >>> model.save(modelPath)
+ >>> loadedModel = Word2VecModel.load(modelPath)
+ >>> loadedModel.getVectors().first().word == model.getVectors().first().word
+ True
+ >>> loadedModel.getVectors().first().vector == model.getVectors().first().vector
+ True
.. versionadded:: 1.4.0
"""
@@ -2014,7 +2226,7 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has
return Word2VecModel(java_model)
-class Word2VecModel(JavaModel):
+class Word2VecModel(JavaModel, MLReadable, MLWritable):
"""
.. note:: Experimental
@@ -2045,7 +2257,7 @@ class Word2VecModel(JavaModel):
@inherit_doc
-class PCA(JavaEstimator, HasInputCol, HasOutputCol):
+class PCA(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable):
"""
.. note:: Experimental
@@ -2062,6 +2274,18 @@ class PCA(JavaEstimator, HasInputCol, HasOutputCol):
DenseVector([1.648..., -4.013...])
>>> model.explainedVariance
DenseVector([0.794..., 0.205...])
+ >>> pcaPath = temp_path + "/pca"
+ >>> pca.save(pcaPath)
+ >>> loadedPca = PCA.load(pcaPath)
+ >>> loadedPca.getK() == pca.getK()
+ True
+ >>> modelPath = temp_path + "/pca-model"
+ >>> model.save(modelPath)
+ >>> loadedModel = PCAModel.load(modelPath)
+ >>> loadedModel.pc == model.pc
+ True
+ >>> loadedModel.explainedVariance == model.explainedVariance
+ True
.. versionadded:: 1.5.0
"""
@@ -2107,7 +2331,7 @@ class PCA(JavaEstimator, HasInputCol, HasOutputCol):
return PCAModel(java_model)
-class PCAModel(JavaModel):
+class PCAModel(JavaModel, MLReadable, MLWritable):
"""
.. note:: Experimental
@@ -2226,7 +2450,8 @@ class RFormulaModel(JavaModel):
@inherit_doc
-class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol):
+class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, MLReadable,
+ MLWritable):
"""
.. note:: Experimental
@@ -2245,6 +2470,16 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol):
DenseVector([1.0])
>>> model.selectedFeatures
[3]
+ >>> chiSqSelectorPath = temp_path + "/chi-sq-selector"
+ >>> selector.save(chiSqSelectorPath)
+ >>> loadedSelector = ChiSqSelector.load(chiSqSelectorPath)
+ >>> loadedSelector.getNumTopFeatures() == selector.getNumTopFeatures()
+ True
+ >>> modelPath = temp_path + "/chi-sq-selector-model"
+ >>> model.save(modelPath)
+ >>> loadedModel = ChiSqSelectorModel.load(modelPath)
+ >>> loadedModel.selectedFeatures == model.selectedFeatures
+ True
.. versionadded:: 2.0.0
"""
@@ -2302,7 +2537,7 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol):
return ChiSqSelectorModel(java_model)
-class ChiSqSelectorModel(JavaModel):
+class ChiSqSelectorModel(JavaModel, MLReadable, MLWritable):
"""
.. note:: Experimental
@@ -2322,9 +2557,16 @@ class ChiSqSelectorModel(JavaModel):
if __name__ == "__main__":
import doctest
+ import tempfile
+
+ import pyspark.ml.feature
from pyspark.context import SparkContext
from pyspark.sql import Row, SQLContext
+
globs = globals().copy()
+ features = pyspark.ml.feature.__dict__.copy()
+ globs.update(features)
+
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
sc = SparkContext("local[2]", "ml.feature tests")
@@ -2335,7 +2577,16 @@ if __name__ == "__main__":
Row(id=2, label="c"), Row(id=3, label="a"),
Row(id=4, label="a"), Row(id=5, label="c")], 2)
globs['stringIndDf'] = sqlContext.createDataFrame(testData)
- (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
- sc.stop()
+ temp_path = tempfile.mkdtemp()
+ globs['temp_path'] = temp_path
+ try:
+ (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
+ sc.stop()
+ finally:
+ from shutil import rmtree
+ try:
+ rmtree(temp_path)
+ except OSError:
+ pass
if failure_count:
exit(-1)