aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2016-04-29 20:51:24 -0700
committerXiangrui Meng <meng@databricks.com>2016-04-29 20:51:24 -0700
commit09da43d514dc4487af88056404953a1f8fd8bee1 (patch)
treead58fa93d6113089c5aaf772c0d37d0414715b0a /python
parent66773eb8a55bfe6437dd4096c2c55685aca29dcd (diff)
downloadspark-09da43d514dc4487af88056404953a1f8fd8bee1.tar.gz
spark-09da43d514dc4487af88056404953a1f8fd8bee1.tar.bz2
spark-09da43d514dc4487af88056404953a1f8fd8bee1.zip
[SPARK-13786][ML][PYTHON] Removed save/load for python tuning
## What changes were proposed in this pull request? Per discussion on [https://github.com/apache/spark/pull/12604], this removes ML persistence for Python tuning (TrainValidationSplit, CrossValidator, and their Models) since they do not handle nesting easily. This support should be re-designed and added in the next release. ## How was this patch tested? Removed unit test elements saving and loading the tuning algorithms, but kept tests to save and load their bestModel fields. Author: Joseph K. Bradley <joseph@databricks.com> Closes #12782 from jkbradley/remove-python-tuning-saveload.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/ml/tests.py39
-rw-r--r--python/pyspark/ml/tuning.py244
2 files changed, 21 insertions, 262 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index e7d4c0af45..faca148218 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -44,8 +44,7 @@ import numpy as np
from pyspark import keyword_only
from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer
-from pyspark.ml.classification import (
- LogisticRegression, DecisionTreeClassifier, OneVsRest, OneVsRestModel)
+from pyspark.ml.classification import *
from pyspark.ml.clustering import *
from pyspark.ml.evaluation import BinaryClassificationEvaluator, RegressionEvaluator
from pyspark.ml.feature import *
@@ -540,6 +539,8 @@ class CrossValidatorTests(PySparkTestCase):
self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1")
def test_save_load(self):
+ # This tests saving and loading the trained model only.
+ # Save/load for CrossValidator will be added later: SPARK-13786
temp_path = tempfile.mkdtemp()
sqlContext = SQLContext(self.sc)
dataset = sqlContext.createDataFrame(
@@ -554,18 +555,13 @@ class CrossValidatorTests(PySparkTestCase):
evaluator = BinaryClassificationEvaluator()
cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
cvModel = cv.fit(dataset)
- cvPath = temp_path + "/cv"
- cv.save(cvPath)
- loadedCV = CrossValidator.load(cvPath)
- self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid)
- self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid)
- self.assertEqual(loadedCV.getEstimatorParamMaps(), cv.getEstimatorParamMaps())
+ lrModel = cvModel.bestModel
+
cvModelPath = temp_path + "/cvModel"
- cvModel.save(cvModelPath)
- loadedModel = CrossValidatorModel.load(cvModelPath)
- self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid)
- for index in range(len(loadedModel.avgMetrics)):
- self.assertTrue(abs(loadedModel.avgMetrics[index] - cvModel.avgMetrics[index]) < 0.0001)
+ lrModel.save(cvModelPath)
+ loadedLrModel = LogisticRegressionModel.load(cvModelPath)
+ self.assertEqual(loadedLrModel.uid, lrModel.uid)
+ self.assertEqual(loadedLrModel.intercept, lrModel.intercept)
class TrainValidationSplitTests(PySparkTestCase):
@@ -619,6 +615,8 @@ class TrainValidationSplitTests(PySparkTestCase):
self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1")
def test_save_load(self):
+ # This tests saving and loading the trained model only.
+ # Save/load for TrainValidationSplit will be added later: SPARK-13786
temp_path = tempfile.mkdtemp()
sqlContext = SQLContext(self.sc)
dataset = sqlContext.createDataFrame(
@@ -633,16 +631,13 @@ class TrainValidationSplitTests(PySparkTestCase):
evaluator = BinaryClassificationEvaluator()
tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
tvsModel = tvs.fit(dataset)
- tvsPath = temp_path + "/tvs"
- tvs.save(tvsPath)
- loadedTvs = TrainValidationSplit.load(tvsPath)
- self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid)
- self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid)
- self.assertEqual(loadedTvs.getEstimatorParamMaps(), tvs.getEstimatorParamMaps())
+ lrModel = tvsModel.bestModel
+
tvsModelPath = temp_path + "/tvsModel"
- tvsModel.save(tvsModelPath)
- loadedModel = TrainValidationSplitModel.load(tvsModelPath)
- self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid)
+ lrModel.save(tvsModelPath)
+ loadedLrModel = LogisticRegressionModel.load(tvsModelPath)
+ self.assertEqual(loadedLrModel.uid, lrModel.uid)
+ self.assertEqual(loadedLrModel.intercept, lrModel.intercept)
class PersistenceTest(PySparkTestCase):
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index 22f9680cab..eb1f029ebb 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -23,7 +23,6 @@ from pyspark import since, keyword_only
from pyspark.ml import Estimator, Model
from pyspark.ml.param import Params, Param, TypeConverters
from pyspark.ml.param.shared import HasSeed
-from pyspark.ml.util import JavaMLWriter, JavaMLReader, MLReadable, MLWritable
from pyspark.ml.wrapper import JavaParams
from pyspark.sql.functions import rand
from pyspark.mllib.common import inherit_doc, _py2java
@@ -141,37 +140,8 @@ class ValidatorParams(HasSeed):
"""
return self.getOrDefault(self.evaluator)
- @classmethod
- def _from_java_impl(cls, java_stage):
- """
- Return Python estimator, estimatorParamMaps, and evaluator from a Java ValidatorParams.
- """
-
- # Load information from java_stage to the instance.
- estimator = JavaParams._from_java(java_stage.getEstimator())
- evaluator = JavaParams._from_java(java_stage.getEvaluator())
- epms = [estimator._transfer_param_map_from_java(epm)
- for epm in java_stage.getEstimatorParamMaps()]
- return estimator, epms, evaluator
-
- def _to_java_impl(self):
- """
- Return Java estimator, estimatorParamMaps, and evaluator from this Python instance.
- """
-
- gateway = SparkContext._gateway
- cls = SparkContext._jvm.org.apache.spark.ml.param.ParamMap
-
- java_epms = gateway.new_array(cls, len(self.getEstimatorParamMaps()))
- for idx, epm in enumerate(self.getEstimatorParamMaps()):
- java_epms[idx] = self.getEstimator()._transfer_param_map_to_java(epm)
-
- java_estimator = self.getEstimator()._to_java()
- java_evaluator = self.getEvaluator()._to_java()
- return java_estimator, java_epms, java_evaluator
-
-class CrossValidator(Estimator, ValidatorParams, MLReadable, MLWritable):
+class CrossValidator(Estimator, ValidatorParams):
"""
K-fold cross validation.
@@ -288,58 +258,8 @@ class CrossValidator(Estimator, ValidatorParams, MLReadable, MLWritable):
newCV.setEvaluator(self.getEvaluator().copy(extra))
return newCV
- @since("2.0.0")
- def write(self):
- """Returns an MLWriter instance for this ML instance."""
- return JavaMLWriter(self)
-
- @since("2.0.0")
- def save(self, path):
- """Save this ML instance to the given path, a shortcut of `write().save(path)`."""
- self.write().save(path)
-
- @classmethod
- @since("2.0.0")
- def read(cls):
- """Returns an MLReader instance for this class."""
- return JavaMLReader(cls)
- @classmethod
- def _from_java(cls, java_stage):
- """
- Given a Java CrossValidator, create and return a Python wrapper of it.
- Used for ML persistence.
- """
-
- estimator, epms, evaluator = super(CrossValidator, cls)._from_java_impl(java_stage)
- numFolds = java_stage.getNumFolds()
- seed = java_stage.getSeed()
- # Create a new instance of this stage.
- py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator,
- numFolds=numFolds, seed=seed)
- py_stage._resetUid(java_stage.uid())
- return py_stage
-
- def _to_java(self):
- """
- Transfer this instance to a Java CrossValidator. Used for ML persistence.
-
- :return: Java object equivalent to this instance.
- """
-
- estimator, epms, evaluator = super(CrossValidator, self)._to_java_impl()
-
- _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidator", self.uid)
- _java_obj.setEstimatorParamMaps(epms)
- _java_obj.setEvaluator(evaluator)
- _java_obj.setEstimator(estimator)
- _java_obj.setSeed(self.getSeed())
- _java_obj.setNumFolds(self.getNumFolds())
-
- return _java_obj
-
-
-class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable):
+class CrossValidatorModel(Model, ValidatorParams):
"""
Model from k-fold cross validation.
@@ -372,59 +292,8 @@ class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable):
avgMetrics = self.avgMetrics
return CrossValidatorModel(bestModel, avgMetrics)
- @since("2.0.0")
- def write(self):
- """Returns an MLWriter instance for this ML instance."""
- return JavaMLWriter(self)
-
- @since("2.0.0")
- def save(self, path):
- """Save this ML instance to the given path, a shortcut of `write().save(path)`."""
- self.write().save(path)
- @classmethod
- @since("2.0.0")
- def read(cls):
- """Returns an MLReader instance for this class."""
- return JavaMLReader(cls)
-
- @classmethod
- def _from_java(cls, java_stage):
- """
- Given a Java CrossValidatorModel, create and return a Python wrapper of it.
- Used for ML persistence.
- """
-
- # Load information from java_stage to the instance.
- bestModel = JavaParams._from_java(java_stage.bestModel())
- avgMetrics = list(java_stage.avgMetrics())
- estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage)
- # Create a new instance of this stage.
- py_stage = cls(bestModel=bestModel, avgMetrics=avgMetrics)\
- .setEstimator(estimator).setEstimatorParamMaps(epms).setEvaluator(evaluator)
- py_stage._resetUid(java_stage.uid())
- return py_stage
-
- def _to_java(self):
- """
- Transfer this instance to a Java CrossValidatorModel. Used for ML persistence.
-
- :return: Java object equivalent to this instance.
- """
-
- _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel",
- self.uid,
- self.bestModel._to_java(),
- self.avgMetrics)
- estimator, epms, evaluator = super(CrossValidatorModel, self)._to_java_impl()
-
- _java_obj.set("evaluator", evaluator)
- _java_obj.set("estimator", estimator)
- _java_obj.set("estimatorParamMaps", epms)
- return _java_obj
-
-
-class TrainValidationSplit(Estimator, ValidatorParams, MLReadable, MLWritable):
+class TrainValidationSplit(Estimator, ValidatorParams):
"""
Train-Validation-Split.
@@ -535,59 +404,8 @@ class TrainValidationSplit(Estimator, ValidatorParams, MLReadable, MLWritable):
newTVS.setEvaluator(self.getEvaluator().copy(extra))
return newTVS
- @since("2.0.0")
- def write(self):
- """Returns an MLWriter instance for this ML instance."""
- return JavaMLWriter(self)
-
- @since("2.0.0")
- def save(self, path):
- """Save this ML instance to the given path, a shortcut of `write().save(path)`."""
- self.write().save(path)
-
- @classmethod
- @since("2.0.0")
- def read(cls):
- """Returns an MLReader instance for this class."""
- return JavaMLReader(cls)
-
- @classmethod
- def _from_java(cls, java_stage):
- """
- Given a Java TrainValidationSplit, create and return a Python wrapper of it.
- Used for ML persistence.
- """
-
- estimator, epms, evaluator = super(TrainValidationSplit, cls)._from_java_impl(java_stage)
- trainRatio = java_stage.getTrainRatio()
- seed = java_stage.getSeed()
- # Create a new instance of this stage.
- py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator,
- trainRatio=trainRatio, seed=seed)
- py_stage._resetUid(java_stage.uid())
- return py_stage
-
- def _to_java(self):
- """
- Transfer this instance to a Java TrainValidationSplit. Used for ML persistence.
-
- :return: Java object equivalent to this instance.
- """
-
- estimator, epms, evaluator = super(TrainValidationSplit, self)._to_java_impl()
-
- _java_obj = JavaParams._new_java_obj("org.apache.spark.ml.tuning.TrainValidationSplit",
- self.uid)
- _java_obj.setEstimatorParamMaps(epms)
- _java_obj.setEvaluator(evaluator)
- _java_obj.setEstimator(estimator)
- _java_obj.setTrainRatio(self.getTrainRatio())
- _java_obj.setSeed(self.getSeed())
- return _java_obj
-
-
-class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable):
+class TrainValidationSplitModel(Model, ValidatorParams):
"""
Model from train validation split.
@@ -617,60 +435,6 @@ class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable):
extra = dict()
return TrainValidationSplitModel(self.bestModel.copy(extra))
- @since("2.0.0")
- def write(self):
- """Returns an MLWriter instance for this ML instance."""
- return JavaMLWriter(self)
-
- @since("2.0.0")
- def save(self, path):
- """Save this ML instance to the given path, a shortcut of `write().save(path)`."""
- self.write().save(path)
-
- @classmethod
- @since("2.0.0")
- def read(cls):
- """Returns an MLReader instance for this class."""
- return JavaMLReader(cls)
-
- @classmethod
- def _from_java(cls, java_stage):
- """
- Given a Java TrainValidationSplitModel, create and return a Python wrapper of it.
- Used for ML persistence.
- """
-
- # Load information from java_stage to the instance.
- bestModel = JavaParams._from_java(java_stage.bestModel())
- estimator, epms, evaluator = \
- super(TrainValidationSplitModel, cls)._from_java_impl(java_stage)
- # Create a new instance of this stage.
- py_stage = cls(bestModel=bestModel)\
- .setEstimator(estimator).setEstimatorParamMaps(epms).setEvaluator(evaluator)
- py_stage._resetUid(java_stage.uid())
- return py_stage
-
- def _to_java(self):
- """
- Transfer this instance to a Java TrainValidationSplitModel. Used for ML persistence.
-
- :return: Java object equivalent to this instance.
- """
-
- sc = SparkContext._active_spark_context
-
- _java_obj = JavaParams._new_java_obj(
- "org.apache.spark.ml.tuning.TrainValidationSplitModel",
- self.uid,
- self.bestModel._to_java(),
- _py2java(sc, []))
- estimator, epms, evaluator = super(TrainValidationSplitModel, self)._to_java_impl()
-
- _java_obj.set("evaluator", evaluator)
- _java_obj.set("estimator", estimator)
- _java_obj.set("estimatorParamMaps", epms)
- return _java_obj
-
if __name__ == "__main__":
import doctest