aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/wrapper.py
diff options
context:
space:
mode:
authorBurak Yavuz <brkyvz@gmail.com>2015-05-07 10:25:41 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-07 10:25:41 -0700
commit9e2ffb13287e6efe256b8d23a4654e4cc305e20b (patch)
tree79a13615578199c2907d371c965ef031307c47b9 /python/pyspark/ml/wrapper.py
parented9be06a4797bbb678355b361054c8872ac20b75 (diff)
downloadspark-9e2ffb13287e6efe256b8d23a4654e4cc305e20b.tar.gz
spark-9e2ffb13287e6efe256b8d23a4654e4cc305e20b.tar.bz2
spark-9e2ffb13287e6efe256b8d23a4654e4cc305e20b.zip
[SPARK-7388] [SPARK-7383] wrapper for VectorAssembler in Python
The wrapper required the implementation of the `ArrayParam`, because `Array[T]` is hard to obtain from Python. `ArrayParam` has an extra function called `wCast` which is an internal function to obtain `Array[T]` from `Seq[T]` Author: Burak Yavuz <brkyvz@gmail.com> Author: Xiangrui Meng <meng@databricks.com> Closes #5930 from brkyvz/ml-feat and squashes the following commits: 73e745f [Burak Yavuz] Merge pull request #3 from mengxr/SPARK-7388 c221db9 [Xiangrui Meng] overload StringArrayParam.w c81072d [Burak Yavuz] addressed comments 99c2ebf [Burak Yavuz] add to python_shared_params 39ecb07 [Burak Yavuz] fix scalastyle 7f7ea2a [Burak Yavuz] [SPARK-7388][SPARK-7383] wrapper for VectorAssembler in Python
Diffstat (limited to 'python/pyspark/ml/wrapper.py')
-rw-r--r--python/pyspark/ml/wrapper.py13
1 files changed, 7 insertions, 6 deletions
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index 0634254bbd..f5ac2a3986 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -67,7 +67,9 @@ class JavaWrapper(Params):
paramMap = self.extractParamMap(params)
for param in self.params:
if param in paramMap:
- java_obj.set(param.name, paramMap[param])
+ value = paramMap[param]
+ java_param = java_obj.getParam(param.name)
+ java_obj.set(java_param.w(value))
def _empty_java_param_map(self):
"""
@@ -79,7 +81,8 @@ class JavaWrapper(Params):
paramMap = self._empty_java_param_map()
for param, value in params.items():
if param.parent is self:
- paramMap.put(java_obj.getParam(param.name), value)
+ java_param = java_obj.getParam(param.name)
+ paramMap.put(java_param.w(value))
return paramMap
@@ -126,10 +129,8 @@ class JavaTransformer(Transformer, JavaWrapper):
def transform(self, dataset, params={}):
java_obj = self._java_obj()
- self._transfer_params_to_java({}, java_obj)
- java_param_map = self._create_java_param_map(params, java_obj)
- return DataFrame(java_obj.transform(dataset._jdf, java_param_map),
- dataset.sql_ctx)
+ self._transfer_params_to_java(params, java_obj)
+ return DataFrame(java_obj.transform(dataset._jdf), dataset.sql_ctx)
@inherit_doc