diff options
-rw-r--r-- | python/pyspark/ml/classification.py | 8 | ||||
-rw-r--r-- | python/pyspark/ml/clustering.py | 4 | ||||
-rw-r--r-- | python/pyspark/ml/feature.py | 89 | ||||
-rw-r--r-- | python/pyspark/ml/pipeline.py | 150 | ||||
-rw-r--r-- | python/pyspark/ml/recommendation.py | 4 | ||||
-rw-r--r-- | python/pyspark/ml/regression.py | 12 | ||||
-rw-r--r-- | python/pyspark/ml/tests.py | 82 | ||||
-rw-r--r-- | python/pyspark/ml/util.py | 89 | ||||
-rw-r--r-- | python/pyspark/ml/wrapper.py | 37 |
9 files changed, 300 insertions, 175 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py index 16ad76483d..8075108114 100644 --- a/python/pyspark/ml/classification.py +++ b/python/pyspark/ml/classification.py @@ -38,7 +38,7 @@ __all__ = ['LogisticRegression', 'LogisticRegressionModel', class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol, HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds, - HasWeightCol, MLWritable, MLReadable): + HasWeightCol, JavaMLWritable, JavaMLReadable): """ Logistic regression. Currently, this class only supports binary classification. @@ -198,7 +198,7 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti " threshold (%g) and thresholds (equivalent to %g)" % (t2, t)) -class LogisticRegressionModel(JavaModel, MLWritable, MLReadable): +class LogisticRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): """ Model fitted by LogisticRegression. @@ -601,7 +601,7 @@ class GBTClassificationModel(TreeEnsembleModels): @inherit_doc class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol, - HasRawPredictionCol, MLWritable, MLReadable): + HasRawPredictionCol, JavaMLWritable, JavaMLReadable): """ Naive Bayes Classifiers. It supports both Multinomial and Bernoulli NB. Multinomial NB @@ -720,7 +720,7 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H return self.getOrDefault(self.modelType) -class NaiveBayesModel(JavaModel, MLWritable, MLReadable): +class NaiveBayesModel(JavaModel, JavaMLWritable, JavaMLReadable): """ Model fitted by NaiveBayes. diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py index 1cea477acb..2db5b82c44 100644 --- a/python/pyspark/ml/clustering.py +++ b/python/pyspark/ml/clustering.py @@ -25,7 +25,7 @@ __all__ = ['BisectingKMeans', 'BisectingKMeansModel', 'KMeans', 'KMeansModel'] -class KMeansModel(JavaModel, MLWritable, MLReadable): +class KMeansModel(JavaModel, JavaMLWritable, JavaMLReadable): """ Model fitted by KMeans. @@ -48,7 +48,7 @@ class KMeansModel(JavaModel, MLWritable, MLReadable): @inherit_doc class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed, - MLWritable, MLReadable): + JavaMLWritable, JavaMLReadable): """ K-means clustering with support for multiple parallel runs and a k-means++ like initialization mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested, diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 3182faac0d..16cb9d1db3 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -22,7 +22,7 @@ if sys.version > '3': from pyspark import since from pyspark.rdd import ignore_unicode_prefix from pyspark.ml.param.shared import * -from pyspark.ml.util import keyword_only, MLReadable, MLWritable +from pyspark.ml.util import keyword_only, JavaMLReadable, JavaMLWritable from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer, _jvm from pyspark.mllib.common import inherit_doc from pyspark.mllib.linalg import _convert_to_vector @@ -58,7 +58,7 @@ __all__ = ['Binarizer', @inherit_doc -class Binarizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class Binarizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -123,7 +123,7 @@ class Binarizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritab @inherit_doc -class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -198,7 +198,7 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWrita @inherit_doc -class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -324,7 +324,7 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWr return CountVectorizerModel(java_model) -class CountVectorizerModel(JavaModel, MLReadable, MLWritable): +class CountVectorizerModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -343,7 +343,7 @@ class CountVectorizerModel(JavaModel, MLReadable, MLWritable): @inherit_doc -class DCT(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class DCT(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -415,7 +415,8 @@ class DCT(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): @inherit_doc -class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, + JavaMLWritable): """ .. note:: Experimental @@ -481,7 +482,8 @@ class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, @inherit_doc -class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, MLReadable, MLWritable): +class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, JavaMLReadable, + JavaMLWritable): """ .. note:: Experimental @@ -529,7 +531,7 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, MLRe @inherit_doc -class IDF(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class IDF(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -604,7 +606,7 @@ class IDF(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable): return IDFModel(java_model) -class IDFModel(JavaModel, MLReadable, MLWritable): +class IDFModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -615,7 +617,7 @@ class IDFModel(JavaModel, MLReadable, MLWritable): @inherit_doc -class MaxAbsScaler(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class MaxAbsScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -676,7 +678,7 @@ class MaxAbsScaler(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWrita return MaxAbsScalerModel(java_model) -class MaxAbsScalerModel(JavaModel, MLReadable, MLWritable): +class MaxAbsScalerModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -695,7 +697,7 @@ class MaxAbsScalerModel(JavaModel, MLReadable, MLWritable): @inherit_doc -class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -802,7 +804,7 @@ class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWrita return MinMaxScalerModel(java_model) -class MinMaxScalerModel(JavaModel, MLReadable, MLWritable): +class MinMaxScalerModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -830,7 +832,7 @@ class MinMaxScalerModel(JavaModel, MLReadable, MLWritable): @inherit_doc @ignore_unicode_prefix -class NGram(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class NGram(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -908,7 +910,7 @@ class NGram(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): @inherit_doc -class Normalizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class Normalizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -974,7 +976,7 @@ class Normalizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWrita @inherit_doc -class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1056,7 +1058,8 @@ class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWr @inherit_doc -class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, + JavaMLWritable): """ .. note:: Experimental @@ -1123,8 +1126,8 @@ class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, MLReadable @inherit_doc -class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasSeed, MLReadable, - MLWritable): +class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasSeed, JavaMLReadable, + JavaMLWritable): """ .. note:: Experimental @@ -1213,7 +1216,7 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasSeed, MLR @inherit_doc @ignore_unicode_prefix -class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1345,7 +1348,7 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLW @inherit_doc -class SQLTransformer(JavaTransformer, MLReadable, MLWritable): +class SQLTransformer(JavaTransformer, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1406,7 +1409,7 @@ class SQLTransformer(JavaTransformer, MLReadable, MLWritable): @inherit_doc -class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1499,7 +1502,7 @@ class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWri return StandardScalerModel(java_model) -class StandardScalerModel(JavaModel, MLReadable, MLWritable): +class StandardScalerModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1526,8 +1529,8 @@ class StandardScalerModel(JavaModel, MLReadable, MLWritable): @inherit_doc -class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, MLReadable, - MLWritable): +class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, JavaMLReadable, + JavaMLWritable): """ .. note:: Experimental @@ -1591,7 +1594,7 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, return StringIndexerModel(java_model) -class StringIndexerModel(JavaModel, MLReadable, MLWritable): +class StringIndexerModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1610,7 +1613,7 @@ class StringIndexerModel(JavaModel, MLReadable, MLWritable): @inherit_doc -class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1664,7 +1667,7 @@ class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWr return self.getOrDefault(self.labels) -class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1751,7 +1754,7 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, M @inherit_doc @ignore_unicode_prefix -class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1806,7 +1809,7 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritab @inherit_doc -class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, MLReadable, MLWritable): +class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1852,7 +1855,7 @@ class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, MLReadable, M @inherit_doc -class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1969,7 +1972,7 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWrit return VectorIndexerModel(java_model) -class VectorIndexerModel(JavaModel, MLReadable, MLWritable): +class VectorIndexerModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -1998,7 +2001,7 @@ class VectorIndexerModel(JavaModel, MLReadable, MLWritable): @inherit_doc -class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -2093,7 +2096,7 @@ class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWri @inherit_doc @ignore_unicode_prefix class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, HasOutputCol, - MLReadable, MLWritable): + JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -2226,7 +2229,7 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has return Word2VecModel(java_model) -class Word2VecModel(JavaModel, MLReadable, MLWritable): +class Word2VecModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -2257,7 +2260,7 @@ class Word2VecModel(JavaModel, MLReadable, MLWritable): @inherit_doc -class PCA(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable): +class PCA(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -2331,7 +2334,7 @@ class PCA(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable): return PCAModel(java_model) -class PCAModel(JavaModel, MLReadable, MLWritable): +class PCAModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -2360,7 +2363,7 @@ class PCAModel(JavaModel, MLReadable, MLWritable): @inherit_doc -class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, MLReadable, MLWritable): +class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -2463,7 +2466,7 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, MLReadable, MLWritabl return RFormulaModel(java_model) -class RFormulaModel(JavaModel, MLReadable, MLWritable): +class RFormulaModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental @@ -2474,8 +2477,8 @@ class RFormulaModel(JavaModel, MLReadable, MLWritable): @inherit_doc -class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, MLReadable, - MLWritable): +class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, JavaMLReadable, + JavaMLWritable): """ .. note:: Experimental @@ -2561,7 +2564,7 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, ML return ChiSqSelectorModel(java_model) -class ChiSqSelectorModel(JavaModel, MLReadable, MLWritable): +class ChiSqSelectorModel(JavaModel, JavaMLReadable, JavaMLWritable): """ .. note:: Experimental diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py index a1658b0a02..2b5504bc29 100644 --- a/python/pyspark/ml/pipeline.py +++ b/python/pyspark/ml/pipeline.py @@ -24,72 +24,31 @@ from pyspark import SparkContext from pyspark import since from pyspark.ml import Estimator, Model, Transformer from pyspark.ml.param import Param, Params -from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader +from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, MLReadable, MLWritable from pyspark.ml.wrapper import JavaWrapper from pyspark.mllib.common import inherit_doc -def _stages_java2py(java_stages): - """ - Transforms the parameter Python stages from a list of Java stages. - :param java_stages: An array of Java stages. - :return: An array of Python stages. - """ - - return [JavaWrapper._transfer_stage_from_java(stage) for stage in java_stages] - - -def _stages_py2java(py_stages, cls): - """ - Transforms the parameter of Python stages to a Java array of Java stages. - :param py_stages: An array of Python stages. - :return: A Java array of Java Stages. - """ - - for stage in py_stages: - assert(isinstance(stage, JavaWrapper), - "Python side implementation is not supported in the meta-PipelineStage currently.") - gateway = SparkContext._gateway - java_stages = gateway.new_array(cls, len(py_stages)) - for idx, stage in enumerate(py_stages): - java_stages[idx] = stage._transfer_stage_to_java() - return java_stages - - @inherit_doc -class PipelineMLWriter(JavaMLWriter, JavaWrapper): +class PipelineMLWriter(JavaMLWriter): """ Private Pipeline utility class that can save ML instances through their Scala implementation. - """ - def __init__(self, instance): - cls = SparkContext._jvm.org.apache.spark.ml.PipelineStage - self._java_obj = self._new_java_obj("org.apache.spark.ml.Pipeline", instance.uid) - self._java_obj.setStages(_stages_py2java(instance.getStages(), cls)) - self._jwrite = self._java_obj.write() + We can currently use JavaMLWriter, rather than MLWriter, since Pipeline implements _to_java. + """ @inherit_doc class PipelineMLReader(JavaMLReader): """ Private utility class that can load Pipeline instances through their Scala implementation. - """ - def load(self, path): - """Load the Pipeline instance from the input path.""" - if not isinstance(path, basestring): - raise TypeError("path should be a basestring, got type %s" % type(path)) - - java_obj = self._jread.load(path) - instance = self._clazz() - instance._resetUid(java_obj.uid()) - instance.setStages(_stages_java2py(java_obj.getStages())) - - return instance + We can currently use JavaMLReader, rather than MLReader, since Pipeline implements _from_java. + """ @inherit_doc -class Pipeline(Estimator): +class Pipeline(Estimator, MLReadable, MLWritable): """ A simple pipeline, which acts as an estimator. A Pipeline consists of a sequence of stages, each of which is either an @@ -206,49 +165,65 @@ class Pipeline(Estimator): @classmethod @since("2.0.0") def read(cls): - """Returns an JavaMLReader instance for this class.""" + """Returns an MLReader instance for this class.""" return PipelineMLReader(cls) @classmethod - @since("2.0.0") - def load(cls, path): - """Reads an ML instance from the input path, a shortcut of `read().load(path)`.""" - return cls.read().load(path) + def _from_java(cls, java_stage): + """ + Given a Java Pipeline, create and return a Python wrapper of it. + Used for ML persistence. + """ + # Create a new instance of this stage. + py_stage = cls() + # Load information from java_stage to the instance. + py_stages = [JavaWrapper._from_java(s) for s in java_stage.getStages()] + py_stage.setStages(py_stages) + py_stage._resetUid(java_stage.uid()) + return py_stage + + def _to_java(self): + """ + Transfer this instance to a Java Pipeline. Used for ML persistence. + + :return: Java object equivalent to this instance. + """ + + gateway = SparkContext._gateway + cls = SparkContext._jvm.org.apache.spark.ml.PipelineStage + java_stages = gateway.new_array(cls, len(self.getStages())) + for idx, stage in enumerate(self.getStages()): + java_stages[idx] = stage._to_java() + + _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.Pipeline", self.uid) + _java_obj.setStages(java_stages) + + return _java_obj @inherit_doc -class PipelineModelMLWriter(JavaMLWriter, JavaWrapper): +class PipelineModelMLWriter(JavaMLWriter): """ Private PipelineModel utility class that can save ML instances through their Scala implementation. - """ - def __init__(self, instance): - cls = SparkContext._jvm.org.apache.spark.ml.Transformer - self._java_obj = self._new_java_obj("org.apache.spark.ml.PipelineModel", - instance.uid, - _stages_py2java(instance.stages, cls)) - self._jwrite = self._java_obj.write() + We can (currently) use JavaMLWriter, rather than MLWriter, since PipelineModel implements + _to_java. + """ @inherit_doc class PipelineModelMLReader(JavaMLReader): """ Private utility class that can load PipelineModel instances through their Scala implementation. - """ - def load(self, path): - """Load the PipelineModel instance from the input path.""" - if not isinstance(path, basestring): - raise TypeError("path should be a basestring, got type %s" % type(path)) - java_obj = self._jread.load(path) - instance = self._clazz(_stages_java2py(java_obj.stages())) - instance._resetUid(java_obj.uid()) - return instance + We can currently use JavaMLReader, rather than MLReader, since PipelineModel implements + _from_java. + """ @inherit_doc -class PipelineModel(Model): +class PipelineModel(Model, MLReadable, MLWritable): """ Represents a compiled pipeline with transformers and fitted models. @@ -294,7 +269,32 @@ class PipelineModel(Model): return PipelineModelMLReader(cls) @classmethod - @since("2.0.0") - def load(cls, path): - """Reads an ML instance from the input path, a shortcut of `read().load(path)`.""" - return cls.read().load(path) + def _from_java(cls, java_stage): + """ + Given a Java PipelineModel, create and return a Python wrapper of it. + Used for ML persistence. + """ + # Load information from java_stage to the instance. + py_stages = [JavaWrapper._from_java(s) for s in java_stage.stages()] + # Create a new instance of this stage. + py_stage = cls(py_stages) + py_stage._resetUid(java_stage.uid()) + return py_stage + + def _to_java(self): + """ + Transfer this instance to a Java PipelineModel. Used for ML persistence. + + :return: Java object equivalent to this instance. + """ + + gateway = SparkContext._gateway + cls = SparkContext._jvm.org.apache.spark.ml.Transformer + java_stages = gateway.new_array(cls, len(self.stages)) + for idx, stage in enumerate(self.stages): + java_stages[idx] = stage._to_java() + + _java_obj =\ + JavaWrapper._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages) + + return _java_obj diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py index 2b605e5c50..de4c2675ed 100644 --- a/python/pyspark/ml/recommendation.py +++ b/python/pyspark/ml/recommendation.py @@ -27,7 +27,7 @@ __all__ = ['ALS', 'ALSModel'] @inherit_doc class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, HasRegParam, HasSeed, - MLWritable, MLReadable): + JavaMLWritable, JavaMLReadable): """ Alternating Least Squares (ALS) matrix factorization. @@ -289,7 +289,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha return self.getOrDefault(self.nonnegative) -class ALSModel(JavaModel, MLWritable, MLReadable): +class ALSModel(JavaModel, JavaMLWritable, JavaMLReadable): """ Model fitted by ALS. diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py index 6e23393f91..664a44bc47 100644 --- a/python/pyspark/ml/regression.py +++ b/python/pyspark/ml/regression.py @@ -35,7 +35,7 @@ __all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel', @inherit_doc class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter, HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept, - HasStandardization, HasSolver, HasWeightCol, MLWritable, MLReadable): + HasStandardization, HasSolver, HasWeightCol, JavaMLWritable, JavaMLReadable): """ Linear regression. @@ -118,7 +118,7 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction return LinearRegressionModel(java_model) -class LinearRegressionModel(JavaModel, MLWritable, MLReadable): +class LinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): """ Model fitted by LinearRegression. @@ -154,7 +154,7 @@ class LinearRegressionModel(JavaModel, MLWritable, MLReadable): @inherit_doc class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, - HasWeightCol, MLWritable, MLReadable): + HasWeightCol, JavaMLWritable, JavaMLReadable): """ .. note:: Experimental @@ -249,7 +249,7 @@ class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti return self.getOrDefault(self.featureIndex) -class IsotonicRegressionModel(JavaModel, MLWritable, MLReadable): +class IsotonicRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): """ .. note:: Experimental @@ -719,7 +719,7 @@ class GBTRegressionModel(TreeEnsembleModels): @inherit_doc class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, - HasFitIntercept, HasMaxIter, HasTol, MLWritable, MLReadable): + HasFitIntercept, HasMaxIter, HasTol, JavaMLWritable, JavaMLReadable): """ Accelerated Failure Time (AFT) Model Survival Regression @@ -857,7 +857,7 @@ class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi return self.getOrDefault(self.quantilesCol) -class AFTSurvivalRegressionModel(JavaModel, MLWritable, MLReadable): +class AFTSurvivalRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable): """ Model fitted by AFTSurvivalRegression. diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 9783ce7e77..211248e8b2 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -47,6 +47,7 @@ from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed from pyspark.ml.regression import LinearRegression from pyspark.ml.tuning import * from pyspark.ml.util import keyword_only +from pyspark.ml.wrapper import JavaWrapper from pyspark.mllib.linalg import DenseVector from pyspark.sql import DataFrame, SQLContext, Row from pyspark.sql.functions import rand @@ -517,7 +518,39 @@ class PersistenceTest(PySparkTestCase): except OSError: pass + def _compare_pipelines(self, m1, m2): + """ + Compare 2 ML types, asserting that they are equivalent. + This currently supports: + - basic types + - Pipeline, PipelineModel + This checks: + - uid + - type + - Param values and parents + """ + self.assertEqual(m1.uid, m2.uid) + self.assertEqual(type(m1), type(m2)) + if isinstance(m1, JavaWrapper): + self.assertEqual(len(m1.params), len(m2.params)) + for p in m1.params: + self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p)) + self.assertEqual(p.parent, m2.getParam(p.name).parent) + elif isinstance(m1, Pipeline): + self.assertEqual(len(m1.getStages()), len(m2.getStages())) + for s1, s2 in zip(m1.getStages(), m2.getStages()): + self._compare_pipelines(s1, s2) + elif isinstance(m1, PipelineModel): + self.assertEqual(len(m1.stages), len(m2.stages)) + for s1, s2 in zip(m1.stages, m2.stages): + self._compare_pipelines(s1, s2) + else: + raise RuntimeError("_compare_pipelines does not yet support type: %s" % type(m1)) + def test_pipeline_persistence(self): + """ + Pipeline[HashingTF, PCA] + """ sqlContext = SQLContext(self.sc) temp_path = tempfile.mkdtemp() @@ -527,33 +560,46 @@ class PersistenceTest(PySparkTestCase): pca = PCA(k=2, inputCol="features", outputCol="pca_features") pl = Pipeline(stages=[tf, pca]) model = pl.fit(df) + pipeline_path = temp_path + "/pipeline" pl.save(pipeline_path) loaded_pipeline = Pipeline.load(pipeline_path) - self.assertEqual(loaded_pipeline.uid, pl.uid) - self.assertEqual(len(loaded_pipeline.getStages()), 2) + self._compare_pipelines(pl, loaded_pipeline) + + model_path = temp_path + "/pipeline-model" + model.save(model_path) + loaded_model = PipelineModel.load(model_path) + self._compare_pipelines(model, loaded_model) + finally: + try: + rmtree(temp_path) + except OSError: + pass - [loaded_tf, loaded_pca] = loaded_pipeline.getStages() - self.assertIsInstance(loaded_tf, HashingTF) - self.assertEqual(loaded_tf.uid, tf.uid) - param = loaded_tf.getParam("numFeatures") - self.assertEqual(loaded_tf.getOrDefault(param), tf.getOrDefault(param)) + def test_nested_pipeline_persistence(self): + """ + Pipeline[HashingTF, Pipeline[PCA]] + """ + sqlContext = SQLContext(self.sc) + temp_path = tempfile.mkdtemp() + + try: + df = sqlContext.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"]) + tf = HashingTF(numFeatures=10, inputCol="words", outputCol="features") + pca = PCA(k=2, inputCol="features", outputCol="pca_features") + p0 = Pipeline(stages=[pca]) + pl = Pipeline(stages=[tf, p0]) + model = pl.fit(df) - self.assertIsInstance(loaded_pca, PCA) - self.assertEqual(loaded_pca.uid, pca.uid) - self.assertEqual(loaded_pca.getK(), pca.getK()) + pipeline_path = temp_path + "/pipeline" + pl.save(pipeline_path) + loaded_pipeline = Pipeline.load(pipeline_path) + self._compare_pipelines(pl, loaded_pipeline) model_path = temp_path + "/pipeline-model" model.save(model_path) loaded_model = PipelineModel.load(model_path) - [model_tf, model_pca] = model.stages - [loaded_model_tf, loaded_model_pca] = loaded_model.stages - self.assertEqual(model_tf.uid, loaded_model_tf.uid) - self.assertEqual(model_tf.getOrDefault(param), loaded_model_tf.getOrDefault(param)) - - self.assertEqual(model_pca.uid, loaded_model_pca.uid) - self.assertEqual(model_pca.pc, loaded_model_pca.pc) - self.assertEqual(model_pca.explainedVariance, loaded_model_pca.explainedVariance) + self._compare_pipelines(model, loaded_model) finally: try: rmtree(temp_path) diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py index 42801c91bb..6703851262 100644 --- a/python/pyspark/ml/util.py +++ b/python/pyspark/ml/util.py @@ -74,18 +74,38 @@ class Identifiable(object): @inherit_doc -class JavaMLWriter(object): +class MLWriter(object): """ .. note:: Experimental - Utility class that can save ML instances through their Scala implementation. + Utility class that can save ML instances. .. versionadded:: 2.0.0 """ + def save(self, path): + """Save the ML instance to the input path.""" + raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self)) + + def overwrite(self): + """Overwrites if the output path already exists.""" + raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self)) + + def context(self, sqlContext): + """Sets the SQL context to use for saving.""" + raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self)) + + +@inherit_doc +class JavaMLWriter(MLWriter): + """ + (Private) Specialization of :py:class:`MLWriter` for :py:class:`JavaWrapper` types + """ + def __init__(self, instance): - instance._transfer_params_to_java() - self._jwrite = instance._java_obj.write() + super(JavaMLWriter, self).__init__() + _java_obj = instance._to_java() + self._jwrite = _java_obj.write() def save(self, path): """Save the ML instance to the input path.""" @@ -109,14 +129,14 @@ class MLWritable(object): """ .. note:: Experimental - Mixin for ML instances that provide JavaMLWriter. + Mixin for ML instances that provide :py:class:`MLWriter`. .. versionadded:: 2.0.0 """ def write(self): """Returns an JavaMLWriter instance for this ML instance.""" - return JavaMLWriter(self) + raise NotImplementedError("MLWritable is not yet implemented for type: %r" % type(self)) def save(self, path): """Save this ML instance to the given path, a shortcut of `write().save(path)`.""" @@ -124,15 +144,41 @@ class MLWritable(object): @inherit_doc -class JavaMLReader(object): +class JavaMLWritable(MLWritable): + """ + (Private) Mixin for ML instances that provide :py:class:`JavaMLWriter`. + """ + + def write(self): + """Returns an JavaMLWriter instance for this ML instance.""" + return JavaMLWriter(self) + + +@inherit_doc +class MLReader(object): """ .. note:: Experimental - Utility class that can load ML instances through their Scala implementation. + Utility class that can load ML instances. .. versionadded:: 2.0.0 """ + def load(self, path): + """Load the ML instance from the input path.""" + raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self)) + + def context(self, sqlContext): + """Sets the SQL context to use for loading.""" + raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self)) + + +@inherit_doc +class JavaMLReader(MLReader): + """ + (Private) Specialization of :py:class:`MLReader` for :py:class:`JavaWrapper` types + """ + def __init__(self, clazz): self._clazz = clazz self._jread = self._load_java_obj(clazz).read() @@ -142,11 +188,10 @@ class JavaMLReader(object): if not isinstance(path, basestring): raise TypeError("path should be a basestring, got type %s" % type(path)) java_obj = self._jread.load(path) - instance = self._clazz() - instance._java_obj = java_obj - instance._resetUid(java_obj.uid()) - instance._transfer_params_from_java() - return instance + if not hasattr(self._clazz, "_from_java"): + raise NotImplementedError("This Java ML type cannot be loaded into Python currently: %r" + % self._clazz) + return self._clazz._from_java(java_obj) def context(self, sqlContext): """Sets the SQL context to use for loading.""" @@ -164,7 +209,7 @@ class JavaMLReader(object): if clazz.__name__ in ("Pipeline", "PipelineModel"): # Remove the last package name "pipeline" for Pipeline and PipelineModel. java_package = ".".join(java_package.split(".")[0:-1]) - return ".".join([java_package, clazz.__name__]) + return java_package + "." + clazz.__name__ @classmethod def _load_java_obj(cls, clazz): @@ -181,7 +226,7 @@ class MLReadable(object): """ .. note:: Experimental - Mixin for instances that provide JavaMLReader. + Mixin for instances that provide :py:class:`MLReader`. .. versionadded:: 2.0.0 """ @@ -189,9 +234,21 @@ class MLReadable(object): @classmethod def read(cls): """Returns an JavaMLReader instance for this class.""" - return JavaMLReader(cls) + raise NotImplementedError("MLReadable.read() not implemented for type: %r" % cls) @classmethod def load(cls, path): """Reads an ML instance from the input path, a shortcut of `read().load(path)`.""" return cls.read().load(path) + + +@inherit_doc +class JavaMLReadable(MLReadable): + """ + (Private) Mixin for instances that provide JavaMLReader. + """ + + @classmethod + def read(cls): + """Returns an JavaMLReader instance for this class.""" + return JavaMLReader(cls) diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 37dcb23b67..35b0eba926 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -95,12 +95,26 @@ class JavaWrapper(Params): """ return _jvm().org.apache.spark.ml.param.ParamMap() - def _transfer_stage_to_java(self): + def _to_java(self): + """ + Transfer this instance's Params to the wrapped Java object, and return the Java object. + Used for ML persistence. + + Meta-algorithms such as Pipeline should override this method. + + :return: Java object equivalent to this instance. + """ self._transfer_params_to_java() return self._java_obj @staticmethod - def _transfer_stage_from_java(java_stage): + def _from_java(java_stage): + """ + Given a Java object, create and return a Python wrapper of it. + Used for ML persistence. + + Meta-algorithms such as Pipeline should override this method as a classmethod. + """ def __get_class(clazz): """ Loads Python class from its name. @@ -113,13 +127,18 @@ class JavaWrapper(Params): return m stage_name = java_stage.getClass().getName().replace("org.apache.spark", "pyspark") # Generate a default new instance from the stage_name class. - py_stage = __get_class(stage_name)() - assert(isinstance(py_stage, JavaWrapper), - "Python side implementation is not supported in the meta-PipelineStage currently.") - # Load information from java_stage to the instance. - py_stage._java_obj = java_stage - py_stage._resetUid(java_stage.uid()) - py_stage._transfer_params_from_java() + py_type = __get_class(stage_name) + if issubclass(py_type, JavaWrapper): + # Load information from java_stage to the instance. + py_stage = py_type() + py_stage._java_obj = java_stage + py_stage._resetUid(java_stage.uid()) + py_stage._transfer_params_from_java() + elif hasattr(py_type, "_from_java"): + py_stage = py_type._from_java(java_stage) + else: + raise NotImplementedError("This Java stage cannot be loaded into Python currently: %r" + % stage_name) return py_stage |