aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/wrapper.py
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2016-04-06 11:24:11 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-06 11:24:11 -0700
commitdb0b06c6ea7412266158b1c710bdc8ca30e26430 (patch)
tree58c218ecdbe61927b7f9c3addf11b0bf245ffb2a /python/pyspark/ml/wrapper.py
parent3c8d8821654e3d82ef927c55272348e1bcc34a79 (diff)
downloadspark-db0b06c6ea7412266158b1c710bdc8ca30e26430.tar.gz
spark-db0b06c6ea7412266158b1c710bdc8ca30e26430.tar.bz2
spark-db0b06c6ea7412266158b1c710bdc8ca30e26430.zip
[SPARK-13786][ML][PYSPARK] Add save/load for pyspark.ml.tuning
## What changes were proposed in this pull request? https://issues.apache.org/jira/browse/SPARK-13786 Add save/load for Python CrossValidator/Model and TrainValidationSplit/Model. ## How was this patch tested? Test with Python doctest. Author: Xusen Yin <yinxusen@gmail.com> Closes #12020 from yinxusen/SPARK-13786.
Diffstat (limited to 'python/pyspark/ml/wrapper.py')
-rw-r--r--python/pyspark/ml/wrapper.py23
1 files changed, 23 insertions, 0 deletions
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index 35b0eba926..ca93bf7d7d 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -76,6 +76,17 @@ class JavaWrapper(Params):
pair = self._make_java_param_pair(param, paramMap[param])
self._java_obj.set(pair)
+ def _transfer_param_map_to_java(self, pyParamMap):
+ """
+ Transforms a Python ParamMap into a Java ParamMap.
+ """
+ paramMap = JavaWrapper._new_java_obj("org.apache.spark.ml.param.ParamMap")
+ for param in self.params:
+ if param in pyParamMap:
+ pair = self._make_java_param_pair(param, pyParamMap[param])
+ paramMap.put([pair])
+ return paramMap
+
def _transfer_params_from_java(self):
"""
Transforms the embedded params from the companion Java object.
@@ -88,6 +99,18 @@ class JavaWrapper(Params):
value = _java2py(sc, self._java_obj.getOrDefault(java_param))
self._paramMap[param] = value
+ def _transfer_param_map_from_java(self, javaParamMap):
+ """
+ Transforms a Java ParamMap into a Python ParamMap.
+ """
+ sc = SparkContext._active_spark_context
+ paramMap = dict()
+ for pair in javaParamMap.toList():
+ param = pair.param()
+ if self.hasParam(str(param.name())):
+ paramMap[self.getParam(param.name())] = _java2py(sc, pair.value())
+ return paramMap
+
@staticmethod
def _empty_java_param_map():
"""