aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/ml/classification.py8
-rw-r--r--python/pyspark/ml/clustering.py4
-rw-r--r--python/pyspark/ml/feature.py89
-rw-r--r--python/pyspark/ml/pipeline.py150
-rw-r--r--python/pyspark/ml/recommendation.py4
-rw-r--r--python/pyspark/ml/regression.py12
-rw-r--r--python/pyspark/ml/tests.py82
-rw-r--r--python/pyspark/ml/util.py89
-rw-r--r--python/pyspark/ml/wrapper.py37
9 files changed, 300 insertions, 175 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 16ad76483d..8075108114 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -38,7 +38,7 @@ __all__ = ['LogisticRegression', 'LogisticRegressionModel',
class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
HasRegParam, HasTol, HasProbabilityCol, HasRawPredictionCol,
HasElasticNetParam, HasFitIntercept, HasStandardization, HasThresholds,
- HasWeightCol, MLWritable, MLReadable):
+ HasWeightCol, JavaMLWritable, JavaMLReadable):
"""
Logistic regression.
Currently, this class only supports binary classification.
@@ -198,7 +198,7 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
" threshold (%g) and thresholds (equivalent to %g)" % (t2, t))
-class LogisticRegressionModel(JavaModel, MLWritable, MLReadable):
+class LogisticRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by LogisticRegression.
@@ -601,7 +601,7 @@ class GBTClassificationModel(TreeEnsembleModels):
@inherit_doc
class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasProbabilityCol,
- HasRawPredictionCol, MLWritable, MLReadable):
+ HasRawPredictionCol, JavaMLWritable, JavaMLReadable):
"""
Naive Bayes Classifiers.
It supports both Multinomial and Bernoulli NB. Multinomial NB
@@ -720,7 +720,7 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H
return self.getOrDefault(self.modelType)
-class NaiveBayesModel(JavaModel, MLWritable, MLReadable):
+class NaiveBayesModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by NaiveBayes.
diff --git a/python/pyspark/ml/clustering.py b/python/pyspark/ml/clustering.py
index 1cea477acb..2db5b82c44 100644
--- a/python/pyspark/ml/clustering.py
+++ b/python/pyspark/ml/clustering.py
@@ -25,7 +25,7 @@ __all__ = ['BisectingKMeans', 'BisectingKMeansModel',
'KMeans', 'KMeansModel']
-class KMeansModel(JavaModel, MLWritable, MLReadable):
+class KMeansModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by KMeans.
@@ -48,7 +48,7 @@ class KMeansModel(JavaModel, MLWritable, MLReadable):
@inherit_doc
class KMeans(JavaEstimator, HasFeaturesCol, HasPredictionCol, HasMaxIter, HasTol, HasSeed,
- MLWritable, MLReadable):
+ JavaMLWritable, JavaMLReadable):
"""
K-means clustering with support for multiple parallel runs and a k-means++ like initialization
mode (the k-means|| algorithm by Bahmani et al). When multiple concurrent runs are requested,
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 3182faac0d..16cb9d1db3 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -22,7 +22,7 @@ if sys.version > '3':
from pyspark import since
from pyspark.rdd import ignore_unicode_prefix
from pyspark.ml.param.shared import *
-from pyspark.ml.util import keyword_only, MLReadable, MLWritable
+from pyspark.ml.util import keyword_only, JavaMLReadable, JavaMLWritable
from pyspark.ml.wrapper import JavaEstimator, JavaModel, JavaTransformer, _jvm
from pyspark.mllib.common import inherit_doc
from pyspark.mllib.linalg import _convert_to_vector
@@ -58,7 +58,7 @@ __all__ = ['Binarizer',
@inherit_doc
-class Binarizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
+class Binarizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -123,7 +123,7 @@ class Binarizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritab
@inherit_doc
-class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
+class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -198,7 +198,7 @@ class Bucketizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWrita
@inherit_doc
-class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable):
+class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -324,7 +324,7 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWr
return CountVectorizerModel(java_model)
-class CountVectorizerModel(JavaModel, MLReadable, MLWritable):
+class CountVectorizerModel(JavaModel, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -343,7 +343,7 @@ class CountVectorizerModel(JavaModel, MLReadable, MLWritable):
@inherit_doc
-class DCT(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
+class DCT(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -415,7 +415,8 @@ class DCT(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
@inherit_doc
-class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
+class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable,
+ JavaMLWritable):
"""
.. note:: Experimental
@@ -481,7 +482,8 @@ class ElementwiseProduct(JavaTransformer, HasInputCol, HasOutputCol, MLReadable,
@inherit_doc
-class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, MLReadable, MLWritable):
+class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, JavaMLReadable,
+ JavaMLWritable):
"""
.. note:: Experimental
@@ -529,7 +531,7 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, MLRe
@inherit_doc
-class IDF(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable):
+class IDF(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -604,7 +606,7 @@ class IDF(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable):
return IDFModel(java_model)
-class IDFModel(JavaModel, MLReadable, MLWritable):
+class IDFModel(JavaModel, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -615,7 +617,7 @@ class IDFModel(JavaModel, MLReadable, MLWritable):
@inherit_doc
-class MaxAbsScaler(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable):
+class MaxAbsScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -676,7 +678,7 @@ class MaxAbsScaler(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWrita
return MaxAbsScalerModel(java_model)
-class MaxAbsScalerModel(JavaModel, MLReadable, MLWritable):
+class MaxAbsScalerModel(JavaModel, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -695,7 +697,7 @@ class MaxAbsScalerModel(JavaModel, MLReadable, MLWritable):
@inherit_doc
-class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable):
+class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -802,7 +804,7 @@ class MinMaxScaler(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWrita
return MinMaxScalerModel(java_model)
-class MinMaxScalerModel(JavaModel, MLReadable, MLWritable):
+class MinMaxScalerModel(JavaModel, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -830,7 +832,7 @@ class MinMaxScalerModel(JavaModel, MLReadable, MLWritable):
@inherit_doc
@ignore_unicode_prefix
-class NGram(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
+class NGram(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -908,7 +910,7 @@ class NGram(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
@inherit_doc
-class Normalizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
+class Normalizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -974,7 +976,7 @@ class Normalizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWrita
@inherit_doc
-class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
+class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -1056,7 +1058,8 @@ class OneHotEncoder(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWr
@inherit_doc
-class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
+class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable,
+ JavaMLWritable):
"""
.. note:: Experimental
@@ -1123,8 +1126,8 @@ class PolynomialExpansion(JavaTransformer, HasInputCol, HasOutputCol, MLReadable
@inherit_doc
-class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasSeed, MLReadable,
- MLWritable):
+class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasSeed, JavaMLReadable,
+ JavaMLWritable):
"""
.. note:: Experimental
@@ -1213,7 +1216,7 @@ class QuantileDiscretizer(JavaEstimator, HasInputCol, HasOutputCol, HasSeed, MLR
@inherit_doc
@ignore_unicode_prefix
-class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
+class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -1345,7 +1348,7 @@ class RegexTokenizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLW
@inherit_doc
-class SQLTransformer(JavaTransformer, MLReadable, MLWritable):
+class SQLTransformer(JavaTransformer, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -1406,7 +1409,7 @@ class SQLTransformer(JavaTransformer, MLReadable, MLWritable):
@inherit_doc
-class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable):
+class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -1499,7 +1502,7 @@ class StandardScaler(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWri
return StandardScalerModel(java_model)
-class StandardScalerModel(JavaModel, MLReadable, MLWritable):
+class StandardScalerModel(JavaModel, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -1526,8 +1529,8 @@ class StandardScalerModel(JavaModel, MLReadable, MLWritable):
@inherit_doc
-class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, MLReadable,
- MLWritable):
+class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid, JavaMLReadable,
+ JavaMLWritable):
"""
.. note:: Experimental
@@ -1591,7 +1594,7 @@ class StringIndexer(JavaEstimator, HasInputCol, HasOutputCol, HasHandleInvalid,
return StringIndexerModel(java_model)
-class StringIndexerModel(JavaModel, MLReadable, MLWritable):
+class StringIndexerModel(JavaModel, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -1610,7 +1613,7 @@ class StringIndexerModel(JavaModel, MLReadable, MLWritable):
@inherit_doc
-class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
+class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -1664,7 +1667,7 @@ class IndexToString(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWr
return self.getOrDefault(self.labels)
-class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
+class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -1751,7 +1754,7 @@ class StopWordsRemover(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, M
@inherit_doc
@ignore_unicode_prefix
-class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
+class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -1806,7 +1809,7 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritab
@inherit_doc
-class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, MLReadable, MLWritable):
+class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -1852,7 +1855,7 @@ class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol, MLReadable, M
@inherit_doc
-class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable):
+class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -1969,7 +1972,7 @@ class VectorIndexer(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWrit
return VectorIndexerModel(java_model)
-class VectorIndexerModel(JavaModel, MLReadable, MLWritable):
+class VectorIndexerModel(JavaModel, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -1998,7 +2001,7 @@ class VectorIndexerModel(JavaModel, MLReadable, MLWritable):
@inherit_doc
-class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWritable):
+class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -2093,7 +2096,7 @@ class VectorSlicer(JavaTransformer, HasInputCol, HasOutputCol, MLReadable, MLWri
@inherit_doc
@ignore_unicode_prefix
class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, HasOutputCol,
- MLReadable, MLWritable):
+ JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -2226,7 +2229,7 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has
return Word2VecModel(java_model)
-class Word2VecModel(JavaModel, MLReadable, MLWritable):
+class Word2VecModel(JavaModel, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -2257,7 +2260,7 @@ class Word2VecModel(JavaModel, MLReadable, MLWritable):
@inherit_doc
-class PCA(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable):
+class PCA(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -2331,7 +2334,7 @@ class PCA(JavaEstimator, HasInputCol, HasOutputCol, MLReadable, MLWritable):
return PCAModel(java_model)
-class PCAModel(JavaModel, MLReadable, MLWritable):
+class PCAModel(JavaModel, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -2360,7 +2363,7 @@ class PCAModel(JavaModel, MLReadable, MLWritable):
@inherit_doc
-class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, MLReadable, MLWritable):
+class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -2463,7 +2466,7 @@ class RFormula(JavaEstimator, HasFeaturesCol, HasLabelCol, MLReadable, MLWritabl
return RFormulaModel(java_model)
-class RFormulaModel(JavaModel, MLReadable, MLWritable):
+class RFormulaModel(JavaModel, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
@@ -2474,8 +2477,8 @@ class RFormulaModel(JavaModel, MLReadable, MLWritable):
@inherit_doc
-class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, MLReadable,
- MLWritable):
+class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, JavaMLReadable,
+ JavaMLWritable):
"""
.. note:: Experimental
@@ -2561,7 +2564,7 @@ class ChiSqSelector(JavaEstimator, HasFeaturesCol, HasOutputCol, HasLabelCol, ML
return ChiSqSelectorModel(java_model)
-class ChiSqSelectorModel(JavaModel, MLReadable, MLWritable):
+class ChiSqSelectorModel(JavaModel, JavaMLReadable, JavaMLWritable):
"""
.. note:: Experimental
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
index a1658b0a02..2b5504bc29 100644
--- a/python/pyspark/ml/pipeline.py
+++ b/python/pyspark/ml/pipeline.py
@@ -24,72 +24,31 @@ from pyspark import SparkContext
from pyspark import since
from pyspark.ml import Estimator, Model, Transformer
from pyspark.ml.param import Param, Params
-from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader
+from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, MLReadable, MLWritable
from pyspark.ml.wrapper import JavaWrapper
from pyspark.mllib.common import inherit_doc
-def _stages_java2py(java_stages):
- """
- Transforms the parameter Python stages from a list of Java stages.
- :param java_stages: An array of Java stages.
- :return: An array of Python stages.
- """
-
- return [JavaWrapper._transfer_stage_from_java(stage) for stage in java_stages]
-
-
-def _stages_py2java(py_stages, cls):
- """
- Transforms the parameter of Python stages to a Java array of Java stages.
- :param py_stages: An array of Python stages.
- :return: A Java array of Java Stages.
- """
-
- for stage in py_stages:
- assert(isinstance(stage, JavaWrapper),
- "Python side implementation is not supported in the meta-PipelineStage currently.")
- gateway = SparkContext._gateway
- java_stages = gateway.new_array(cls, len(py_stages))
- for idx, stage in enumerate(py_stages):
- java_stages[idx] = stage._transfer_stage_to_java()
- return java_stages
-
-
@inherit_doc
-class PipelineMLWriter(JavaMLWriter, JavaWrapper):
+class PipelineMLWriter(JavaMLWriter):
"""
Private Pipeline utility class that can save ML instances through their Scala implementation.
- """
- def __init__(self, instance):
- cls = SparkContext._jvm.org.apache.spark.ml.PipelineStage
- self._java_obj = self._new_java_obj("org.apache.spark.ml.Pipeline", instance.uid)
- self._java_obj.setStages(_stages_py2java(instance.getStages(), cls))
- self._jwrite = self._java_obj.write()
+ We can currently use JavaMLWriter, rather than MLWriter, since Pipeline implements _to_java.
+ """
@inherit_doc
class PipelineMLReader(JavaMLReader):
"""
Private utility class that can load Pipeline instances through their Scala implementation.
- """
- def load(self, path):
- """Load the Pipeline instance from the input path."""
- if not isinstance(path, basestring):
- raise TypeError("path should be a basestring, got type %s" % type(path))
-
- java_obj = self._jread.load(path)
- instance = self._clazz()
- instance._resetUid(java_obj.uid())
- instance.setStages(_stages_java2py(java_obj.getStages()))
-
- return instance
+ We can currently use JavaMLReader, rather than MLReader, since Pipeline implements _from_java.
+ """
@inherit_doc
-class Pipeline(Estimator):
+class Pipeline(Estimator, MLReadable, MLWritable):
"""
A simple pipeline, which acts as an estimator. A Pipeline consists
of a sequence of stages, each of which is either an
@@ -206,49 +165,65 @@ class Pipeline(Estimator):
@classmethod
@since("2.0.0")
def read(cls):
- """Returns an JavaMLReader instance for this class."""
+ """Returns an MLReader instance for this class."""
return PipelineMLReader(cls)
@classmethod
- @since("2.0.0")
- def load(cls, path):
- """Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
- return cls.read().load(path)
+ def _from_java(cls, java_stage):
+ """
+ Given a Java Pipeline, create and return a Python wrapper of it.
+ Used for ML persistence.
+ """
+ # Create a new instance of this stage.
+ py_stage = cls()
+ # Load information from java_stage to the instance.
+ py_stages = [JavaWrapper._from_java(s) for s in java_stage.getStages()]
+ py_stage.setStages(py_stages)
+ py_stage._resetUid(java_stage.uid())
+ return py_stage
+
+ def _to_java(self):
+ """
+ Transfer this instance to a Java Pipeline. Used for ML persistence.
+
+ :return: Java object equivalent to this instance.
+ """
+
+ gateway = SparkContext._gateway
+ cls = SparkContext._jvm.org.apache.spark.ml.PipelineStage
+ java_stages = gateway.new_array(cls, len(self.getStages()))
+ for idx, stage in enumerate(self.getStages()):
+ java_stages[idx] = stage._to_java()
+
+ _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.Pipeline", self.uid)
+ _java_obj.setStages(java_stages)
+
+ return _java_obj
@inherit_doc
-class PipelineModelMLWriter(JavaMLWriter, JavaWrapper):
+class PipelineModelMLWriter(JavaMLWriter):
"""
Private PipelineModel utility class that can save ML instances through their Scala
implementation.
- """
- def __init__(self, instance):
- cls = SparkContext._jvm.org.apache.spark.ml.Transformer
- self._java_obj = self._new_java_obj("org.apache.spark.ml.PipelineModel",
- instance.uid,
- _stages_py2java(instance.stages, cls))
- self._jwrite = self._java_obj.write()
+ We can (currently) use JavaMLWriter, rather than MLWriter, since PipelineModel implements
+ _to_java.
+ """
@inherit_doc
class PipelineModelMLReader(JavaMLReader):
"""
Private utility class that can load PipelineModel instances through their Scala implementation.
- """
- def load(self, path):
- """Load the PipelineModel instance from the input path."""
- if not isinstance(path, basestring):
- raise TypeError("path should be a basestring, got type %s" % type(path))
- java_obj = self._jread.load(path)
- instance = self._clazz(_stages_java2py(java_obj.stages()))
- instance._resetUid(java_obj.uid())
- return instance
+ We can currently use JavaMLReader, rather than MLReader, since PipelineModel implements
+ _from_java.
+ """
@inherit_doc
-class PipelineModel(Model):
+class PipelineModel(Model, MLReadable, MLWritable):
"""
Represents a compiled pipeline with transformers and fitted models.
@@ -294,7 +269,32 @@ class PipelineModel(Model):
return PipelineModelMLReader(cls)
@classmethod
- @since("2.0.0")
- def load(cls, path):
- """Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
- return cls.read().load(path)
+ def _from_java(cls, java_stage):
+ """
+ Given a Java PipelineModel, create and return a Python wrapper of it.
+ Used for ML persistence.
+ """
+ # Load information from java_stage to the instance.
+ py_stages = [JavaWrapper._from_java(s) for s in java_stage.stages()]
+ # Create a new instance of this stage.
+ py_stage = cls(py_stages)
+ py_stage._resetUid(java_stage.uid())
+ return py_stage
+
+ def _to_java(self):
+ """
+ Transfer this instance to a Java PipelineModel. Used for ML persistence.
+
+ :return: Java object equivalent to this instance.
+ """
+
+ gateway = SparkContext._gateway
+ cls = SparkContext._jvm.org.apache.spark.ml.Transformer
+ java_stages = gateway.new_array(cls, len(self.stages))
+ for idx, stage in enumerate(self.stages):
+ java_stages[idx] = stage._to_java()
+
+ _java_obj =\
+ JavaWrapper._new_java_obj("org.apache.spark.ml.PipelineModel", self.uid, java_stages)
+
+ return _java_obj
diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py
index 2b605e5c50..de4c2675ed 100644
--- a/python/pyspark/ml/recommendation.py
+++ b/python/pyspark/ml/recommendation.py
@@ -27,7 +27,7 @@ __all__ = ['ALS', 'ALSModel']
@inherit_doc
class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, HasRegParam, HasSeed,
- MLWritable, MLReadable):
+ JavaMLWritable, JavaMLReadable):
"""
Alternating Least Squares (ALS) matrix factorization.
@@ -289,7 +289,7 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
return self.getOrDefault(self.nonnegative)
-class ALSModel(JavaModel, MLWritable, MLReadable):
+class ALSModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by ALS.
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 6e23393f91..664a44bc47 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -35,7 +35,7 @@ __all__ = ['AFTSurvivalRegression', 'AFTSurvivalRegressionModel',
@inherit_doc
class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, HasMaxIter,
HasRegParam, HasTol, HasElasticNetParam, HasFitIntercept,
- HasStandardization, HasSolver, HasWeightCol, MLWritable, MLReadable):
+ HasStandardization, HasSolver, HasWeightCol, JavaMLWritable, JavaMLReadable):
"""
Linear regression.
@@ -118,7 +118,7 @@ class LinearRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPrediction
return LinearRegressionModel(java_model)
-class LinearRegressionModel(JavaModel, MLWritable, MLReadable):
+class LinearRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by LinearRegression.
@@ -154,7 +154,7 @@ class LinearRegressionModel(JavaModel, MLWritable, MLReadable):
@inherit_doc
class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
- HasWeightCol, MLWritable, MLReadable):
+ HasWeightCol, JavaMLWritable, JavaMLReadable):
"""
.. note:: Experimental
@@ -249,7 +249,7 @@ class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
return self.getOrDefault(self.featureIndex)
-class IsotonicRegressionModel(JavaModel, MLWritable, MLReadable):
+class IsotonicRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
.. note:: Experimental
@@ -719,7 +719,7 @@ class GBTRegressionModel(TreeEnsembleModels):
@inherit_doc
class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
- HasFitIntercept, HasMaxIter, HasTol, MLWritable, MLReadable):
+ HasFitIntercept, HasMaxIter, HasTol, JavaMLWritable, JavaMLReadable):
"""
Accelerated Failure Time (AFT) Model Survival Regression
@@ -857,7 +857,7 @@ class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
return self.getOrDefault(self.quantilesCol)
-class AFTSurvivalRegressionModel(JavaModel, MLWritable, MLReadable):
+class AFTSurvivalRegressionModel(JavaModel, JavaMLWritable, JavaMLReadable):
"""
Model fitted by AFTSurvivalRegression.
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 9783ce7e77..211248e8b2 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -47,6 +47,7 @@ from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed
from pyspark.ml.regression import LinearRegression
from pyspark.ml.tuning import *
from pyspark.ml.util import keyword_only
+from pyspark.ml.wrapper import JavaWrapper
from pyspark.mllib.linalg import DenseVector
from pyspark.sql import DataFrame, SQLContext, Row
from pyspark.sql.functions import rand
@@ -517,7 +518,39 @@ class PersistenceTest(PySparkTestCase):
except OSError:
pass
+ def _compare_pipelines(self, m1, m2):
+ """
+ Compare 2 ML types, asserting that they are equivalent.
+ This currently supports:
+ - basic types
+ - Pipeline, PipelineModel
+ This checks:
+ - uid
+ - type
+ - Param values and parents
+ """
+ self.assertEqual(m1.uid, m2.uid)
+ self.assertEqual(type(m1), type(m2))
+ if isinstance(m1, JavaWrapper):
+ self.assertEqual(len(m1.params), len(m2.params))
+ for p in m1.params:
+ self.assertEqual(m1.getOrDefault(p), m2.getOrDefault(p))
+ self.assertEqual(p.parent, m2.getParam(p.name).parent)
+ elif isinstance(m1, Pipeline):
+ self.assertEqual(len(m1.getStages()), len(m2.getStages()))
+ for s1, s2 in zip(m1.getStages(), m2.getStages()):
+ self._compare_pipelines(s1, s2)
+ elif isinstance(m1, PipelineModel):
+ self.assertEqual(len(m1.stages), len(m2.stages))
+ for s1, s2 in zip(m1.stages, m2.stages):
+ self._compare_pipelines(s1, s2)
+ else:
+ raise RuntimeError("_compare_pipelines does not yet support type: %s" % type(m1))
+
def test_pipeline_persistence(self):
+ """
+ Pipeline[HashingTF, PCA]
+ """
sqlContext = SQLContext(self.sc)
temp_path = tempfile.mkdtemp()
@@ -527,33 +560,46 @@ class PersistenceTest(PySparkTestCase):
pca = PCA(k=2, inputCol="features", outputCol="pca_features")
pl = Pipeline(stages=[tf, pca])
model = pl.fit(df)
+
pipeline_path = temp_path + "/pipeline"
pl.save(pipeline_path)
loaded_pipeline = Pipeline.load(pipeline_path)
- self.assertEqual(loaded_pipeline.uid, pl.uid)
- self.assertEqual(len(loaded_pipeline.getStages()), 2)
+ self._compare_pipelines(pl, loaded_pipeline)
+
+ model_path = temp_path + "/pipeline-model"
+ model.save(model_path)
+ loaded_model = PipelineModel.load(model_path)
+ self._compare_pipelines(model, loaded_model)
+ finally:
+ try:
+ rmtree(temp_path)
+ except OSError:
+ pass
- [loaded_tf, loaded_pca] = loaded_pipeline.getStages()
- self.assertIsInstance(loaded_tf, HashingTF)
- self.assertEqual(loaded_tf.uid, tf.uid)
- param = loaded_tf.getParam("numFeatures")
- self.assertEqual(loaded_tf.getOrDefault(param), tf.getOrDefault(param))
+ def test_nested_pipeline_persistence(self):
+ """
+ Pipeline[HashingTF, Pipeline[PCA]]
+ """
+ sqlContext = SQLContext(self.sc)
+ temp_path = tempfile.mkdtemp()
+
+ try:
+ df = sqlContext.createDataFrame([(["a", "b", "c"],), (["c", "d", "e"],)], ["words"])
+ tf = HashingTF(numFeatures=10, inputCol="words", outputCol="features")
+ pca = PCA(k=2, inputCol="features", outputCol="pca_features")
+ p0 = Pipeline(stages=[pca])
+ pl = Pipeline(stages=[tf, p0])
+ model = pl.fit(df)
- self.assertIsInstance(loaded_pca, PCA)
- self.assertEqual(loaded_pca.uid, pca.uid)
- self.assertEqual(loaded_pca.getK(), pca.getK())
+ pipeline_path = temp_path + "/pipeline"
+ pl.save(pipeline_path)
+ loaded_pipeline = Pipeline.load(pipeline_path)
+ self._compare_pipelines(pl, loaded_pipeline)
model_path = temp_path + "/pipeline-model"
model.save(model_path)
loaded_model = PipelineModel.load(model_path)
- [model_tf, model_pca] = model.stages
- [loaded_model_tf, loaded_model_pca] = loaded_model.stages
- self.assertEqual(model_tf.uid, loaded_model_tf.uid)
- self.assertEqual(model_tf.getOrDefault(param), loaded_model_tf.getOrDefault(param))
-
- self.assertEqual(model_pca.uid, loaded_model_pca.uid)
- self.assertEqual(model_pca.pc, loaded_model_pca.pc)
- self.assertEqual(model_pca.explainedVariance, loaded_model_pca.explainedVariance)
+ self._compare_pipelines(model, loaded_model)
finally:
try:
rmtree(temp_path)
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index 42801c91bb..6703851262 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -74,18 +74,38 @@ class Identifiable(object):
@inherit_doc
-class JavaMLWriter(object):
+class MLWriter(object):
"""
.. note:: Experimental
- Utility class that can save ML instances through their Scala implementation.
+ Utility class that can save ML instances.
.. versionadded:: 2.0.0
"""
+ def save(self, path):
+ """Save the ML instance to the input path."""
+ raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self))
+
+ def overwrite(self):
+ """Overwrites if the output path already exists."""
+ raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self))
+
+ def context(self, sqlContext):
+ """Sets the SQL context to use for saving."""
+ raise NotImplementedError("MLWriter is not yet implemented for type: %s" % type(self))
+
+
+@inherit_doc
+class JavaMLWriter(MLWriter):
+ """
+ (Private) Specialization of :py:class:`MLWriter` for :py:class:`JavaWrapper` types
+ """
+
def __init__(self, instance):
- instance._transfer_params_to_java()
- self._jwrite = instance._java_obj.write()
+ super(JavaMLWriter, self).__init__()
+ _java_obj = instance._to_java()
+ self._jwrite = _java_obj.write()
def save(self, path):
"""Save the ML instance to the input path."""
@@ -109,14 +129,14 @@ class MLWritable(object):
"""
.. note:: Experimental
- Mixin for ML instances that provide JavaMLWriter.
+ Mixin for ML instances that provide :py:class:`MLWriter`.
.. versionadded:: 2.0.0
"""
def write(self):
"""Returns an JavaMLWriter instance for this ML instance."""
- return JavaMLWriter(self)
+ raise NotImplementedError("MLWritable is not yet implemented for type: %r" % type(self))
def save(self, path):
"""Save this ML instance to the given path, a shortcut of `write().save(path)`."""
@@ -124,15 +144,41 @@ class MLWritable(object):
@inherit_doc
-class JavaMLReader(object):
+class JavaMLWritable(MLWritable):
+ """
+ (Private) Mixin for ML instances that provide :py:class:`JavaMLWriter`.
+ """
+
+ def write(self):
+ """Returns an JavaMLWriter instance for this ML instance."""
+ return JavaMLWriter(self)
+
+
+@inherit_doc
+class MLReader(object):
"""
.. note:: Experimental
- Utility class that can load ML instances through their Scala implementation.
+ Utility class that can load ML instances.
.. versionadded:: 2.0.0
"""
+ def load(self, path):
+ """Load the ML instance from the input path."""
+ raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self))
+
+ def context(self, sqlContext):
+ """Sets the SQL context to use for loading."""
+ raise NotImplementedError("MLReader is not yet implemented for type: %s" % type(self))
+
+
+@inherit_doc
+class JavaMLReader(MLReader):
+ """
+ (Private) Specialization of :py:class:`MLReader` for :py:class:`JavaWrapper` types
+ """
+
def __init__(self, clazz):
self._clazz = clazz
self._jread = self._load_java_obj(clazz).read()
@@ -142,11 +188,10 @@ class JavaMLReader(object):
if not isinstance(path, basestring):
raise TypeError("path should be a basestring, got type %s" % type(path))
java_obj = self._jread.load(path)
- instance = self._clazz()
- instance._java_obj = java_obj
- instance._resetUid(java_obj.uid())
- instance._transfer_params_from_java()
- return instance
+ if not hasattr(self._clazz, "_from_java"):
+ raise NotImplementedError("This Java ML type cannot be loaded into Python currently: %r"
+ % self._clazz)
+ return self._clazz._from_java(java_obj)
def context(self, sqlContext):
"""Sets the SQL context to use for loading."""
@@ -164,7 +209,7 @@ class JavaMLReader(object):
if clazz.__name__ in ("Pipeline", "PipelineModel"):
# Remove the last package name "pipeline" for Pipeline and PipelineModel.
java_package = ".".join(java_package.split(".")[0:-1])
- return ".".join([java_package, clazz.__name__])
+ return java_package + "." + clazz.__name__
@classmethod
def _load_java_obj(cls, clazz):
@@ -181,7 +226,7 @@ class MLReadable(object):
"""
.. note:: Experimental
- Mixin for instances that provide JavaMLReader.
+ Mixin for instances that provide :py:class:`MLReader`.
.. versionadded:: 2.0.0
"""
@@ -189,9 +234,21 @@ class MLReadable(object):
@classmethod
def read(cls):
"""Returns an JavaMLReader instance for this class."""
- return JavaMLReader(cls)
+ raise NotImplementedError("MLReadable.read() not implemented for type: %r" % cls)
@classmethod
def load(cls, path):
"""Reads an ML instance from the input path, a shortcut of `read().load(path)`."""
return cls.read().load(path)
+
+
+@inherit_doc
+class JavaMLReadable(MLReadable):
+ """
+ (Private) Mixin for instances that provide JavaMLReader.
+ """
+
+ @classmethod
+ def read(cls):
+ """Returns an JavaMLReader instance for this class."""
+ return JavaMLReader(cls)
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index 37dcb23b67..35b0eba926 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -95,12 +95,26 @@ class JavaWrapper(Params):
"""
return _jvm().org.apache.spark.ml.param.ParamMap()
- def _transfer_stage_to_java(self):
+ def _to_java(self):
+ """
+ Transfer this instance's Params to the wrapped Java object, and return the Java object.
+ Used for ML persistence.
+
+ Meta-algorithms such as Pipeline should override this method.
+
+ :return: Java object equivalent to this instance.
+ """
self._transfer_params_to_java()
return self._java_obj
@staticmethod
- def _transfer_stage_from_java(java_stage):
+ def _from_java(java_stage):
+ """
+ Given a Java object, create and return a Python wrapper of it.
+ Used for ML persistence.
+
+ Meta-algorithms such as Pipeline should override this method as a classmethod.
+ """
def __get_class(clazz):
"""
Loads Python class from its name.
@@ -113,13 +127,18 @@ class JavaWrapper(Params):
return m
stage_name = java_stage.getClass().getName().replace("org.apache.spark", "pyspark")
# Generate a default new instance from the stage_name class.
- py_stage = __get_class(stage_name)()
- assert(isinstance(py_stage, JavaWrapper),
- "Python side implementation is not supported in the meta-PipelineStage currently.")
- # Load information from java_stage to the instance.
- py_stage._java_obj = java_stage
- py_stage._resetUid(java_stage.uid())
- py_stage._transfer_params_from_java()
+ py_type = __get_class(stage_name)
+ if issubclass(py_type, JavaWrapper):
+ # Load information from java_stage to the instance.
+ py_stage = py_type()
+ py_stage._java_obj = java_stage
+ py_stage._resetUid(java_stage.uid())
+ py_stage._transfer_params_from_java()
+ elif hasattr(py_type, "_from_java"):
+ py_stage = py_type._from_java(java_stage)
+ else:
+ raise NotImplementedError("This Java stage cannot be loaded into Python currently: %r"
+ % stage_name)
return py_stage