From db0b06c6ea7412266158b1c710bdc8ca30e26430 Mon Sep 17 00:00:00 2001 From: Xusen Yin Date: Wed, 6 Apr 2016 11:24:11 -0700 Subject: [SPARK-13786][ML][PYSPARK] Add save/load for pyspark.ml.tuning ## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-13786 Add save/load for Python CrossValidator/Model and TrainValidationSplit/Model. ## How was this patch tested? Test with Python doctest. Author: Xusen Yin Closes #12020 from yinxusen/SPARK-13786. --- python/pyspark/ml/wrapper.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) (limited to 'python/pyspark/ml/wrapper.py') diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 35b0eba926..ca93bf7d7d 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -76,6 +76,17 @@ class JavaWrapper(Params): pair = self._make_java_param_pair(param, paramMap[param]) self._java_obj.set(pair) + def _transfer_param_map_to_java(self, pyParamMap): + """ + Transforms a Python ParamMap into a Java ParamMap. + """ + paramMap = JavaWrapper._new_java_obj("org.apache.spark.ml.param.ParamMap") + for param in self.params: + if param in pyParamMap: + pair = self._make_java_param_pair(param, pyParamMap[param]) + paramMap.put([pair]) + return paramMap + def _transfer_params_from_java(self): """ Transforms the embedded params from the companion Java object. @@ -88,6 +99,18 @@ class JavaWrapper(Params): value = _java2py(sc, self._java_obj.getOrDefault(java_param)) self._paramMap[param] = value + def _transfer_param_map_from_java(self, javaParamMap): + """ + Transforms a Java ParamMap into a Python ParamMap. + """ + sc = SparkContext._active_spark_context + paramMap = dict() + for pair in javaParamMap.toList(): + param = pair.param() + if self.hasParam(str(param.name())): + paramMap[self.getParam(param.name())] = _java2py(sc, pair.value()) + return paramMap + @staticmethod def _empty_java_param_map(): """ -- cgit v1.2.3