aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/tuning.py
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2016-04-06 11:24:11 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-06 11:24:11 -0700
commitdb0b06c6ea7412266158b1c710bdc8ca30e26430 (patch)
tree58c218ecdbe61927b7f9c3addf11b0bf245ffb2a /python/pyspark/ml/tuning.py
parent3c8d8821654e3d82ef927c55272348e1bcc34a79 (diff)
downloadspark-db0b06c6ea7412266158b1c710bdc8ca30e26430.tar.gz
spark-db0b06c6ea7412266158b1c710bdc8ca30e26430.tar.bz2
spark-db0b06c6ea7412266158b1c710bdc8ca30e26430.zip
[SPARK-13786][ML][PYSPARK] Add save/load for pyspark.ml.tuning
## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-13786 Add save/load for Python CrossValidator/Model and TrainValidationSplit/Model. ## How was this patch tested? Test with Python doctest. Author: Xusen Yin <yinxusen@gmail.com> Closes #12020 from yinxusen/SPARK-13786.
Diffstat (limited to 'python/pyspark/ml/tuning.py')
-rw-r--r--python/pyspark/ml/tuning.py407
1 files changed, 298 insertions, 109 deletions
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index a528d22e18..da00f317b3 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -18,12 +18,15 @@
import itertools
import numpy as np
+from pyspark import SparkContext
from pyspark import since
from pyspark.ml import Estimator, Model
from pyspark.ml.param import Params, Param, TypeConverters
from pyspark.ml.param.shared import HasSeed
-from pyspark.ml.util import keyword_only
+from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, MLReadable, MLWritable
+from pyspark.ml.wrapper import JavaWrapper
from pyspark.sql.functions import rand
+from pyspark.mllib.common import inherit_doc, _py2java
__all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel', 'TrainValidationSplit',
'TrainValidationSplitModel']
@@ -91,7 +94,84 @@ class ParamGridBuilder(object):
return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)]
-class CrossValidator(Estimator, HasSeed):
+class ValidatorParams(HasSeed):
+ """
+ Common params for TrainValidationSplit and CrossValidator.
+ """
+
+ estimator = Param(Params._dummy(), "estimator", "estimator to be cross-validated")
+ estimatorParamMaps = Param(Params._dummy(), "estimatorParamMaps", "estimator param maps")
+ evaluator = Param(
+ Params._dummy(), "evaluator",
+ "evaluator used to select hyper-parameters that maximize the validator metric")
+
+ def setEstimator(self, value):
+ """
+ Sets the value of :py:attr:`estimator`.
+ """
+ return self._set(estimator=value)
+
+ def getEstimator(self):
+ """
+ Gets the value of estimator or its default value.
+ """
+ return self.getOrDefault(self.estimator)
+
+ def setEstimatorParamMaps(self, value):
+ """
+ Sets the value of :py:attr:`estimatorParamMaps`.
+ """
+ return self._set(estimatorParamMaps=value)
+
+ def getEstimatorParamMaps(self):
+ """
+ Gets the value of estimatorParamMaps or its default value.
+ """
+ return self.getOrDefault(self.estimatorParamMaps)
+
+ def setEvaluator(self, value):
+ """
+ Sets the value of :py:attr:`evaluator`.
+ """
+ return self._set(evaluator=value)
+
+ def getEvaluator(self):
+ """
+ Gets the value of evaluator or its default value.
+ """
+ return self.getOrDefault(self.evaluator)
+
+ @classmethod
+ def _from_java_impl(cls, java_stage):
+ """
+ Return Python estimator, estimatorParamMaps, and evaluator from a Java ValidatorParams.
+ """
+
+ # Load information from java_stage to the instance.
+ estimator = JavaWrapper._from_java(java_stage.getEstimator())
+ evaluator = JavaWrapper._from_java(java_stage.getEvaluator())
+ epms = [estimator._transfer_param_map_from_java(epm)
+ for epm in java_stage.getEstimatorParamMaps()]
+ return estimator, epms, evaluator
+
+ def _to_java_impl(self):
+ """
+ Return Java estimator, estimatorParamMaps, and evaluator from this Python instance.
+ """
+
+ gateway = SparkContext._gateway
+ cls = SparkContext._jvm.org.apache.spark.ml.param.ParamMap
+
+ java_epms = gateway.new_array(cls, len(self.getEstimatorParamMaps()))
+ for idx, epm in enumerate(self.getEstimatorParamMaps()):
+ java_epms[idx] = self.getEstimator()._transfer_param_map_to_java(epm)
+
+ java_estimator = self.getEstimator()._to_java()
+ java_evaluator = self.getEvaluator()._to_java()
+ return java_estimator, java_epms, java_evaluator
+
+
+class CrossValidator(Estimator, ValidatorParams, MLReadable, MLWritable):
"""
K-fold cross validation.
@@ -116,11 +196,6 @@ class CrossValidator(Estimator, HasSeed):
.. versionadded:: 1.4.0
"""
- estimator = Param(Params._dummy(), "estimator", "estimator to be cross-validated")
- estimatorParamMaps = Param(Params._dummy(), "estimatorParamMaps", "estimator param maps")
- evaluator = Param(
- Params._dummy(), "evaluator",
- "evaluator used to select hyper-parameters that maximize the cross-validated metric")
numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation",
typeConverter=TypeConverters.toInt)
@@ -149,51 +224,6 @@ class CrossValidator(Estimator, HasSeed):
return self._set(**kwargs)
@since("1.4.0")
- def setEstimator(self, value):
- """
- Sets the value of :py:attr:`estimator`.
- """
- self._paramMap[self.estimator] = value
- return self
-
- @since("1.4.0")
- def getEstimator(self):
- """
- Gets the value of estimator or its default value.
- """
- return self.getOrDefault(self.estimator)
-
- @since("1.4.0")
- def setEstimatorParamMaps(self, value):
- """
- Sets the value of :py:attr:`estimatorParamMaps`.
- """
- self._paramMap[self.estimatorParamMaps] = value
- return self
-
- @since("1.4.0")
- def getEstimatorParamMaps(self):
- """
- Gets the value of estimatorParamMaps or its default value.
- """
- return self.getOrDefault(self.estimatorParamMaps)
-
- @since("1.4.0")
- def setEvaluator(self, value):
- """
- Sets the value of :py:attr:`evaluator`.
- """
- self._paramMap[self.evaluator] = value
- return self
-
- @since("1.4.0")
- def getEvaluator(self):
- """
- Gets the value of evaluator or its default value.
- """
- return self.getOrDefault(self.evaluator)
-
- @since("1.4.0")
def setNumFolds(self, value):
"""
Sets the value of :py:attr:`numFolds`.
@@ -236,7 +266,7 @@ class CrossValidator(Estimator, HasSeed):
else:
bestIndex = np.argmin(metrics)
bestModel = est.fit(dataset, epm[bestIndex])
- return CrossValidatorModel(bestModel)
+ return self._copyValues(CrossValidatorModel(bestModel))
@since("1.4.0")
def copy(self, extra=None):
@@ -258,8 +288,58 @@ class CrossValidator(Estimator, HasSeed):
newCV.setEvaluator(self.getEvaluator().copy(extra))
return newCV
+ @since("2.0.0")
+ def write(self):
+ """Returns an MLWriter instance for this ML instance."""
+ return JavaMLWriter(self)
+
+ @since("2.0.0")
+ def save(self, path):
+ """Save this ML instance to the given path, a shortcut of `write().save(path)`."""
+ self.write().save(path)
+
+ @classmethod
+ @since("2.0.0")
+ def read(cls):
+ """Returns an MLReader instance for this class."""
+ return JavaMLReader(cls)
+
+ @classmethod
+ def _from_java(cls, java_stage):
+ """
+ Given a Java CrossValidator, create and return a Python wrapper of it.
+ Used for ML persistence.
+ """
+
+ estimator, epms, evaluator = super(CrossValidator, cls)._from_java_impl(java_stage)
+ numFolds = java_stage.getNumFolds()
+ seed = java_stage.getSeed()
+ # Create a new instance of this stage.
+ py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator,
+ numFolds=numFolds, seed=seed)
+ py_stage._resetUid(java_stage.uid())
+ return py_stage
-class CrossValidatorModel(Model):
+ def _to_java(self):
+ """
+ Transfer this instance to a Java CrossValidator. Used for ML persistence.
+
+ :return: Java object equivalent to this instance.
+ """
+
+ estimator, epms, evaluator = super(CrossValidator, self)._to_java_impl()
+
+ _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.CrossValidator", self.uid)
+ _java_obj.setEstimatorParamMaps(epms)
+ _java_obj.setEvaluator(evaluator)
+ _java_obj.setEstimator(estimator)
+ _java_obj.setSeed(self.getSeed())
+ _java_obj.setNumFolds(self.getNumFolds())
+
+ return _java_obj
+
+
+class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable):
"""
Model from k-fold cross validation.
@@ -289,8 +369,60 @@ class CrossValidatorModel(Model):
extra = dict()
return CrossValidatorModel(self.bestModel.copy(extra))
+ @since("2.0.0")
+ def write(self):
+ """Returns an MLWriter instance for this ML instance."""
+ return JavaMLWriter(self)
+
+ @since("2.0.0")
+ def save(self, path):
+ """Save this ML instance to the given path, a shortcut of `write().save(path)`."""
+ self.write().save(path)
-class TrainValidationSplit(Estimator, HasSeed):
+ @classmethod
+ @since("2.0.0")
+ def read(cls):
+ """Returns an MLReader instance for this class."""
+ return JavaMLReader(cls)
+
+ @classmethod
+ def _from_java(cls, java_stage):
+ """
+ Given a Java CrossValidatorModel, create and return a Python wrapper of it.
+ Used for ML persistence.
+ """
+
+ # Load information from java_stage to the instance.
+ bestModel = JavaWrapper._from_java(java_stage.bestModel())
+ estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage)
+ # Create a new instance of this stage.
+ py_stage = cls(bestModel=bestModel)\
+ .setEstimator(estimator).setEstimatorParamMaps(epms).setEvaluator(evaluator)
+ py_stage._resetUid(java_stage.uid())
+ return py_stage
+
+ def _to_java(self):
+ """
+ Transfer this instance to a Java CrossValidatorModel. Used for ML persistence.
+
+ :return: Java object equivalent to this instance.
+ """
+
+ sc = SparkContext._active_spark_context
+
+ _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel",
+ self.uid,
+ self.bestModel._to_java(),
+ _py2java(sc, []))
+ estimator, epms, evaluator = super(CrossValidatorModel, self)._to_java_impl()
+
+ _java_obj.set("evaluator", evaluator)
+ _java_obj.set("estimator", estimator)
+ _java_obj.set("estimatorParamMaps", epms)
+ return _java_obj
+
+
+class TrainValidationSplit(Estimator, ValidatorParams, MLReadable, MLWritable):
"""
Train-Validation-Split.
@@ -315,11 +447,6 @@ class TrainValidationSplit(Estimator, HasSeed):
.. versionadded:: 2.0.0
"""
- estimator = Param(Params._dummy(), "estimator", "estimator to be tested")
- estimatorParamMaps = Param(Params._dummy(), "estimatorParamMaps", "estimator param maps")
- evaluator = Param(
- Params._dummy(), "evaluator",
- "evaluator used to select hyper-parameters that maximize the validated metric")
trainRatio = Param(Params._dummy(), "trainRatio", "Param for ratio between train and\
validation data. Must be between 0 and 1.")
@@ -348,51 +475,6 @@ class TrainValidationSplit(Estimator, HasSeed):
return self._set(**kwargs)
@since("2.0.0")
- def setEstimator(self, value):
- """
- Sets the value of :py:attr:`estimator`.
- """
- self._paramMap[self.estimator] = value
- return self
-
- @since("2.0.0")
- def getEstimator(self):
- """
- Gets the value of estimator or its default value.
- """
- return self.getOrDefault(self.estimator)
-
- @since("2.0.0")
- def setEstimatorParamMaps(self, value):
- """
- Sets the value of :py:attr:`estimatorParamMaps`.
- """
- self._paramMap[self.estimatorParamMaps] = value
- return self
-
- @since("2.0.0")
- def getEstimatorParamMaps(self):
- """
- Gets the value of estimatorParamMaps or its default value.
- """
- return self.getOrDefault(self.estimatorParamMaps)
-
- @since("2.0.0")
- def setEvaluator(self, value):
- """
- Sets the value of :py:attr:`evaluator`.
- """
- self._paramMap[self.evaluator] = value
- return self
-
- @since("2.0.0")
- def getEvaluator(self):
- """
- Gets the value of evaluator or its default value.
- """
- return self.getOrDefault(self.evaluator)
-
- @since("2.0.0")
def setTrainRatio(self, value):
"""
Sets the value of :py:attr:`trainRatio`.
@@ -429,7 +511,7 @@ class TrainValidationSplit(Estimator, HasSeed):
else:
bestIndex = np.argmin(metrics)
bestModel = est.fit(dataset, epm[bestIndex])
- return TrainValidationSplitModel(bestModel)
+ return self._copyValues(TrainValidationSplitModel(bestModel))
@since("2.0.0")
def copy(self, extra=None):
@@ -451,8 +533,59 @@ class TrainValidationSplit(Estimator, HasSeed):
newTVS.setEvaluator(self.getEvaluator().copy(extra))
return newTVS
+ @since("2.0.0")
+ def write(self):
+ """Returns an MLWriter instance for this ML instance."""
+ return JavaMLWriter(self)
+
+ @since("2.0.0")
+ def save(self, path):
+ """Save this ML instance to the given path, a shortcut of `write().save(path)`."""
+ self.write().save(path)
+
+ @classmethod
+ @since("2.0.0")
+ def read(cls):
+ """Returns an MLReader instance for this class."""
+ return JavaMLReader(cls)
+
+ @classmethod
+ def _from_java(cls, java_stage):
+ """
+ Given a Java TrainValidationSplit, create and return a Python wrapper of it.
+ Used for ML persistence.
+ """
+
+ estimator, epms, evaluator = super(TrainValidationSplit, cls)._from_java_impl(java_stage)
+ trainRatio = java_stage.getTrainRatio()
+ seed = java_stage.getSeed()
+ # Create a new instance of this stage.
+ py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator,
+ trainRatio=trainRatio, seed=seed)
+ py_stage._resetUid(java_stage.uid())
+ return py_stage
+
+ def _to_java(self):
+ """
+ Transfer this instance to a Java TrainValidationSplit. Used for ML persistence.
-class TrainValidationSplitModel(Model):
+ :return: Java object equivalent to this instance.
+ """
+
+ estimator, epms, evaluator = super(TrainValidationSplit, self)._to_java_impl()
+
+ _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.TrainValidationSplit",
+ self.uid)
+ _java_obj.setEstimatorParamMaps(epms)
+ _java_obj.setEvaluator(evaluator)
+ _java_obj.setEstimator(estimator)
+ _java_obj.setTrainRatio(self.getTrainRatio())
+ _java_obj.setSeed(self.getSeed())
+
+ return _java_obj
+
+
+class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable):
"""
Model from train validation split.
"""
@@ -480,19 +613,75 @@ class TrainValidationSplitModel(Model):
extra = dict()
return TrainValidationSplitModel(self.bestModel.copy(extra))
+ @since("2.0.0")
+ def write(self):
+ """Returns an MLWriter instance for this ML instance."""
+ return JavaMLWriter(self)
+
+ @since("2.0.0")
+ def save(self, path):
+ """Save this ML instance to the given path, a shortcut of `write().save(path)`."""
+ self.write().save(path)
+
+ @classmethod
+ @since("2.0.0")
+ def read(cls):
+ """Returns an MLReader instance for this class."""
+ return JavaMLReader(cls)
+
+ @classmethod
+ def _from_java(cls, java_stage):
+ """
+ Given a Java TrainValidationSplitModel, create and return a Python wrapper of it.
+ Used for ML persistence.
+ """
+
+ # Load information from java_stage to the instance.
+ bestModel = JavaWrapper._from_java(java_stage.bestModel())
+ estimator, epms, evaluator = \
+ super(TrainValidationSplitModel, cls)._from_java_impl(java_stage)
+ # Create a new instance of this stage.
+ py_stage = cls(bestModel=bestModel)\
+ .setEstimator(estimator).setEstimatorParamMaps(epms).setEvaluator(evaluator)
+ py_stage._resetUid(java_stage.uid())
+ return py_stage
+
+ def _to_java(self):
+ """
+ Transfer this instance to a Java TrainValidationSplitModel. Used for ML persistence.
+
+ :return: Java object equivalent to this instance.
+ """
+
+ sc = SparkContext._active_spark_context
+
+ _java_obj = JavaWrapper._new_java_obj(
+ "org.apache.spark.ml.tuning.TrainValidationSplitModel",
+ self.uid,
+ self.bestModel._to_java(),
+ _py2java(sc, []))
+ estimator, epms, evaluator = super(TrainValidationSplitModel, self)._to_java_impl()
+
+ _java_obj.set("evaluator", evaluator)
+ _java_obj.set("estimator", estimator)
+ _java_obj.set("estimatorParamMaps", epms)
+ return _java_obj
+
+
if __name__ == "__main__":
import doctest
+
from pyspark.context import SparkContext
from pyspark.sql import SQLContext
globs = globals().copy()
+
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
sc = SparkContext("local[2]", "ml.tuning tests")
sqlContext = SQLContext(sc)
globs['sc'] = sc
globs['sqlContext'] = sqlContext
- (failure_count, test_count) = doctest.testmod(
- globs=globs, optionflags=doctest.ELLIPSIS)
+ (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
sc.stop()
if failure_count:
exit(-1)