aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2016-04-06 11:24:11 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-06 11:24:11 -0700
commitdb0b06c6ea7412266158b1c710bdc8ca30e26430 (patch)
tree58c218ecdbe61927b7f9c3addf11b0bf245ffb2a /python
parent3c8d8821654e3d82ef927c55272348e1bcc34a79 (diff)
downloadspark-db0b06c6ea7412266158b1c710bdc8ca30e26430.tar.gz
spark-db0b06c6ea7412266158b1c710bdc8ca30e26430.tar.bz2
spark-db0b06c6ea7412266158b1c710bdc8ca30e26430.zip
[SPARK-13786][ML][PYSPARK] Add save/load for pyspark.ml.tuning
## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-13786 Add save/load for Python CrossValidator/Model and TrainValidationSplit/Model. ## How was this patch tested? Test with Python doctest. Author: Xusen Yin <yinxusen@gmail.com> Closes #12020 from yinxusen/SPARK-13786.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/ml/tests.py56
-rw-r--r--python/pyspark/ml/tuning.py407
-rw-r--r--python/pyspark/ml/wrapper.py23
3 files changed, 375 insertions, 111 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index f6159b2c95..e3f873e3a7 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -44,7 +44,7 @@ import numpy as np
from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer
from pyspark.ml.classification import LogisticRegression, DecisionTreeClassifier
from pyspark.ml.clustering import KMeans
-from pyspark.ml.evaluation import RegressionEvaluator
+from pyspark.ml.evaluation import BinaryClassificationEvaluator, RegressionEvaluator
from pyspark.ml.feature import *
from pyspark.ml.param import Param, Params, TypeConverters
from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed
@@ -53,7 +53,7 @@ from pyspark.ml.tuning import *
from pyspark.ml.util import keyword_only
from pyspark.ml.util import MLWritable, MLWriter
from pyspark.ml.wrapper import JavaWrapper
-from pyspark.mllib.linalg import DenseVector, SparseVector
+from pyspark.mllib.linalg import Vectors, DenseVector, SparseVector
from pyspark.sql import DataFrame, SQLContext, Row
from pyspark.sql.functions import rand
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
@@ -479,6 +479,32 @@ class CrossValidatorTests(PySparkTestCase):
"Best model should have zero induced error")
self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1")
+ def test_save_load(self):
+ temp_path = tempfile.mkdtemp()
+ sqlContext = SQLContext(self.sc)
+ dataset = sqlContext.createDataFrame(
+ [(Vectors.dense([0.0]), 0.0),
+ (Vectors.dense([0.4]), 1.0),
+ (Vectors.dense([0.5]), 0.0),
+ (Vectors.dense([0.6]), 1.0),
+ (Vectors.dense([1.0]), 1.0)] * 10,
+ ["features", "label"])
+ lr = LogisticRegression()
+ grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
+ evaluator = BinaryClassificationEvaluator()
+ cv = CrossValidator(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
+ cvModel = cv.fit(dataset)
+ cvPath = temp_path + "/cv"
+ cv.save(cvPath)
+ loadedCV = CrossValidator.load(cvPath)
+ self.assertEqual(loadedCV.getEstimator().uid, cv.getEstimator().uid)
+ self.assertEqual(loadedCV.getEvaluator().uid, cv.getEvaluator().uid)
+ self.assertEqual(loadedCV.getEstimatorParamMaps(), cv.getEstimatorParamMaps())
+ cvModelPath = temp_path + "/cvModel"
+ cvModel.save(cvModelPath)
+ loadedModel = CrossValidatorModel.load(cvModelPath)
+ self.assertEqual(loadedModel.bestModel.uid, cvModel.bestModel.uid)
+
class TrainValidationSplitTests(PySparkTestCase):
@@ -530,6 +556,32 @@ class TrainValidationSplitTests(PySparkTestCase):
"Best model should have zero induced error")
self.assertEqual(1.0, bestModelMetric, "Best model has R-squared of 1")
+ def test_save_load(self):
+ temp_path = tempfile.mkdtemp()
+ sqlContext = SQLContext(self.sc)
+ dataset = sqlContext.createDataFrame(
+ [(Vectors.dense([0.0]), 0.0),
+ (Vectors.dense([0.4]), 1.0),
+ (Vectors.dense([0.5]), 0.0),
+ (Vectors.dense([0.6]), 1.0),
+ (Vectors.dense([1.0]), 1.0)] * 10,
+ ["features", "label"])
+ lr = LogisticRegression()
+ grid = ParamGridBuilder().addGrid(lr.maxIter, [0, 1]).build()
+ evaluator = BinaryClassificationEvaluator()
+ tvs = TrainValidationSplit(estimator=lr, estimatorParamMaps=grid, evaluator=evaluator)
+ tvsModel = tvs.fit(dataset)
+ tvsPath = temp_path + "/tvs"
+ tvs.save(tvsPath)
+ loadedTvs = TrainValidationSplit.load(tvsPath)
+ self.assertEqual(loadedTvs.getEstimator().uid, tvs.getEstimator().uid)
+ self.assertEqual(loadedTvs.getEvaluator().uid, tvs.getEvaluator().uid)
+ self.assertEqual(loadedTvs.getEstimatorParamMaps(), tvs.getEstimatorParamMaps())
+ tvsModelPath = temp_path + "/tvsModel"
+ tvsModel.save(tvsModelPath)
+ loadedModel = TrainValidationSplitModel.load(tvsModelPath)
+ self.assertEqual(loadedModel.bestModel.uid, tvsModel.bestModel.uid)
+
class PersistenceTest(PySparkTestCase):
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index a528d22e18..da00f317b3 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -18,12 +18,15 @@
import itertools
import numpy as np
+from pyspark import SparkContext
from pyspark import since
from pyspark.ml import Estimator, Model
from pyspark.ml.param import Params, Param, TypeConverters
from pyspark.ml.param.shared import HasSeed
-from pyspark.ml.util import keyword_only
+from pyspark.ml.util import keyword_only, JavaMLWriter, JavaMLReader, MLReadable, MLWritable
+from pyspark.ml.wrapper import JavaWrapper
from pyspark.sql.functions import rand
+from pyspark.mllib.common import inherit_doc, _py2java
__all__ = ['ParamGridBuilder', 'CrossValidator', 'CrossValidatorModel', 'TrainValidationSplit',
'TrainValidationSplitModel']
@@ -91,7 +94,84 @@ class ParamGridBuilder(object):
return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)]
-class CrossValidator(Estimator, HasSeed):
+class ValidatorParams(HasSeed):
+ """
+ Common params for TrainValidationSplit and CrossValidator.
+ """
+
+ estimator = Param(Params._dummy(), "estimator", "estimator to be cross-validated")
+ estimatorParamMaps = Param(Params._dummy(), "estimatorParamMaps", "estimator param maps")
+ evaluator = Param(
+ Params._dummy(), "evaluator",
+ "evaluator used to select hyper-parameters that maximize the validator metric")
+
+ def setEstimator(self, value):
+ """
+ Sets the value of :py:attr:`estimator`.
+ """
+ return self._set(estimator=value)
+
+ def getEstimator(self):
+ """
+ Gets the value of estimator or its default value.
+ """
+ return self.getOrDefault(self.estimator)
+
+ def setEstimatorParamMaps(self, value):
+ """
+ Sets the value of :py:attr:`estimatorParamMaps`.
+ """
+ return self._set(estimatorParamMaps=value)
+
+ def getEstimatorParamMaps(self):
+ """
+ Gets the value of estimatorParamMaps or its default value.
+ """
+ return self.getOrDefault(self.estimatorParamMaps)
+
+ def setEvaluator(self, value):
+ """
+ Sets the value of :py:attr:`evaluator`.
+ """
+ return self._set(evaluator=value)
+
+ def getEvaluator(self):
+ """
+ Gets the value of evaluator or its default value.
+ """
+ return self.getOrDefault(self.evaluator)
+
+ @classmethod
+ def _from_java_impl(cls, java_stage):
+ """
+ Return Python estimator, estimatorParamMaps, and evaluator from a Java ValidatorParams.
+ """
+
+ # Load information from java_stage to the instance.
+ estimator = JavaWrapper._from_java(java_stage.getEstimator())
+ evaluator = JavaWrapper._from_java(java_stage.getEvaluator())
+ epms = [estimator._transfer_param_map_from_java(epm)
+ for epm in java_stage.getEstimatorParamMaps()]
+ return estimator, epms, evaluator
+
+ def _to_java_impl(self):
+ """
+ Return Java estimator, estimatorParamMaps, and evaluator from this Python instance.
+ """
+
+ gateway = SparkContext._gateway
+ cls = SparkContext._jvm.org.apache.spark.ml.param.ParamMap
+
+ java_epms = gateway.new_array(cls, len(self.getEstimatorParamMaps()))
+ for idx, epm in enumerate(self.getEstimatorParamMaps()):
+ java_epms[idx] = self.getEstimator()._transfer_param_map_to_java(epm)
+
+ java_estimator = self.getEstimator()._to_java()
+ java_evaluator = self.getEvaluator()._to_java()
+ return java_estimator, java_epms, java_evaluator
+
+
+class CrossValidator(Estimator, ValidatorParams, MLReadable, MLWritable):
"""
K-fold cross validation.
@@ -116,11 +196,6 @@ class CrossValidator(Estimator, HasSeed):
.. versionadded:: 1.4.0
"""
- estimator = Param(Params._dummy(), "estimator", "estimator to be cross-validated")
- estimatorParamMaps = Param(Params._dummy(), "estimatorParamMaps", "estimator param maps")
- evaluator = Param(
- Params._dummy(), "evaluator",
- "evaluator used to select hyper-parameters that maximize the cross-validated metric")
numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation",
typeConverter=TypeConverters.toInt)
@@ -149,51 +224,6 @@ class CrossValidator(Estimator, HasSeed):
return self._set(**kwargs)
@since("1.4.0")
- def setEstimator(self, value):
- """
- Sets the value of :py:attr:`estimator`.
- """
- self._paramMap[self.estimator] = value
- return self
-
- @since("1.4.0")
- def getEstimator(self):
- """
- Gets the value of estimator or its default value.
- """
- return self.getOrDefault(self.estimator)
-
- @since("1.4.0")
- def setEstimatorParamMaps(self, value):
- """
- Sets the value of :py:attr:`estimatorParamMaps`.
- """
- self._paramMap[self.estimatorParamMaps] = value
- return self
-
- @since("1.4.0")
- def getEstimatorParamMaps(self):
- """
- Gets the value of estimatorParamMaps or its default value.
- """
- return self.getOrDefault(self.estimatorParamMaps)
-
- @since("1.4.0")
- def setEvaluator(self, value):
- """
- Sets the value of :py:attr:`evaluator`.
- """
- self._paramMap[self.evaluator] = value
- return self
-
- @since("1.4.0")
- def getEvaluator(self):
- """
- Gets the value of evaluator or its default value.
- """
- return self.getOrDefault(self.evaluator)
-
- @since("1.4.0")
def setNumFolds(self, value):
"""
Sets the value of :py:attr:`numFolds`.
@@ -236,7 +266,7 @@ class CrossValidator(Estimator, HasSeed):
else:
bestIndex = np.argmin(metrics)
bestModel = est.fit(dataset, epm[bestIndex])
- return CrossValidatorModel(bestModel)
+ return self._copyValues(CrossValidatorModel(bestModel))
@since("1.4.0")
def copy(self, extra=None):
@@ -258,8 +288,58 @@ class CrossValidator(Estimator, HasSeed):
newCV.setEvaluator(self.getEvaluator().copy(extra))
return newCV
+ @since("2.0.0")
+ def write(self):
+ """Returns an MLWriter instance for this ML instance."""
+ return JavaMLWriter(self)
+
+ @since("2.0.0")
+ def save(self, path):
+ """Save this ML instance to the given path, a shortcut of `write().save(path)`."""
+ self.write().save(path)
+
+ @classmethod
+ @since("2.0.0")
+ def read(cls):
+ """Returns an MLReader instance for this class."""
+ return JavaMLReader(cls)
+
+ @classmethod
+ def _from_java(cls, java_stage):
+ """
+ Given a Java CrossValidator, create and return a Python wrapper of it.
+ Used for ML persistence.
+ """
+
+ estimator, epms, evaluator = super(CrossValidator, cls)._from_java_impl(java_stage)
+ numFolds = java_stage.getNumFolds()
+ seed = java_stage.getSeed()
+ # Create a new instance of this stage.
+ py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator,
+ numFolds=numFolds, seed=seed)
+ py_stage._resetUid(java_stage.uid())
+ return py_stage
-class CrossValidatorModel(Model):
+ def _to_java(self):
+ """
+ Transfer this instance to a Java CrossValidator. Used for ML persistence.
+
+ :return: Java object equivalent to this instance.
+ """
+
+ estimator, epms, evaluator = super(CrossValidator, self)._to_java_impl()
+
+ _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.CrossValidator", self.uid)
+ _java_obj.setEstimatorParamMaps(epms)
+ _java_obj.setEvaluator(evaluator)
+ _java_obj.setEstimator(estimator)
+ _java_obj.setSeed(self.getSeed())
+ _java_obj.setNumFolds(self.getNumFolds())
+
+ return _java_obj
+
+
+class CrossValidatorModel(Model, ValidatorParams, MLReadable, MLWritable):
"""
Model from k-fold cross validation.
@@ -289,8 +369,60 @@ class CrossValidatorModel(Model):
extra = dict()
return CrossValidatorModel(self.bestModel.copy(extra))
+ @since("2.0.0")
+ def write(self):
+ """Returns an MLWriter instance for this ML instance."""
+ return JavaMLWriter(self)
+
+ @since("2.0.0")
+ def save(self, path):
+ """Save this ML instance to the given path, a shortcut of `write().save(path)`."""
+ self.write().save(path)
-class TrainValidationSplit(Estimator, HasSeed):
+ @classmethod
+ @since("2.0.0")
+ def read(cls):
+ """Returns an MLReader instance for this class."""
+ return JavaMLReader(cls)
+
+ @classmethod
+ def _from_java(cls, java_stage):
+ """
+ Given a Java CrossValidatorModel, create and return a Python wrapper of it.
+ Used for ML persistence.
+ """
+
+ # Load information from java_stage to the instance.
+ bestModel = JavaWrapper._from_java(java_stage.bestModel())
+ estimator, epms, evaluator = super(CrossValidatorModel, cls)._from_java_impl(java_stage)
+ # Create a new instance of this stage.
+ py_stage = cls(bestModel=bestModel)\
+ .setEstimator(estimator).setEstimatorParamMaps(epms).setEvaluator(evaluator)
+ py_stage._resetUid(java_stage.uid())
+ return py_stage
+
+ def _to_java(self):
+ """
+ Transfer this instance to a Java CrossValidatorModel. Used for ML persistence.
+
+ :return: Java object equivalent to this instance.
+ """
+
+ sc = SparkContext._active_spark_context
+
+ _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.CrossValidatorModel",
+ self.uid,
+ self.bestModel._to_java(),
+ _py2java(sc, []))
+ estimator, epms, evaluator = super(CrossValidatorModel, self)._to_java_impl()
+
+ _java_obj.set("evaluator", evaluator)
+ _java_obj.set("estimator", estimator)
+ _java_obj.set("estimatorParamMaps", epms)
+ return _java_obj
+
+
+class TrainValidationSplit(Estimator, ValidatorParams, MLReadable, MLWritable):
"""
Train-Validation-Split.
@@ -315,11 +447,6 @@ class TrainValidationSplit(Estimator, HasSeed):
.. versionadded:: 2.0.0
"""
- estimator = Param(Params._dummy(), "estimator", "estimator to be tested")
- estimatorParamMaps = Param(Params._dummy(), "estimatorParamMaps", "estimator param maps")
- evaluator = Param(
- Params._dummy(), "evaluator",
- "evaluator used to select hyper-parameters that maximize the validated metric")
trainRatio = Param(Params._dummy(), "trainRatio", "Param for ratio between train and\
validation data. Must be between 0 and 1.")
@@ -348,51 +475,6 @@ class TrainValidationSplit(Estimator, HasSeed):
return self._set(**kwargs)
@since("2.0.0")
- def setEstimator(self, value):
- """
- Sets the value of :py:attr:`estimator`.
- """
- self._paramMap[self.estimator] = value
- return self
-
- @since("2.0.0")
- def getEstimator(self):
- """
- Gets the value of estimator or its default value.
- """
- return self.getOrDefault(self.estimator)
-
- @since("2.0.0")
- def setEstimatorParamMaps(self, value):
- """
- Sets the value of :py:attr:`estimatorParamMaps`.
- """
- self._paramMap[self.estimatorParamMaps] = value
- return self
-
- @since("2.0.0")
- def getEstimatorParamMaps(self):
- """
- Gets the value of estimatorParamMaps or its default value.
- """
- return self.getOrDefault(self.estimatorParamMaps)
-
- @since("2.0.0")
- def setEvaluator(self, value):
- """
- Sets the value of :py:attr:`evaluator`.
- """
- self._paramMap[self.evaluator] = value
- return self
-
- @since("2.0.0")
- def getEvaluator(self):
- """
- Gets the value of evaluator or its default value.
- """
- return self.getOrDefault(self.evaluator)
-
- @since("2.0.0")
def setTrainRatio(self, value):
"""
Sets the value of :py:attr:`trainRatio`.
@@ -429,7 +511,7 @@ class TrainValidationSplit(Estimator, HasSeed):
else:
bestIndex = np.argmin(metrics)
bestModel = est.fit(dataset, epm[bestIndex])
- return TrainValidationSplitModel(bestModel)
+ return self._copyValues(TrainValidationSplitModel(bestModel))
@since("2.0.0")
def copy(self, extra=None):
@@ -451,8 +533,59 @@ class TrainValidationSplit(Estimator, HasSeed):
newTVS.setEvaluator(self.getEvaluator().copy(extra))
return newTVS
+ @since("2.0.0")
+ def write(self):
+ """Returns an MLWriter instance for this ML instance."""
+ return JavaMLWriter(self)
+
+ @since("2.0.0")
+ def save(self, path):
+ """Save this ML instance to the given path, a shortcut of `write().save(path)`."""
+ self.write().save(path)
+
+ @classmethod
+ @since("2.0.0")
+ def read(cls):
+ """Returns an MLReader instance for this class."""
+ return JavaMLReader(cls)
+
+ @classmethod
+ def _from_java(cls, java_stage):
+ """
+ Given a Java TrainValidationSplit, create and return a Python wrapper of it.
+ Used for ML persistence.
+ """
+
+ estimator, epms, evaluator = super(TrainValidationSplit, cls)._from_java_impl(java_stage)
+ trainRatio = java_stage.getTrainRatio()
+ seed = java_stage.getSeed()
+ # Create a new instance of this stage.
+ py_stage = cls(estimator=estimator, estimatorParamMaps=epms, evaluator=evaluator,
+ trainRatio=trainRatio, seed=seed)
+ py_stage._resetUid(java_stage.uid())
+ return py_stage
+
+ def _to_java(self):
+ """
+ Transfer this instance to a Java TrainValidationSplit. Used for ML persistence.
-class TrainValidationSplitModel(Model):
+ :return: Java object equivalent to this instance.
+ """
+
+ estimator, epms, evaluator = super(TrainValidationSplit, self)._to_java_impl()
+
+ _java_obj = JavaWrapper._new_java_obj("org.apache.spark.ml.tuning.TrainValidationSplit",
+ self.uid)
+ _java_obj.setEstimatorParamMaps(epms)
+ _java_obj.setEvaluator(evaluator)
+ _java_obj.setEstimator(estimator)
+ _java_obj.setTrainRatio(self.getTrainRatio())
+ _java_obj.setSeed(self.getSeed())
+
+ return _java_obj
+
+
+class TrainValidationSplitModel(Model, ValidatorParams, MLReadable, MLWritable):
"""
Model from train validation split.
"""
@@ -480,19 +613,75 @@ class TrainValidationSplitModel(Model):
extra = dict()
return TrainValidationSplitModel(self.bestModel.copy(extra))
+ @since("2.0.0")
+ def write(self):
+ """Returns an MLWriter instance for this ML instance."""
+ return JavaMLWriter(self)
+
+ @since("2.0.0")
+ def save(self, path):
+ """Save this ML instance to the given path, a shortcut of `write().save(path)`."""
+ self.write().save(path)
+
+ @classmethod
+ @since("2.0.0")
+ def read(cls):
+ """Returns an MLReader instance for this class."""
+ return JavaMLReader(cls)
+
+ @classmethod
+ def _from_java(cls, java_stage):
+ """
+ Given a Java TrainValidationSplitModel, create and return a Python wrapper of it.
+ Used for ML persistence.
+ """
+
+ # Load information from java_stage to the instance.
+ bestModel = JavaWrapper._from_java(java_stage.bestModel())
+ estimator, epms, evaluator = \
+ super(TrainValidationSplitModel, cls)._from_java_impl(java_stage)
+ # Create a new instance of this stage.
+ py_stage = cls(bestModel=bestModel)\
+ .setEstimator(estimator).setEstimatorParamMaps(epms).setEvaluator(evaluator)
+ py_stage._resetUid(java_stage.uid())
+ return py_stage
+
+ def _to_java(self):
+ """
+ Transfer this instance to a Java TrainValidationSplitModel. Used for ML persistence.
+
+ :return: Java object equivalent to this instance.
+ """
+
+ sc = SparkContext._active_spark_context
+
+ _java_obj = JavaWrapper._new_java_obj(
+ "org.apache.spark.ml.tuning.TrainValidationSplitModel",
+ self.uid,
+ self.bestModel._to_java(),
+ _py2java(sc, []))
+ estimator, epms, evaluator = super(TrainValidationSplitModel, self)._to_java_impl()
+
+ _java_obj.set("evaluator", evaluator)
+ _java_obj.set("estimator", estimator)
+ _java_obj.set("estimatorParamMaps", epms)
+ return _java_obj
+
+
if __name__ == "__main__":
import doctest
+
from pyspark.context import SparkContext
from pyspark.sql import SQLContext
globs = globals().copy()
+
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
sc = SparkContext("local[2]", "ml.tuning tests")
sqlContext = SQLContext(sc)
globs['sc'] = sc
globs['sqlContext'] = sqlContext
- (failure_count, test_count) = doctest.testmod(
- globs=globs, optionflags=doctest.ELLIPSIS)
+ (failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
sc.stop()
if failure_count:
exit(-1)
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index 35b0eba926..ca93bf7d7d 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -76,6 +76,17 @@ class JavaWrapper(Params):
pair = self._make_java_param_pair(param, paramMap[param])
self._java_obj.set(pair)
+ def _transfer_param_map_to_java(self, pyParamMap):
+ """
+ Transforms a Python ParamMap into a Java ParamMap.
+ """
+ paramMap = JavaWrapper._new_java_obj("org.apache.spark.ml.param.ParamMap")
+ for param in self.params:
+ if param in pyParamMap:
+ pair = self._make_java_param_pair(param, pyParamMap[param])
+ paramMap.put([pair])
+ return paramMap
+
def _transfer_params_from_java(self):
"""
Transforms the embedded params from the companion Java object.
@@ -88,6 +99,18 @@ class JavaWrapper(Params):
value = _java2py(sc, self._java_obj.getOrDefault(java_param))
self._paramMap[param] = value
+ def _transfer_param_map_from_java(self, javaParamMap):
+ """
+ Transforms a Java ParamMap into a Python ParamMap.
+ """
+ sc = SparkContext._active_spark_context
+ paramMap = dict()
+ for pair in javaParamMap.toList():
+ param = pair.param()
+ if self.hasParam(str(param.name())):
+ paramMap[self.getParam(param.name())] = _java2py(sc, pair.value())
+ return paramMap
+
@staticmethod
def _empty_java_param_map():
"""