diff options
author | Xusen Yin <yinxusen@gmail.com> | 2016-04-06 11:24:11 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-04-06 11:24:11 -0700 |
commit | db0b06c6ea7412266158b1c710bdc8ca30e26430 (patch) | |
tree | 58c218ecdbe61927b7f9c3addf11b0bf245ffb2a /python/pyspark/ml/wrapper.py | |
parent | 3c8d8821654e3d82ef927c55272348e1bcc34a79 (diff) | |
download | spark-db0b06c6ea7412266158b1c710bdc8ca30e26430.tar.gz spark-db0b06c6ea7412266158b1c710bdc8ca30e26430.tar.bz2 spark-db0b06c6ea7412266158b1c710bdc8ca30e26430.zip |
[SPARK-13786][ML][PYSPARK] Add save/load for pyspark.ml.tuning
## What changes were proposed in this pull request?
https://issues.apache.org/jira/browse/SPARK-13786
Add save/load for Python CrossValidator/Model and TrainValidationSplit/Model.
## How was this patch tested?
Test with Python doctest.
Author: Xusen Yin <yinxusen@gmail.com>
Closes #12020 from yinxusen/SPARK-13786.
Diffstat (limited to 'python/pyspark/ml/wrapper.py')
-rw-r--r-- | python/pyspark/ml/wrapper.py | 23 |
1 files changed, 23 insertions, 0 deletions
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 35b0eba926..ca93bf7d7d 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -76,6 +76,17 @@ class JavaWrapper(Params): pair = self._make_java_param_pair(param, paramMap[param]) self._java_obj.set(pair) + def _transfer_param_map_to_java(self, pyParamMap): + """ + Transforms a Python ParamMap into a Java ParamMap. + """ + paramMap = JavaWrapper._new_java_obj("org.apache.spark.ml.param.ParamMap") + for param in self.params: + if param in pyParamMap: + pair = self._make_java_param_pair(param, pyParamMap[param]) + paramMap.put([pair]) + return paramMap + def _transfer_params_from_java(self): """ Transforms the embedded params from the companion Java object. @@ -88,6 +99,18 @@ class JavaWrapper(Params): value = _java2py(sc, self._java_obj.getOrDefault(java_param)) self._paramMap[param] = value + def _transfer_param_map_from_java(self, javaParamMap): + """ + Transforms a Java ParamMap into a Python ParamMap. + """ + sc = SparkContext._active_spark_context + paramMap = dict() + for pair in javaParamMap.toList(): + param = pair.param() + if self.hasParam(str(param.name())): + paramMap[self.getParam(param.name())] = _java2py(sc, pair.value()) + return paramMap + @staticmethod def _empty_java_param_map(): """ |