aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2016-03-22 12:11:23 -0700
committerXiangrui Meng <meng@databricks.com>2016-03-22 12:11:37 -0700
commit7e3423b9c03c9812d404134c3d204c4cfea87721 (patch)
treeb922610e318774c1db7da6549ee0932b21fe3090
parent297c20226d3330309c9165d789749458f8f4ab8e (diff)
downloadspark-7e3423b9c03c9812d404134c3d204c4cfea87721.tar.gz
spark-7e3423b9c03c9812d404134c3d204c4cfea87721.tar.bz2
spark-7e3423b9c03c9812d404134c3d204c4cfea87721.zip
[SPARK-13951][ML][PYTHON] Nested Pipeline persistence
Adds support for saving and loading nested ML Pipelines from Python. Pipeline and PipelineModel do not extend JavaWrapper, but they are able to utilize the JavaMLWriter, JavaMLReader implementations. Also: * Separates out interfaces from Java wrapper implementations for MLWritable, MLReadable, MLWriter, MLReader. * Moves methods _stages_java2py, _stages_py2java into Pipeline, PipelineModel as _transfer_stage_from_java, _transfer_stage_to_java Added new unit test for nested Pipelines. Abstracted validity check into a helper method for the 2 unit tests. Author: Joseph K. Bradley <joseph@databricks.com> Closes #11866 from jkbradley/nested-pipeline-io. Closes #11835
-rw-r--r--python/pyspark/ml/classification.py8
-rw-r--r--python/pyspark/ml/clustering.py4
-rw-r--r--python/pyspark/ml/feature.py89
-rw-r--r--python/pyspark/ml/pipeline.py150
-rw-r--r--python/pyspark/ml/recommendation.py4
-rw-r--r--python/pyspark/ml/regression.py12
-rw-r--r--python/pyspark/ml/tests.py82
-rw-r--r--python/pyspark/ml/util.py89
-rw-r--r--python/pyspark/ml/wrapper.py37
9 files changed, 300 insertions, 175 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 16ad76483d..8075108114 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -38,7 +38,7 @@ __all__ = ['LogisticRegression', 'LogisticRegressionModel',
class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol,
HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds,
- HasWeightCol, MLWritable, MLReadable):
+ HasWeightCol, JavaMLWritable, JavaMLReadable):
"""
Logistic regression.
Currently, this class only supports binary classification.
@@ -198,7 +198,7 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
" threshold (%g) and thresholds (equivalent to %g)" % (t2, t))
-class LogisticRegressionModel(JavaModel, MLWritable, MLReadable):
+class LogisticRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by LogisticRegression.
@@ -601,7 +601,7 @@ class GBTClassificationModel(TreeEnsembleModels):
@inherit_doc
class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol,
- HasRawPredictionCol, MLWritable, MLReadable):
+ HasRawPredictionCol, JavaMLWritable, JavaMLReadable):
"""
Naive Bayes Classifiers.
It supports both Multinomial and Bernoulli NB. Multinomial NB
@@ -720,7 +720,7 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H
return self.getOrDefault(self.modelType)
-class NaiveBayesModel(JavaModel, MLWritable, MLReadable):
+class NaiveBayesModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by NaiveBayes.
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index 1cea477acb..2db5b82c44 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -25,7 +25,7 @@ __all__ = ['BisectingKMeans', 'BisectingKMeansModel',
'KMeans', 'KMeansModel']
-class KMeansModel(JavaModel, MLWritable, MLReadable):
+class KMeansModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by KMeans.
@@ -48,7 +48,7 @@ class KMeansModel(JavaModel, MLWritable, MLReadable):
@inherit_doc
class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed,
- MLWritable, MLReadable):
+ JavaMLWritable, JavaMLReadable):
"""
K-means clustering with support for multiple parallel runs and a k-means++ like initialization
mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested,
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 3182faac0d..16cb9d1db3 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -22,7 +22,7 @@ if sys.version > '3':
from pyspark import since
from pyspark.rdd import ignore_unicode_prefix
from pyspark.ml.param.shared import *
-from pyspark.ml.util import keyword_only, MLReadable, MLWritable
+from pyspark.ml.util import keyword_only, JavaMLReadable, JavaMLWritable
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer, _jvm
from pyspark.mllib.common import inherit_doc
from pyspark.mllib.linalg import _convert_to_vector
@@ -58,7 +58,7 @@ __all__ = ['Binarizer',
@inherit_doc
-class Binarizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
+class Binarizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -123,7 +123,7 @@ class Binarizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritab
@inherit_doc
-class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
+class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -198,7 +198,7 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWrita
@inherit_doc
-class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable):
+class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -324,7 +324,7 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWr
return CountVectorizerModel(java_model)
-class CountVectorizerModel(JavaModel, MLReadable, MLWritable):
+class CountVectorizerModel(JavaModel, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -343,7 +343,7 @@ class CountVectorizerModel(JavaModel, MLReadable, MLWritable):
@inherit_doc
-class DCT(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
+class DCT(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -415,7 +415,8 @@ class DCT(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
@inherit_doc
-class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
+class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable,
+ JavaMLWritable):
"""
.. note:: Experimental
@@ -481,7 +482,8 @@ class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol, MLReadable,
@inherit_doc
-class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, MLReadable, MLWritable):
+class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, JavaMLReadable,
+ JavaMLWritable):
"""
.. note:: Experimental
@@ -529,7 +531,7 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, MLRe
@inherit_doc
-class IDF(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable):
+class IDF(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -604,7 +606,7 @@ class IDF(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable):
return IDFModel(java_model)
-class IDFModel(JavaModel, MLReadable, MLWritable):
+class IDFModel(JavaModel, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -615,7 +617,7 @@ class IDFModel(JavaModel, MLReadable, MLWritable):
@inherit_doc
-class MaxAbsScaler(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable):
+class MaxAbsScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -676,7 +678,7 @@ class MaxAbsScaler(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWrita
return MaxAbsScalerModel(java_model)
-class MaxAbsScalerModel(JavaModel, MLReadable, MLWritable):
+class MaxAbsScalerModel(JavaModel, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -695,7 +697,7 @@ class MaxAbsScalerModel(JavaModel, MLReadable, MLWritable):
@inherit_doc
-class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable):
+class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -802,7 +804,7 @@ class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWrita
return MinMaxScalerModel(java_model)
-class MinMaxScalerModel(JavaModel, MLReadable, MLWritable):
+class MinMaxScalerModel(JavaModel, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -830,7 +832,7 @@ class MinMaxScalerModel(JavaModel, MLReadable, MLWritable):
@inherit_doc
@ignore_unicode_prefix
-class NGram(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
+class NGram(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -908,7 +910,7 @@ class NGram(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
@inherit_doc
-class Normalizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
+class Normalizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -974,7 +976,7 @@ class Normalizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWrita
@inherit_doc
-class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
+class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -1056,7 +1058,8 @@ class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWr
@inherit_doc
-class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
+class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable,
+ JavaMLWritable):
"""
.. note:: Experimental
@@ -1123,8 +1126,8 @@ class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, MLReadable
@inherit_doc
-class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasSeed, MLReadable,
- MLWritable):
+class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasSeed, JavaMLReadable,
+ JavaMLWritable):
"""
.. note:: Experimental
@@ -1213,7 +1216,7 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasSeed, MLR
@inherit_doc
@ignore_unicode_prefix
-class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
+class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -1345,7 +1348,7 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLW
@inherit_doc
-class SQLTransformer(JavaTransformer, MLReadable, MLWritable):
+class SQLTransformer(JavaTransformer, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -1406,7 +1409,7 @@ class SQLTransformer(JavaTransformer, MLReadable, MLWritable):
@inherit_doc
-class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable):
+class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -1499,7 +1502,7 @@ class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWri
return StandardScalerModel(java_model)
-class StandardScalerModel(JavaModel, MLReadable, MLWritable):
+class StandardScalerModel(JavaModel, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -1526,8 +1529,8 @@ class StandardScalerModel(JavaModel, MLReadable, MLWritable):
@inherit_doc
-class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, MLReadable,
- MLWritable):
+class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, JavaMLReadable,
+ JavaMLWritable):
"""
.. note:: Experimental
@@ -1591,7 +1594,7 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid,
return StringIndexerModel(java_model)
-class StringIndexerModel(JavaModel, MLReadable, MLWritable):
+class StringIndexerModel(JavaModel, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -1610,7 +1613,7 @@ class StringIndexerModel(JavaModel, MLReadable, MLWritable):
@inherit_doc
-class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
+class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -1664,7 +1667,7 @@ class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWr
return self.getOrDefault(self.labels)
-class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
+class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -1751,7 +1754,7 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, M
@inherit_doc
@ignore_unicode_prefix
-class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
+class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -1806,7 +1809,7 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritab
@inherit_doc
-class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, MLReadable, MLWritable):
+class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -1852,7 +1855,7 @@ class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, MLReadable, M
@inherit_doc
-class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable):
+class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -1969,7 +1972,7 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWrit
return VectorIndexerModel(java_model)
-class VectorIndexerModel(JavaModel, MLReadable, MLWritable):
+class VectorIndexerModel(JavaModel, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -1998,7 +2001,7 @@ class VectorIndexerModel(JavaModel, MLReadable, MLWritable):
@inherit_doc
-class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
+class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -2093,7 +2096,7 @@ class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWri
@inherit_doc
@ignore_unicode_prefix
class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, HasOutputCol,
- MLReadable, MLWritable):
+ JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -2226,7 +2229,7 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has
return Word2VecModel(java_model)
-class Word2VecModel(JavaModel, MLReadable, MLWritable):
+class Word2VecModel(JavaModel, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -2257,7 +2260,7 @@ class Word2VecModel(JavaModel, MLReadable, MLWritable):
@inherit_doc
-class PCA(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable):
+class PCA(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -2331,7 +2334,7 @@ class PCA(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable):
return PCAModel(java_model)
-class PCAModel(JavaModel, MLReadable, MLWritable):
+class PCAModel(JavaModel, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -2360,7 +2363,7 @@ class PCAModel(JavaModel, MLReadable, MLWritable):
@inherit_doc
-class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, MLReadable, MLWritable):
+class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -2463,7 +2466,7 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, MLReadable, MLWritabl
return RFormulaModel(java_model)
-class RFormulaModel(JavaModel, MLReadable, MLWritable):
+class RFormulaModel(JavaModel, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -2474,8 +2477,8 @@ class RFormulaModel(JavaModel, MLReadable, MLWritable):
@inherit_doc
-class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, MLReadable,
- MLWritable):
+class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, JavaMLReadable,
+ JavaMLWritable):
"""
.. note:: Experimental
@@ -2561,7 +2564,7 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, ML
return ChiSqSelectorModel(java_model)
-class ChiSqSelectorModel(JavaModel, MLReadable, MLWritable):
+class ChiSqSelectorModel(JavaModel, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
index a1658b0a02..2b5504bc29 100644
--- a/python/pyspark/ml/pipeline.py
+++ b/python/pyspark/ml/pipeline.py
@@ -24,72 +24,31 @@ from pyspark import SparkContext
from pyspark import since
from pyspark.ml import Estimator, Model, Transformer
from pyspark.ml.param import Param, Params
-from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader
+from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, MLReadable, MLWritable
from pyspark.ml.wrapper import JavaWrapper
from pyspark.mllib.common import inherit_doc
-def _stages_java2py(java_stages):
- """
- Transforms the parameter Python stages from a list of Java stages.
- :param java_stages: An array of Java stages.
- :return: An array of Python stages.
- """
-
- return [JavaWrapper._transfer_stage_from_java(stage) for stage in java_stages]
-
-
-def _stages_py2java(py_stages, cls):
- """
- Transforms the parameter of Python stages to a Java array of Java stages.
- :param py_stages: An array of Python stages.
- :return: A Java array of Java Stages.
- """
-
- for stage in py_stages:
- assert(isinstance(stage, JavaWrapper),
- "Python side implementation is not supported in the meta-PipelineStage currently.")
- gateway = SparkContext._gateway
- java_stages = gateway.new_array(cls, len(py_stages))
- for idx, stage in enumerate(py_stages):
- java_stages[idx] = stage._transfer_stage_to_java()
- return java_stages
-
-
@inherit_doc
-class PipelineMLWriter(JavaMLWriter, JavaWrapper):
+class PipelineMLWriter(JavaMLWriter):
"""
Private Pipeline utility class that can save ML instances through their Scala implementation.
- """
- def __init__(self, instance):
- cls = SparkContext._jvm.org.apache.spark.ml.PipelineStage
- self._java_obj = self._new_java_obj("org.apache.spark.ml.Pipeline", instance.uid)
- self._java_obj.setStages(_stages_py2java(instance.getStages(), cls))
- self._jwrite = self._java_obj.write()
+ We can currently use JavaMLWriter, rather than MLWriter, since Pipeline implements _to_java.
+ """
@inherit_doc
class PipelineMLReader(JavaMLReader):
"""
Private utility class that can load Pipeline instances through their Scala implementation.
- """
- def load(self, path):
- """Load the Pipeline instance from the input path."""
- if not isinstance(path, basestring):
- raise TypeError("path should be a basestring, got type %s" % type(path))
-
- java_obj = self._jread.load(path)
- instance = self._clazz()
- instance._resetUid(java_obj.uid())
- instance.setStages(_stages_java2py(java_obj.getStages()))
-
- return instance
+ We can currently use JavaMLReader, rather than MLReader, since Pipeline implements _from_java.
+ """
@inherit_doc
-class Pipeline(Estimator):
+class Pipeline(Estimator, MLReadable, MLWritable):
"""
A simple pipeline, which acts as an estimator. A Pipeline consists
of a sequence of stages, each of which is either an
@@ -206,49 +165,65 @@ class Pipeline(Estimator):
@classmethod
@since("2.0.0")
def read(cls):
- """Returns an JavaMLReader instance for this class."""
+ """Returns an MLReader instance for this class."""
return PipelineMLReader(cls)
@classmethod
- @since("2.0.0")
- def load(cls, path):
- """Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
- return cls.read().load(path)
+ def _from_java(cls, java_stage):
+ """
+ Given a Java Pipeline, create and return a Python wrapper of it.
+ Used for ML persistence.
+ """
+ # Create a new instance of this stage.
+ py_stage = cls()
+ # Load information from java_stage to the instance.
+ py_stages = [JavaWrapper._from_java(s) for s in java_stage.getStages()]
+ py_stage.setStages(py_stages)
+ py_stage._resetUid(java_stage.uid())
+ return py_stage
+
+ def _to_java(self):
+ """
+ Transfer this instance to a Java Pipeline. Used for ML persistence.
+
+ :return: Java object equivalent to this instance.
+ """
+
+ gateway = SparkContext._gateway
+ cls = SparkContext._jvm.org.apache.spark.ml.PipelineStage
+ java_stages = gateway.new_array(cls, len(self.getStages()))
+ for idx, stage in enumerate(self.getStages()):
+ java_stages[idx] = stage._to_java()
+
+ _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.Pipeline", self.uid)
+ _java_obj.setStages(java_stages)
+
+ return _java_obj
@inherit_doc
-class PipelineModelMLWriter(JavaMLWriter, JavaWrapper):
+class PipelineModelMLWriter(JavaMLWriter):
"""
Private PipelineModel utility class that can save ML instances through their Scala
implementation.
- """
- def __init__(self, instance):
- cls = SparkContext._jvm.org.apache.spark.ml.Transformer
- self._java_obj = self._new_java_obj("org.apache.spark.ml.PipelineModel",
- instance.uid,
- _stages_py2java(instance.stages, cls))
- self._jwrite = self._java_obj.write()
+ We can (currently) use JavaMLWriter, rather than MLWriter, since PipelineModel implements
+ _to_java.
+ """
@inherit_doc
class PipelineModelMLReader(JavaMLReader):
"""
Private utility class that can load PipelineModel instances through their Scala implementation.
- """
- def load(self, path):
- """Load the PipelineModel instance from the input path."""
- if not isinstance(path, basestring):
- raise TypeError("path should be a basestring, got type %s" % type(path))
- java_obj = self._jread.load(path)
- instance = self._clazz(_stages_java2py(java_obj.stages()))
- instance._resetUid(java_obj.uid())
- return instance
+ We can currently use JavaMLReader, rather than MLReader, since PipelineModel implements
+ _from_java.
+ """
@inherit_doc
-class PipelineModel(Model):
+class PipelineModel(Model, MLReadable, MLWritable):
"""
Represents a compiled pipeline with transformers and fitted models.
@@ -294,7 +269,32 @@ class PipelineModel(Model):
return PipelineModelMLReader(cls)
@classmethod
- @since("2.0.0")
- def load(cls, path):
- """Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
- return cls.read().load(path)
+ def _from_java(cls, java_stage):
+ """
+ Given a Java PipelineModel, create and return a Python wrapper of it.
+ Used for ML persistence.
+ """
+ # Load information from java_stage to the instance.
+ py_stages = [JavaWrapper._from_java(s) for s in java_stage.stages()]
+ # Create a new instance of this stage.
+ py_stage = cls(py_stages)
+ py_stage._resetUid(java_stage.uid())
+ return py_stage
+
+ def _to_java(self):
+ """
+ Transfer this instance to a Java PipelineModel. Used for ML persistence.
+
+ :return: Java object equivalent to this instance.
+ """
+
+ gateway = SparkContext._gateway
+ cls = SparkContext._jvm.org.apache.spark.ml.Transformer
+ java_stages = gateway.new_array(cls, len(self.stages))
+ for idx, stage in enumerate(self.stages):
+ java_stages[idx] = stage._to_java()
+
+ _java_obj =\
+ JavaWrapper._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages)
+
+ return _java_obj
diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py
index 2b605e5c50..de4c2675ed 100644
--- a/python/pyspark/ml/recommendation.py
+++ b/python/pyspark/ml/recommendation.py
@@ -27,7 +27,7 @@ __all__ = ['ALS', 'ALSModel']
@inherit_doc
class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, HasRegParam, HasSeed,
- MLWritable, MLReadable):
+ JavaMLWritable, JavaMLReadable):
"""
Alternating Least Squares (ALS) matrix factorization.
@@ -289,7 +289,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
return self.getOrDefault(self.nonnegative)
-class ALSModel(JavaModel, MLWritable, MLReadable):
+class ALSModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by ALS.
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 6e23393f91..664a44bc47 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -35,7 +35,7 @@ __all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel',
@inherit_doc
class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept,
- HasStandardization, HasSolver, HasWeightCol, MLWritable, MLReadable):
+ HasStandardization, HasSolver, HasWeightCol, JavaMLWritable, JavaMLReadable):
"""
Linear regression.
@@ -118,7 +118,7 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
return LinearRegressionModel(java_model)
-class LinearRegressionModel(JavaModel, MLWritable, MLReadable):
+class LinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by LinearRegression.
@@ -154,7 +154,7 @@ class LinearRegressionModel(JavaModel, MLWritable, MLReadable):
@inherit_doc
class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
- HasWeightCol, MLWritable, MLReadable):
+ HasWeightCol, JavaMLWritable, JavaMLReadable):
"""
.. note:: Experimental
@@ -249,7 +249,7 @@ class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
return self.getOrDefault(self.featureIndex)
-class IsotonicRegressionModel(JavaModel, MLWritable, MLReadable):
+class IsotonicRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
.. note:: Experimental
@@ -719,7 +719,7 @@ class GBTRegressionModel(TreeEnsembleModels):
@inherit_doc
class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
- HasFitIntercept, HasMaxIter, HasTol, MLWritable, MLReadable):
+ HasFitIntercept, HasMaxIter, HasTol, JavaMLWritable, JavaMLReadable):
"""
Accelerated Failure Time (AFT) Model Survival Regression
@@ -857,7 +857,7 @@ class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
return self.getOrDefault(self.quantilesCol)
-class AFTSurvivalRegressionModel(JavaModel, MLWritable, MLReadable):
+class AFTSurvivalRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by AFTSurvivalRegression.
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 9783ce7e77..211248e8b2 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -47,6 +47,7 @@ from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed
from pyspark.ml.regression import LinearRegression
from pyspark.ml.tuning import *
from pyspark.ml.util import keyword_only
+from pyspark.ml.wrapper import JavaWrapper
from pyspark.mllib.linalg import DenseVector
from pyspark.sql import DataFrame, SQLContext, Row
from pyspark.sql.functions import rand
@@ -517,7 +518,39 @@ class PersistenceTest(PySparkTestCase):
except OSError:
pass
+ def _compare_pipelines(self, m1, m2):
+ """
+ Compare 2 ML types, asserting that they are equivalent.
+ This currently supports:
+ - basic types
+ - Pipeline, PipelineModel
+ This checks:
+ - uid
+ - type
+ - Param values and parents
+ """
+ self.assertEqual(m1.uid, m2.uid)
+ self.assertEqual(type(m1), type(m2))
+ if isinstance(m1, JavaWrapper):
+ self.assertEqual(len(m1.params), len(m2.params))
+ for p in m1.params:
+ self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p))
+ self.assertEqual(p.parent, m2.getParam(p.name).parent)
+ elif isinstance(m1, Pipeline):
+ self.assertEqual(len(m1.getStages()), len(m2.getStages()))
+ for s1, s2 in zip(m1.getStages(), m2.getStages()):
+ self._compare_pipelines(s1, s2)
+ elif isinstance(m1, PipelineModel):
+ self.assertEqual(len(m1.stages), len(m2.stages))
+ for s1, s2 in zip(m1.stages, m2.stages):
+ self._compare_pipelines(s1, s2)
+ else:
+ raise RuntimeError("_compare_pipelines does not yet support type: %s" % type(m1))
+
def test_pipeline_persistence(self):
+ """
+ Pipeline[HashingTF, PCA]
+ """
sqlContext = SQLContext(self.sc)
temp_path = tempfile.mkdtemp()
@@ -527,33 +560,46 @@ class PersistenceTest(PySparkTestCase):
pca = PCA(k=2, inputCol="features", outputCol="pca_features")
pl = Pipeline(stages=[tf, pca])
model = pl.fit(df)
+
pipeline_path = temp_path + "/pipeline"
pl.save(pipeline_path)
loaded_pipeline = Pipeline.load(pipeline_path)
- self.assertEqual(loaded_pipeline.uid, pl.uid)
- self.assertEqual(len(loaded_pipeline.getStages()), 2)
+ self._compare_pipelines(pl, loaded_pipeline)
+
+ model_path = temp_path + "/pipeline-model"
+ model.save(model_path)
+ loaded_model = PipelineModel.load(model_path)
+ self._compare_pipelines(model, loaded_model)
+ finally:
+ try:
+ rmtree(temp_path)
+ except OSError:
+ pass
- [loaded_tf, loaded_pca] = loaded_pipeline.getStages()
- self.assertIsInstance(loaded_tf, HashingTF)
- self.assertEqual(loaded_tf.uid, tf.uid)
- param = loaded_tf.getParam("numFeatures")
- self.assertEqual(loaded_tf.getOrDefault(param), tf.getOrDefault(param))
+ def test_nested_pipeline_persistence(self):
+ """
+ Pipeline[HashingTF, Pipeline[PCA]]
+ """
+ sqlContext = SQLContext(self.sc)
+ temp_path = tempfile.mkdtemp()
+
+ try:
+ df = sqlContext.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"])
+ tf = HashingTF(numFeatures=10, inputCol="words", outputCol="features")
+ pca = PCA(k=2, inputCol="features", outputCol="pca_features")
+ p0 = Pipeline(stages=[pca])
+ pl = Pipeline(stages=[tf, p0])
+ model = pl.fit(df)
- self.assertIsInstance(loaded_pca, PCA)
- self.assertEqual(loaded_pca.uid, pca.uid)
- self.assertEqual(loaded_pca.getK(), pca.getK())
+ pipeline_path = temp_path + "/pipeline"
+ pl.save(pipeline_path)
+ loaded_pipeline = Pipeline.load(pipeline_path)
+ self._compare_pipelines(pl, loaded_pipeline)
model_path = temp_path + "/pipeline-model"
model.save(model_path)
loaded_model = PipelineModel.load(model_path)
- [model_tf, model_pca] = model.stages
- [loaded_model_tf, loaded_model_pca] = loaded_model.stages
- self.assertEqual(model_tf.uid, loaded_model_tf.uid)
- self.assertEqual(model_tf.getOrDefault(param), loaded_model_tf.getOrDefault(param))
-
- self.assertEqual(model_pca.uid, loaded_model_pca.uid)
- self.assertEqual(model_pca.pc, loaded_model_pca.pc)
- self.assertEqual(model_pca.explainedVariance, loaded_model_pca.explainedVariance)
+ self._compare_pipelines(model, loaded_model)
finally:
try:
rmtree(temp_path)
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index 42801c91bb..6703851262 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -74,18 +74,38 @@ class Identifiable(object):
@inherit_doc
-class JavaMLWriter(object):
+class MLWriter(object):
"""
.. note:: Experimental
- Utility class that can save ML instances through their Scala implementation.
+ Utility class that can save ML instances.
.. versionadded:: 2.0.0
"""
+ def save(self, path):
+ """Save the ML instance to the input path."""
+ raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self))
+
+ def overwrite(self):
+ """Overwrites if the output path already exists."""
+ raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self))
+
+ def context(self, sqlContext):
+ """Sets the SQL context to use for saving."""
+ raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self))
+
+
+@inherit_doc
+class JavaMLWriter(MLWriter):
+ """
+ (Private) Specialization of :py:class:`MLWriter` for :py:class:`JavaWrapper` types
+ """
+
def __init__(self, instance):
- instance._transfer_params_to_java()
- self._jwrite = instance._java_obj.write()
+ super(JavaMLWriter, self).__init__()
+ _java_obj = instance._to_java()
+ self._jwrite = _java_obj.write()
def save(self, path):
"""Save the ML instance to the input path."""
@@ -109,14 +129,14 @@ class MLWritable(object):
"""
.. note:: Experimental
- Mixin for ML instances that provide JavaMLWriter.
+ Mixin for ML instances that provide :py:class:`MLWriter`.
.. versionadded:: 2.0.0
"""
def write(self):
"""Returns an JavaMLWriter instance for this ML instance."""
- return JavaMLWriter(self)
+ raise NotImplementedError("MLWritable is not yet implemented for type: %r" % type(self))
def save(self, path):
"""Save this ML instance to the given path, a shortcut of `write().save(path)`."""
@@ -124,15 +144,41 @@ class MLWritable(object):
@inherit_doc
-class JavaMLReader(object):
+class JavaMLWritable(MLWritable):
+ """
+ (Private) Mixin for ML instances that provide :py:class:`JavaMLWriter`.
+ """
+
+ def write(self):
+ """Returns an JavaMLWriter instance for this ML instance."""
+ return JavaMLWriter(self)
+
+
+@inherit_doc
+class MLReader(object):
"""
.. note:: Experimental
- Utility class that can load ML instances through their Scala implementation.
+ Utility class that can load ML instances.
.. versionadded:: 2.0.0
"""
+ def load(self, path):
+ """Load the ML instance from the input path."""
+ raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self))
+
+ def context(self, sqlContext):
+ """Sets the SQL context to use for loading."""
+ raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self))
+
+
+@inherit_doc
+class JavaMLReader(MLReader):
+ """
+ (Private) Specialization of :py:class:`MLReader` for :py:class:`JavaWrapper` types
+ """
+
def __init__(self, clazz):
self._clazz = clazz
self._jread = self._load_java_obj(clazz).read()
@@ -142,11 +188,10 @@ class JavaMLReader(object):
if not isinstance(path, basestring):
raise TypeError("path should be a basestring, got type %s" % type(path))
java_obj = self._jread.load(path)
- instance = self._clazz()
- instance._java_obj = java_obj
- instance._resetUid(java_obj.uid())
- instance._transfer_params_from_java()
- return instance
+ if not hasattr(self._clazz, "_from_java"):
+ raise NotImplementedError("This Java ML type cannot be loaded into Python currently: %r"
+ % self._clazz)
+ return self._clazz._from_java(java_obj)
def context(self, sqlContext):
"""Sets the SQL context to use for loading."""
@@ -164,7 +209,7 @@ class JavaMLReader(object):
if clazz.__name__ in ("Pipeline", "PipelineModel"):
# Remove the last package name "pipeline" for Pipeline and PipelineModel.
java_package = ".".join(java_package.split(".")[0:-1])
- return ".".join([java_package, clazz.__name__])
+ return java_package + "." + clazz.__name__
@classmethod
def _load_java_obj(cls, clazz):
@@ -181,7 +226,7 @@ class MLReadable(object):
"""
.. note:: Experimental
- Mixin for instances that provide JavaMLReader.
+ Mixin for instances that provide :py:class:`MLReader`.
.. versionadded:: 2.0.0
"""
@@ -189,9 +234,21 @@ class MLReadable(object):
@classmethod
def read(cls):
"""Returns an JavaMLReader instance for this class."""
- return JavaMLReader(cls)
+ raise NotImplementedError("MLReadable.read() not implemented for type: %r" % cls)
@classmethod
def load(cls, path):
"""Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
return cls.read().load(path)
+
+
+@inherit_doc
+class JavaMLReadable(MLReadable):
+ """
+ (Private) Mixin for instances that provide JavaMLReader.
+ """
+
+ @classmethod
+ def read(cls):
+ """Returns an JavaMLReader instance for this class."""
+ return JavaMLReader(cls)
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index 37dcb23b67..35b0eba926 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -95,12 +95,26 @@ class JavaWrapper(Params):
"""
return _jvm().org.apache.spark.ml.param.ParamMap()
- def _transfer_stage_to_java(self):
+ def _to_java(self):
+ """
+ Transfer this instance's Params to the wrapped Java object, and return the Java object.
+ Used for ML persistence.
+
+ Meta-algorithms such as Pipeline should override this method.
+
+ :return: Java object equivalent to this instance.
+ """
self._transfer_params_to_java()
return self._java_obj
@staticmethod
- def _transfer_stage_from_java(java_stage):
+ def _from_java(java_stage):
+ """
+ Given a Java object, create and return a Python wrapper of it.
+ Used for ML persistence.
+
+ Meta-algorithms such as Pipeline should override this method as a classmethod.
+ """
def __get_class(clazz):
"""
Loads Python class from its name.
@@ -113,13 +127,18 @@ class JavaWrapper(Params):
return m
stage_name = java_stage.getClass().getName().replace("org.apache.spark", "pyspark")
# Generate a default new instance from the stage_name class.
- py_stage = __get_class(stage_name)()
- assert(isinstance(py_stage, JavaWrapper),
- "Python side implementation is not supported in the meta-PipelineStage currently.")
- # Load information from java_stage to the instance.
- py_stage._java_obj = java_stage
- py_stage._resetUid(java_stage.uid())
- py_stage._transfer_params_from_java()
+ py_type = __get_class(stage_name)
+ if issubclass(py_type, JavaWrapper):
+ # Load information from java_stage to the instance.
+ py_stage = py_type()
+ py_stage._java_obj = java_stage
+ py_stage._resetUid(java_stage.uid())
+ py_stage._transfer_params_from_java()
+ elif hasattr(py_type, "_from_java"):
+ py_stage = py_type._from_java(java_stage)
+ else:
+ raise NotImplementedError("This Java stage cannot be loaded into Python currently: %r"
+ % stage_name)
return py_stage