aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/wrapper.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/wrapper.py')
-rw-r--r--python/pyspark/ml/wrapper.py23
1 files changed, 23 insertions, 0 deletions
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index 35b0eba926..ca93bf7d7d 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -76,6 +76,17 @@ class JavaWrapper(Params):
pair = self._make_java_param_pair(param, paramMap[param])
self._java_obj.set(pair)
+ def _transfer_param_map_to_java(self, pyParamMap):
+ """
+ Transforms a Python ParamMap into a Java ParamMap.
+ """
+ paramMap = JavaWrapper._new_java_obj("org.apache.spark.ml.param.ParamMap")
+ for param in self.params:
+ if param in pyParamMap:
+ pair = self._make_java_param_pair(param, pyParamMap[param])
+ paramMap.put([pair])
+ return paramMap
+
def _transfer_params_from_java(self):
"""
Transforms the embedded params from the companion Java object.
@@ -88,6 +99,18 @@ class JavaWrapper(Params):
value = _java2py(sc, self._java_obj.getOrDefault(java_param))
self._paramMap[param] = value
+ def _transfer_param_map_from_java(self, javaParamMap):
+ """
+ Transforms a Java ParamMap into a Python ParamMap.
+ """
+ sc = SparkContext._active_spark_context
+ paramMap = dict()
+ for pair in javaParamMap.toList():
+ param = pair.param()
+ if self.hasParam(str(param.name())):
+ paramMap[self.getParam(param.name())] = _java2py(sc, pair.value())
+ return paramMap
+
@staticmethod
def _empty_java_param_map():
"""