diff options
author | Joseph K. Bradley <joseph@databricks.com> | 2016-03-22 12:11:23 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2016-03-22 12:11:37 -0700 |
commit | 7e3423b9c03c9812d404134c3d204c4cfea87721 (patch) | |
tree | b922610e318774c1db7da6549ee0932b21fe3090 /python/pyspark/ml/feature.py | |
parent | 297c20226d3330309c9165d789749458f8f4ab8e (diff) | |
download | spark-7e3423b9c03c9812d404134c3d204c4cfea87721.tar.gz spark-7e3423b9c03c9812d404134c3d204c4cfea87721.tar.bz2 spark-7e3423b9c03c9812d404134c3d204c4cfea87721.zip |
[SPARK-13951][ML][PYTHON] Nested Pipeline persistence
Adds support for saving and loading nested ML Pipelines from Python. Pipeline and PipelineModel do not extend JavaWrapper, but they are able to utilize the JavaMLWriter, JavaMLReader implementations.
Also:
* Separates out interfaces from Java wrapper implementations for MLWritable, MLReadable, MLWriter, MLReader.
* Moves methods _stages_java2py, _stages_py2java into Pipeline, PipelineModel as _transfer_stage_from_java, _transfer_stage_to_java
Added new unit test for nested Pipelines. Abstracted validity check into a helper method for the 2 unit tests.
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #11866 from jkbradley/nested-pipeline-io.
Closes #11835
Diffstat (limited to 'python/pyspark/ml/feature.py')
-rw-r--r-- | python/pyspark/ml/feature.py | 89 |
1 files changed, 46 insertions, 43 deletions
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 3182faac0d..16cb9d1db3 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -22,7 +22,7 @@ if sys.version > '3': from pyspark import since from pyspark.rdd import ignore_unicode_prefix from pyspark.ml.param.shared import * -from pyspark.ml.util import keyword_only, MLReadable, MLWritable +from pyspark.ml.util import keyword_only, JavaMLReadable, JavaMLWritable from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer, _jvm from pyspark.mllib.common import inherit_doc from pyspark.mllib.linalg import _convert_to_vector @@ -58,7 +58,7 @@ __all__ = ['Binarizer', @inherit_doc -class Binarizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class Binarizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -123,7 +123,7 @@ class Binarizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritab @inherit_doc -class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -198,7 +198,7 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWrita @inherit_doc -class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -324,7 +324,7 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWr return CountVectorizerModel(java_model) -class CountVectorizerModel(JavaModel, MLReadable, MLWritable): +class CountVectorizerModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -343,7 +343,7 @@ class CountVectorizerModel(JavaModel, MLReadable, MLWritable): @inherit_doc -class DCT(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class DCT(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -415,7 +415,8 @@ class DCT(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): @inherit_doc -class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, + JavaMLWritable): """ .. note:: Experimental @@ -481,7 +482,8 @@ class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, @inherit_doc -class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, MLReadable, MLWritable): +class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, JavaMLReadable, + JavaMLWritable): """ .. note:: Experimental @@ -529,7 +531,7 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, MLRe @inherit_doc -class IDF(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class IDF(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -604,7 +606,7 @@ class IDF(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable): return IDFModel(java_model) -class IDFModel(JavaModel, MLReadable, MLWritable): +class IDFModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -615,7 +617,7 @@ class IDFModel(JavaModel, MLReadable, MLWritable): @inherit_doc -class MaxAbsScaler(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class MaxAbsScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -676,7 +678,7 @@ class MaxAbsScaler(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWrita return MaxAbsScalerModel(java_model) -class MaxAbsScalerModel(JavaModel, MLReadable, MLWritable): +class MaxAbsScalerModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -695,7 +697,7 @@ class MaxAbsScalerModel(JavaModel, MLReadable, MLWritable): @inherit_doc -class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -802,7 +804,7 @@ class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWrita return MinMaxScalerModel(java_model) -class MinMaxScalerModel(JavaModel, MLReadable, MLWritable): +class MinMaxScalerModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -830,7 +832,7 @@ class MinMaxScalerModel(JavaModel, MLReadable, MLWritable): @inherit_doc @ignore_unicode_prefix -class NGram(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class NGram(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -908,7 +910,7 @@ class NGram(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): @inherit_doc -class Normalizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class Normalizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -974,7 +976,7 @@ class Normalizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWrita @inherit_doc -class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1056,7 +1058,8 @@ class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWr @inherit_doc -class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, + JavaMLWritable): """ .. note:: Experimental @@ -1123,8 +1126,8 @@ class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, MLReadable @inherit_doc -class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasSeed, MLReadable, - MLWritable): +class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasSeed, JavaMLReadable, + JavaMLWritable): """ .. note:: Experimental @@ -1213,7 +1216,7 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasSeed, MLR @inherit_doc @ignore_unicode_prefix -class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1345,7 +1348,7 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLW @inherit_doc -class SQLTransformer(JavaTransformer, MLReadable, MLWritable): +class SQLTransformer(JavaTransformer, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1406,7 +1409,7 @@ class SQLTransformer(JavaTransformer, MLReadable, MLWritable): @inherit_doc -class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1499,7 +1502,7 @@ class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWri return StandardScalerModel(java_model) -class StandardScalerModel(JavaModel, MLReadable, MLWritable): +class StandardScalerModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1526,8 +1529,8 @@ class StandardScalerModel(JavaModel, MLReadable, MLWritable): @inherit_doc -class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, MLReadable, - MLWritable): +class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, JavaMLReadable, + JavaMLWritable): """ .. note:: Experimental @@ -1591,7 +1594,7 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, return StringIndexerModel(java_model) -class StringIndexerModel(JavaModel, MLReadable, MLWritable): +class StringIndexerModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1610,7 +1613,7 @@ class StringIndexerModel(JavaModel, MLReadable, MLWritable): @inherit_doc -class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1664,7 +1667,7 @@ class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWr return self.getOrDefault(self.labels) -class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1751,7 +1754,7 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, M @inherit_doc @ignore_unicode_prefix -class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1806,7 +1809,7 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritab @inherit_doc -class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, MLReadable, MLWritable): +class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1852,7 +1855,7 @@ class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, MLReadable, M @inherit_doc -class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1969,7 +1972,7 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWrit return VectorIndexerModel(java_model) -class VectorIndexerModel(JavaModel, MLReadable, MLWritable): +class VectorIndexerModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1998,7 +2001,7 @@ class VectorIndexerModel(JavaModel, MLReadable, MLWritable): @inherit_doc -class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -2093,7 +2096,7 @@ class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWri @inherit_doc @ignore_unicode_prefix class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, HasOutputCol, - MLReadable, MLWritable): + JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -2226,7 +2229,7 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has return Word2VecModel(java_model) -class Word2VecModel(JavaModel, MLReadable, MLWritable): +class Word2VecModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -2257,7 +2260,7 @@ class Word2VecModel(JavaModel, MLReadable, MLWritable): @inherit_doc -class PCA(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class PCA(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -2331,7 +2334,7 @@ class PCA(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable): return PCAModel(java_model) -class PCAModel(JavaModel, MLReadable, MLWritable): +class PCAModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -2360,7 +2363,7 @@ class PCAModel(JavaModel, MLReadable, MLWritable): @inherit_doc -class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, MLReadable, MLWritable): +class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -2463,7 +2466,7 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, MLReadable, MLWritabl return RFormulaModel(java_model) -class RFormulaModel(JavaModel, MLReadable, MLWritable): +class RFormulaModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -2474,8 +2477,8 @@ class RFormulaModel(JavaModel, MLReadable, MLWritable): @inherit_doc -class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, MLReadable, - MLWritable): +class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, JavaMLReadable, + JavaMLWritable): """ .. note:: Experimental @@ -2561,7 +2564,7 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, ML return ChiSqSelectorModel(java_model) -class ChiSqSelectorModel(JavaModel, MLReadable, MLWritable): +class ChiSqSelectorModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental |