From 7e3423b9c03c9812d404134c3d204c4cfea87721 Mon Sep 17 00:00:00 2001 From: "Joseph K. Bradley" Date: Tue, 22 Mar 2016 12:11:23 -0700 Subject: [SPARK-13951][ML][PYTHON] Nested Pipeline persistence Adds support for saving and loading nested ML Pipelines from Python. Pipeline and PipelineModel do not extend JavaWrapper, but they are able to utilize the JavaMLWriter, JavaMLReader implementations. Also: * Separates out interfaces from Java wrapper implementations for MLWritable, MLReadable, MLWriter, MLReader. * Moves methods _stages_java2py, _stages_py2java into Pipeline, PipelineModel as _transfer_stage_from_java, _transfer_stage_to_java Added new unit test for nested Pipelines. Abstracted validity check into a helper method for the 2 unit tests. Author: Joseph K. Bradley Closes #11866 from jkbradley/nested-pipeline-io. Closes #11835 --- python/pyspark/ml/classification.py | 8 +- python/pyspark/ml/clustering.py | 4 +- python/pyspark/ml/feature.py | 89 ++++++++++----------- python/pyspark/ml/pipeline.py | 150 ++++++++++++++++++------------------ python/pyspark/ml/recommendation.py | 4 +- python/pyspark/ml/regression.py | 12 +-- python/pyspark/ml/tests.py | 82 +++++++++++++++----- python/pyspark/ml/util.py | 89 +++++++++++++++++---- python/pyspark/ml/wrapper.py | 37 ++++++--- 9 files changed, 300 insertions(+), 175 deletions(-) (limited to 'python/pyspark') diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 16ad76483d..8075108114 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -38,7 +38,7 @@ __all__ = ['LogisticRegression', 'LogisticRegressionModel', class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol, HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds, - HasWeightCol, MLWritable, MLReadable): + HasWeightCol, JavaMLWritable, JavaMLReadable): """ Logistic regression. Currently, this class only supports binary classification. @@ -198,7 +198,7 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti " threshold (%g) and thresholds (equivalent to %g)" % (t2, t)) -class LogisticRegressionModel(JavaModel, MLWritable, MLReadable): +class LogisticRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): """ Model fitted by LogisticRegression. @@ -601,7 +601,7 @@ class GBTClassificationModel(TreeEnsembleModels): @inherit_doc class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol, - HasRawPredictionCol, MLWritable, MLReadable): + HasRawPredictionCol, JavaMLWritable, JavaMLReadable): """ Naive Bayes Classifiers. It supports both Multinomial and Bernoulli NB. Multinomial NB @@ -720,7 +720,7 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H return self.getOrDefault(self.modelType) -class NaiveBayesModel(JavaModel, MLWritable, MLReadable): +class NaiveBayesModel(JavaModel, JavaMLWritable, JavaMLReadable): """ Model fitted by NaiveBayes. diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 1cea477acb..2db5b82c44 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -25,7 +25,7 @@ __all__ = ['BisectingKMeans', 'BisectingKMeansModel', 'KMeans', 'KMeansModel'] -class KMeansModel(JavaModel, MLWritable, MLReadable): +class KMeansModel(JavaModel, JavaMLWritable, JavaMLReadable): """ Model fitted by KMeans. @@ -48,7 +48,7 @@ class KMeansModel(JavaModel, MLWritable, MLReadable): @inherit_doc class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed, - MLWritable, MLReadable): + JavaMLWritable, JavaMLReadable): """ K-means clustering with support for multiple parallel runs and a k-means++ like initialization mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested, diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 3182faac0d..16cb9d1db3 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -22,7 +22,7 @@ if sys.version > '3': from pyspark import since from pyspark.rdd import ignore_unicode_prefix from pyspark.ml.param.shared import * -from pyspark.ml.util import keyword_only, MLReadable, MLWritable +from pyspark.ml.util import keyword_only, JavaMLReadable, JavaMLWritable from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer, _jvm from pyspark.mllib.common import inherit_doc from pyspark.mllib.linalg import _convert_to_vector @@ -58,7 +58,7 @@ __all__ = ['Binarizer', @inherit_doc -class Binarizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class Binarizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -123,7 +123,7 @@ class Binarizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritab @inherit_doc -class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -198,7 +198,7 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWrita @inherit_doc -class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -324,7 +324,7 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWr return CountVectorizerModel(java_model) -class CountVectorizerModel(JavaModel, MLReadable, MLWritable): +class CountVectorizerModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -343,7 +343,7 @@ class CountVectorizerModel(JavaModel, MLReadable, MLWritable): @inherit_doc -class DCT(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class DCT(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -415,7 +415,8 @@ class DCT(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): @inherit_doc -class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, + JavaMLWritable): """ .. note:: Experimental @@ -481,7 +482,8 @@ class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, @inherit_doc -class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, MLReadable, MLWritable): +class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, JavaMLReadable, + JavaMLWritable): """ .. note:: Experimental @@ -529,7 +531,7 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, MLRe @inherit_doc -class IDF(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class IDF(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -604,7 +606,7 @@ class IDF(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable): return IDFModel(java_model) -class IDFModel(JavaModel, MLReadable, MLWritable): +class IDFModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -615,7 +617,7 @@ class IDFModel(JavaModel, MLReadable, MLWritable): @inherit_doc -class MaxAbsScaler(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class MaxAbsScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -676,7 +678,7 @@ class MaxAbsScaler(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWrita return MaxAbsScalerModel(java_model) -class MaxAbsScalerModel(JavaModel, MLReadable, MLWritable): +class MaxAbsScalerModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -695,7 +697,7 @@ class MaxAbsScalerModel(JavaModel, MLReadable, MLWritable): @inherit_doc -class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -802,7 +804,7 @@ class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWrita return MinMaxScalerModel(java_model) -class MinMaxScalerModel(JavaModel, MLReadable, MLWritable): +class MinMaxScalerModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -830,7 +832,7 @@ class MinMaxScalerModel(JavaModel, MLReadable, MLWritable): @inherit_doc @ignore_unicode_prefix -class NGram(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class NGram(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -908,7 +910,7 @@ class NGram(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): @inherit_doc -class Normalizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class Normalizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -974,7 +976,7 @@ class Normalizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWrita @inherit_doc -class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1056,7 +1058,8 @@ class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWr @inherit_doc -class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, + JavaMLWritable): """ .. note:: Experimental @@ -1123,8 +1126,8 @@ class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, MLReadable @inherit_doc -class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasSeed, MLReadable, - MLWritable): +class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasSeed, JavaMLReadable, + JavaMLWritable): """ .. note:: Experimental @@ -1213,7 +1216,7 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasSeed, MLR @inherit_doc @ignore_unicode_prefix -class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1345,7 +1348,7 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLW @inherit_doc -class SQLTransformer(JavaTransformer, MLReadable, MLWritable): +class SQLTransformer(JavaTransformer, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1406,7 +1409,7 @@ class SQLTransformer(JavaTransformer, MLReadable, MLWritable): @inherit_doc -class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1499,7 +1502,7 @@ class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWri return StandardScalerModel(java_model) -class StandardScalerModel(JavaModel, MLReadable, MLWritable): +class StandardScalerModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1526,8 +1529,8 @@ class StandardScalerModel(JavaModel, MLReadable, MLWritable): @inherit_doc -class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, MLReadable, - MLWritable): +class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, JavaMLReadable, + JavaMLWritable): """ .. note:: Experimental @@ -1591,7 +1594,7 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, return StringIndexerModel(java_model) -class StringIndexerModel(JavaModel, MLReadable, MLWritable): +class StringIndexerModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1610,7 +1613,7 @@ class StringIndexerModel(JavaModel, MLReadable, MLWritable): @inherit_doc -class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1664,7 +1667,7 @@ class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWr return self.getOrDefault(self.labels) -class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1751,7 +1754,7 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, M @inherit_doc @ignore_unicode_prefix -class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1806,7 +1809,7 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritab @inherit_doc -class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, MLReadable, MLWritable): +class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1852,7 +1855,7 @@ class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, MLReadable, M @inherit_doc -class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1969,7 +1972,7 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWrit return VectorIndexerModel(java_model) -class VectorIndexerModel(JavaModel, MLReadable, MLWritable): +class VectorIndexerModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1998,7 +2001,7 @@ class VectorIndexerModel(JavaModel, MLReadable, MLWritable): @inherit_doc -class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -2093,7 +2096,7 @@ class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWri @inherit_doc @ignore_unicode_prefix class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, HasOutputCol, - MLReadable, MLWritable): + JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -2226,7 +2229,7 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has return Word2VecModel(java_model) -class Word2VecModel(JavaModel, MLReadable, MLWritable): +class Word2VecModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -2257,7 +2260,7 @@ class Word2VecModel(JavaModel, MLReadable, MLWritable): @inherit_doc -class PCA(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class PCA(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -2331,7 +2334,7 @@ class PCA(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable): return PCAModel(java_model) -class PCAModel(JavaModel, MLReadable, MLWritable): +class PCAModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -2360,7 +2363,7 @@ class PCAModel(JavaModel, MLReadable, MLWritable): @inherit_doc -class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, MLReadable, MLWritable): +class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -2463,7 +2466,7 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, MLReadable, MLWritabl return RFormulaModel(java_model) -class RFormulaModel(JavaModel, MLReadable, MLWritable): +class RFormulaModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -2474,8 +2477,8 @@ class RFormulaModel(JavaModel, MLReadable, MLWritable): @inherit_doc -class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, MLReadable, - MLWritable): +class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, JavaMLReadable, + JavaMLWritable): """ .. note:: Experimental @@ -2561,7 +2564,7 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, ML return ChiSqSelectorModel(java_model) -class ChiSqSelectorModel(JavaModel, MLReadable, MLWritable): +class ChiSqSelectorModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index a1658b0a02..2b5504bc29 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -24,72 +24,31 @@ from pyspark import SparkContext from pyspark import since from pyspark.ml import Estimator, Model, Transformer from pyspark.ml.param import Param, Params -from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader +from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, MLReadable, MLWritable from pyspark.ml.wrapper import JavaWrapper from pyspark.mllib.common import inherit_doc -def _stages_java2py(java_stages): - """ - Transforms the parameter Python stages from a list of Java stages. - :param java_stages: An array of Java stages. - :return: An array of Python stages. - """ - - return [JavaWrapper._transfer_stage_from_java(stage) for stage in java_stages] - - -def _stages_py2java(py_stages, cls): - """ - Transforms the parameter of Python stages to a Java array of Java stages. - :param py_stages: An array of Python stages. - :return: A Java array of Java Stages. - """ - - for stage in py_stages: - assert(isinstance(stage, JavaWrapper), - "Python side implementation is not supported in the meta-PipelineStage currently.") - gateway = SparkContext._gateway - java_stages = gateway.new_array(cls, len(py_stages)) - for idx, stage in enumerate(py_stages): - java_stages[idx] = stage._transfer_stage_to_java() - return java_stages - - @inherit_doc -class PipelineMLWriter(JavaMLWriter, JavaWrapper): +class PipelineMLWriter(JavaMLWriter): """ Private Pipeline utility class that can save ML instances through their Scala implementation. - """ - def __init__(self, instance): - cls = SparkContext._jvm.org.apache.spark.ml.PipelineStage - self._java_obj = self._new_java_obj("org.apache.spark.ml.Pipeline", instance.uid) - self._java_obj.setStages(_stages_py2java(instance.getStages(), cls)) - self._jwrite = self._java_obj.write() + We can currently use JavaMLWriter, rather than MLWriter, since Pipeline implements _to_java. + """ @inherit_doc class PipelineMLReader(JavaMLReader): """ Private utility class that can load Pipeline instances through their Scala implementation. - """ - def load(self, path): - """Load the Pipeline instance from the input path.""" - if not isinstance(path, basestring): - raise TypeError("path should be a basestring, got type %s" % type(path)) - - java_obj = self._jread.load(path) - instance = self._clazz() - instance._resetUid(java_obj.uid()) - instance.setStages(_stages_java2py(java_obj.getStages())) - - return instance + We can currently use JavaMLReader, rather than MLReader, since Pipeline implements _from_java. + """ @inherit_doc -class Pipeline(Estimator): +class Pipeline(Estimator, MLReadable, MLWritable): """ A simple pipeline, which acts as an estimator. A Pipeline consists of a sequence of stages, each of which is either an @@ -206,49 +165,65 @@ class Pipeline(Estimator): @classmethod @since("2.0.0") def read(cls): - """Returns an JavaMLReader instance for this class.""" + """Returns an MLReader instance for this class.""" return PipelineMLReader(cls) @classmethod - @since("2.0.0") - def load(cls, path): - """Reads an ML instance from the input path, a shortcut of `read().load(path)`.""" - return cls.read().load(path) + def _from_java(cls, java_stage): + """ + Given a Java Pipeline, create and return a Python wrapper of it. + Used for ML persistence. + """ + # Create a new instance of this stage. + py_stage = cls() + # Load information from java_stage to the instance. + py_stages = [JavaWrapper._from_java(s) for s in java_stage.getStages()] + py_stage.setStages(py_stages) + py_stage._resetUid(java_stage.uid()) + return py_stage + + def _to_java(self): + """ + Transfer this instance to a Java Pipeline. Used for ML persistence. + + :return: Java object equivalent to this instance. + """ + + gateway = SparkContext._gateway + cls = SparkContext._jvm.org.apache.spark.ml.PipelineStage + java_stages = gateway.new_array(cls, len(self.getStages())) + for idx, stage in enumerate(self.getStages()): + java_stages[idx] = stage._to_java() + + _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.Pipeline", self.uid) + _java_obj.setStages(java_stages) + + return _java_obj @inherit_doc -class PipelineModelMLWriter(JavaMLWriter, JavaWrapper): +class PipelineModelMLWriter(JavaMLWriter): """ Private PipelineModel utility class that can save ML instances through their Scala implementation. - """ - def __init__(self, instance): - cls = SparkContext._jvm.org.apache.spark.ml.Transformer - self._java_obj = self._new_java_obj("org.apache.spark.ml.PipelineModel", - instance.uid, - _stages_py2java(instance.stages, cls)) - self._jwrite = self._java_obj.write() + We can (currently) use JavaMLWriter, rather than MLWriter, since PipelineModel implements + _to_java. + """ @inherit_doc class PipelineModelMLReader(JavaMLReader): """ Private utility class that can load PipelineModel instances through their Scala implementation. - """ - def load(self, path): - """Load the PipelineModel instance from the input path.""" - if not isinstance(path, basestring): - raise TypeError("path should be a basestring, got type %s" % type(path)) - java_obj = self._jread.load(path) - instance = self._clazz(_stages_java2py(java_obj.stages())) - instance._resetUid(java_obj.uid()) - return instance + We can currently use JavaMLReader, rather than MLReader, since PipelineModel implements + _from_java. + """ @inherit_doc -class PipelineModel(Model): +class PipelineModel(Model, MLReadable, MLWritable): """ Represents a compiled pipeline with transformers and fitted models. @@ -294,7 +269,32 @@ class PipelineModel(Model): return PipelineModelMLReader(cls) @classmethod - @since("2.0.0") - def load(cls, path): - """Reads an ML instance from the input path, a shortcut of `read().load(path)`.""" - return cls.read().load(path) + def _from_java(cls, java_stage): + """ + Given a Java PipelineModel, create and return a Python wrapper of it. + Used for ML persistence. + """ + # Load information from java_stage to the instance. + py_stages = [JavaWrapper._from_java(s) for s in java_stage.stages()] + # Create a new instance of this stage. + py_stage = cls(py_stages) + py_stage._resetUid(java_stage.uid()) + return py_stage + + def _to_java(self): + """ + Transfer this instance to a Java PipelineModel. Used for ML persistence. + + :return: Java object equivalent to this instance. + """ + + gateway = SparkContext._gateway + cls = SparkContext._jvm.org.apache.spark.ml.Transformer + java_stages = gateway.new_array(cls, len(self.stages)) + for idx, stage in enumerate(self.stages): + java_stages[idx] = stage._to_java() + + _java_obj =\ + JavaWrapper._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages) + + return _java_obj diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index 2b605e5c50..de4c2675ed 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -27,7 +27,7 @@ __all__ = ['ALS', 'ALSModel'] @inherit_doc class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, HasRegParam, HasSeed, - MLWritable, MLReadable): + JavaMLWritable, JavaMLReadable): """ Alternating Least Squares (ALS) matrix factorization. @@ -289,7 +289,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha return self.getOrDefault(self.nonnegative) -class ALSModel(JavaModel, MLWritable, MLReadable): +class ALSModel(JavaModel, JavaMLWritable, JavaMLReadable): """ Model fitted by ALS. diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 6e23393f91..664a44bc47 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -35,7 +35,7 @@ __all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel', @inherit_doc class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept, - HasStandardization, HasSolver, HasWeightCol, MLWritable, MLReadable): + HasStandardization, HasSolver, HasWeightCol, JavaMLWritable, JavaMLReadable): """ Linear regression. @@ -118,7 +118,7 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction return LinearRegressionModel(java_model) -class LinearRegressionModel(JavaModel, MLWritable, MLReadable): +class LinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): """ Model fitted by LinearRegression. @@ -154,7 +154,7 @@ class LinearRegressionModel(JavaModel, MLWritable, MLReadable): @inherit_doc class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, - HasWeightCol, MLWritable, MLReadable): + HasWeightCol, JavaMLWritable, JavaMLReadable): """ .. note:: Experimental @@ -249,7 +249,7 @@ class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti return self.getOrDefault(self.featureIndex) -class IsotonicRegressionModel(JavaModel, MLWritable, MLReadable): +class IsotonicRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): """ .. note:: Experimental @@ -719,7 +719,7 @@ class GBTRegressionModel(TreeEnsembleModels): @inherit_doc class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, - HasFitIntercept, HasMaxIter, HasTol, MLWritable, MLReadable): + HasFitIntercept, HasMaxIter, HasTol, JavaMLWritable, JavaMLReadable): """ Accelerated Failure Time (AFT) Model Survival Regression @@ -857,7 +857,7 @@ class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi return self.getOrDefault(self.quantilesCol) -class AFTSurvivalRegressionModel(JavaModel, MLWritable, MLReadable): +class AFTSurvivalRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): """ Model fitted by AFTSurvivalRegression. diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 9783ce7e77..211248e8b2 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -47,6 +47,7 @@ from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed from pyspark.ml.regression import LinearRegression from pyspark.ml.tuning import * from pyspark.ml.util import keyword_only +from pyspark.ml.wrapper import JavaWrapper from pyspark.mllib.linalg import DenseVector from pyspark.sql import DataFrame, SQLContext, Row from pyspark.sql.functions import rand @@ -517,7 +518,39 @@ class PersistenceTest(PySparkTestCase): except OSError: pass + def _compare_pipelines(self, m1, m2): + """ + Compare 2 ML types, asserting that they are equivalent. + This currently supports: + - basic types + - Pipeline, PipelineModel + This checks: + - uid + - type + - Param values and parents + """ + self.assertEqual(m1.uid, m2.uid) + self.assertEqual(type(m1), type(m2)) + if isinstance(m1, JavaWrapper): + self.assertEqual(len(m1.params), len(m2.params)) + for p in m1.params: + self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p)) + self.assertEqual(p.parent, m2.getParam(p.name).parent) + elif isinstance(m1, Pipeline): + self.assertEqual(len(m1.getStages()), len(m2.getStages())) + for s1, s2 in zip(m1.getStages(), m2.getStages()): + self._compare_pipelines(s1, s2) + elif isinstance(m1, PipelineModel): + self.assertEqual(len(m1.stages), len(m2.stages)) + for s1, s2 in zip(m1.stages, m2.stages): + self._compare_pipelines(s1, s2) + else: + raise RuntimeError("_compare_pipelines does not yet support type: %s" % type(m1)) + def test_pipeline_persistence(self): + """ + Pipeline[HashingTF, PCA] + """ sqlContext = SQLContext(self.sc) temp_path = tempfile.mkdtemp() @@ -527,33 +560,46 @@ class PersistenceTest(PySparkTestCase): pca = PCA(k=2, inputCol="features", outputCol="pca_features") pl = Pipeline(stages=[tf, pca]) model = pl.fit(df) + pipeline_path = temp_path + "/pipeline" pl.save(pipeline_path) loaded_pipeline = Pipeline.load(pipeline_path) - self.assertEqual(loaded_pipeline.uid, pl.uid) - self.assertEqual(len(loaded_pipeline.getStages()), 2) + self._compare_pipelines(pl, loaded_pipeline) + + model_path = temp_path + "/pipeline-model" + model.save(model_path) + loaded_model = PipelineModel.load(model_path) + self._compare_pipelines(model, loaded_model) + finally: + try: + rmtree(temp_path) + except OSError: + pass - [loaded_tf, loaded_pca] = loaded_pipeline.getStages() - self.assertIsInstance(loaded_tf, HashingTF) - self.assertEqual(loaded_tf.uid, tf.uid) - param = loaded_tf.getParam("numFeatures") - self.assertEqual(loaded_tf.getOrDefault(param), tf.getOrDefault(param)) + def test_nested_pipeline_persistence(self): + """ + Pipeline[HashingTF, Pipeline[PCA]] + """ + sqlContext = SQLContext(self.sc) + temp_path = tempfile.mkdtemp() + + try: + df = sqlContext.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"]) + tf = HashingTF(numFeatures=10, inputCol="words", outputCol="features") + pca = PCA(k=2, inputCol="features", outputCol="pca_features") + p0 = Pipeline(stages=[pca]) + pl = Pipeline(stages=[tf, p0]) + model = pl.fit(df) - self.assertIsInstance(loaded_pca, PCA) - self.assertEqual(loaded_pca.uid, pca.uid) - self.assertEqual(loaded_pca.getK(), pca.getK()) + pipeline_path = temp_path + "/pipeline" + pl.save(pipeline_path) + loaded_pipeline = Pipeline.load(pipeline_path) + self._compare_pipelines(pl, loaded_pipeline) model_path = temp_path + "/pipeline-model" model.save(model_path) loaded_model = PipelineModel.load(model_path) - [model_tf, model_pca] = model.stages - [loaded_model_tf, loaded_model_pca] = loaded_model.stages - self.assertEqual(model_tf.uid, loaded_model_tf.uid) - self.assertEqual(model_tf.getOrDefault(param), loaded_model_tf.getOrDefault(param)) - - self.assertEqual(model_pca.uid, loaded_model_pca.uid) - self.assertEqual(model_pca.pc, loaded_model_pca.pc) - self.assertEqual(model_pca.explainedVariance, loaded_model_pca.explainedVariance) + self._compare_pipelines(model, loaded_model) finally: try: rmtree(temp_path) diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 42801c91bb..6703851262 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -74,18 +74,38 @@ class Identifiable(object): @inherit_doc -class JavaMLWriter(object): +class MLWriter(object): """ .. note:: Experimental - Utility class that can save ML instances through their Scala implementation. + Utility class that can save ML instances. .. versionadded:: 2.0.0 """ + def save(self, path): + """Save the ML instance to the input path.""" + raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self)) + + def overwrite(self): + """Overwrites if the output path already exists.""" + raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self)) + + def context(self, sqlContext): + """Sets the SQL context to use for saving.""" + raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self)) + + +@inherit_doc +class JavaMLWriter(MLWriter): + """ + (Private) Specialization of :py:class:`MLWriter` for :py:class:`JavaWrapper` types + """ + def __init__(self, instance): - instance._transfer_params_to_java() - self._jwrite = instance._java_obj.write() + super(JavaMLWriter, self).__init__() + _java_obj = instance._to_java() + self._jwrite = _java_obj.write() def save(self, path): """Save the ML instance to the input path.""" @@ -109,14 +129,14 @@ class MLWritable(object): """ .. note:: Experimental - Mixin for ML instances that provide JavaMLWriter. + Mixin for ML instances that provide :py:class:`MLWriter`. .. versionadded:: 2.0.0 """ def write(self): """Returns an JavaMLWriter instance for this ML instance.""" - return JavaMLWriter(self) + raise NotImplementedError("MLWritable is not yet implemented for type: %r" % type(self)) def save(self, path): """Save this ML instance to the given path, a shortcut of `write().save(path)`.""" @@ -124,15 +144,41 @@ class MLWritable(object): @inherit_doc -class JavaMLReader(object): +class JavaMLWritable(MLWritable): + """ + (Private) Mixin for ML instances that provide :py:class:`JavaMLWriter`. + """ + + def write(self): + """Returns an JavaMLWriter instance for this ML instance.""" + return JavaMLWriter(self) + + +@inherit_doc +class MLReader(object): """ .. note:: Experimental - Utility class that can load ML instances through their Scala implementation. + Utility class that can load ML instances. .. versionadded:: 2.0.0 """ + def load(self, path): + """Load the ML instance from the input path.""" + raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self)) + + def context(self, sqlContext): + """Sets the SQL context to use for loading.""" + raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self)) + + +@inherit_doc +class JavaMLReader(MLReader): + """ + (Private) Specialization of :py:class:`MLReader` for :py:class:`JavaWrapper` types + """ + def __init__(self, clazz): self._clazz = clazz self._jread = self._load_java_obj(clazz).read() @@ -142,11 +188,10 @@ class JavaMLReader(object): if not isinstance(path, basestring): raise TypeError("path should be a basestring, got type %s" % type(path)) java_obj = self._jread.load(path) - instance = self._clazz() - instance._java_obj = java_obj - instance._resetUid(java_obj.uid()) - instance._transfer_params_from_java() - return instance + if not hasattr(self._clazz, "_from_java"): + raise NotImplementedError("This Java ML type cannot be loaded into Python currently: %r" + % self._clazz) + return self._clazz._from_java(java_obj) def context(self, sqlContext): """Sets the SQL context to use for loading.""" @@ -164,7 +209,7 @@ class JavaMLReader(object): if clazz.__name__ in ("Pipeline", "PipelineModel"): # Remove the last package name "pipeline" for Pipeline and PipelineModel. java_package = ".".join(java_package.split(".")[0:-1]) - return ".".join([java_package, clazz.__name__]) + return java_package + "." + clazz.__name__ @classmethod def _load_java_obj(cls, clazz): @@ -181,7 +226,7 @@ class MLReadable(object): """ .. note:: Experimental - Mixin for instances that provide JavaMLReader. + Mixin for instances that provide :py:class:`MLReader`. .. versionadded:: 2.0.0 """ @@ -189,9 +234,21 @@ class MLReadable(object): @classmethod def read(cls): """Returns an JavaMLReader instance for this class.""" - return JavaMLReader(cls) + raise NotImplementedError("MLReadable.read() not implemented for type: %r" % cls) @classmethod def load(cls, path): """Reads an ML instance from the input path, a shortcut of `read().load(path)`.""" return cls.read().load(path) + + +@inherit_doc +class JavaMLReadable(MLReadable): + """ + (Private) Mixin for instances that provide JavaMLReader. + """ + + @classmethod + def read(cls): + """Returns an JavaMLReader instance for this class.""" + return JavaMLReader(cls) diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 37dcb23b67..35b0eba926 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -95,12 +95,26 @@ class JavaWrapper(Params): """ return _jvm().org.apache.spark.ml.param.ParamMap() - def _transfer_stage_to_java(self): + def _to_java(self): + """ + Transfer this instance's Params to the wrapped Java object, and return the Java object. + Used for ML persistence. + + Meta-algorithms such as Pipeline should override this method. + + :return: Java object equivalent to this instance. + """ self._transfer_params_to_java() return self._java_obj @staticmethod - def _transfer_stage_from_java(java_stage): + def _from_java(java_stage): + """ + Given a Java object, create and return a Python wrapper of it. + Used for ML persistence. + + Meta-algorithms such as Pipeline should override this method as a classmethod. + """ def __get_class(clazz): """ Loads Python class from its name. @@ -113,13 +127,18 @@ class JavaWrapper(Params): return m stage_name = java_stage.getClass().getName().replace("org.apache.spark", "pyspark") # Generate a default new instance from the stage_name class. - py_stage = __get_class(stage_name)() - assert(isinstance(py_stage, JavaWrapper), - "Python side implementation is not supported in the meta-PipelineStage currently.") - # Load information from java_stage to the instance. - py_stage._java_obj = java_stage - py_stage._resetUid(java_stage.uid()) - py_stage._transfer_params_from_java() + py_type = __get_class(stage_name) + if issubclass(py_type, JavaWrapper): + # Load information from java_stage to the instance. + py_stage = py_type() + py_stage._java_obj = java_stage + py_stage._resetUid(java_stage.uid()) + py_stage._transfer_params_from_java() + elif hasattr(py_type, "_from_java"): + py_stage = py_type._from_java(java_stage) + else: + raise NotImplementedError("This Java stage cannot be loaded into Python currently: %r" + % stage_name) return py_stage -- cgit v1.2.3