diff options
author | Burak Yavuz <brkyvz@gmail.com> | 2015-05-07 10:25:41 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-05-07 10:25:41 -0700 |
commit | 9e2ffb13287e6efe256b8d23a4654e4cc305e20b (patch) | |
tree | 79a13615578199c2907d371c965ef031307c47b9 /python/pyspark/ml/wrapper.py | |
parent | ed9be06a4797bbb678355b361054c8872ac20b75 (diff) | |
download | spark-9e2ffb13287e6efe256b8d23a4654e4cc305e20b.tar.gz spark-9e2ffb13287e6efe256b8d23a4654e4cc305e20b.tar.bz2 spark-9e2ffb13287e6efe256b8d23a4654e4cc305e20b.zip |
[SPARK-7388] [SPARK-7383] wrapper for VectorAssembler in Python
The wrapper required the implementation of the `ArrayParam`, because `Array[T]` is hard to obtain from Python. `ArrayParam` has an extra function called `wCast` which is an internal function to obtain `Array[T]` from `Seq[T]`
Author: Burak Yavuz <brkyvz@gmail.com>
Author: Xiangrui Meng <meng@databricks.com>
Closes #5930 from brkyvz/ml-feat and squashes the following commits:
73e745f [Burak Yavuz] Merge pull request #3 from mengxr/SPARK-7388
c221db9 [Xiangrui Meng] overload StringArrayParam.w
c81072d [Burak Yavuz] addressed comments
99c2ebf [Burak Yavuz] add to python_shared_params
39ecb07 [Burak Yavuz] fix scalastyle
7f7ea2a [Burak Yavuz] [SPARK-7388][SPARK-7383] wrapper for VectorAssembler in Python
Diffstat (limited to 'python/pyspark/ml/wrapper.py')
-rw-r--r-- | python/pyspark/ml/wrapper.py | 13 |
1 files changed, 7 insertions, 6 deletions
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 0634254bbd..f5ac2a3986 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -67,7 +67,9 @@ class JavaWrapper(Params): paramMap = self.extractParamMap(params) for param in self.params: if param in paramMap: - java_obj.set(param.name, paramMap[param]) + value = paramMap[param] + java_param = java_obj.getParam(param.name) + java_obj.set(java_param.w(value)) def _empty_java_param_map(self): """ @@ -79,7 +81,8 @@ class JavaWrapper(Params): paramMap = self._empty_java_param_map() for param, value in params.items(): if param.parent is self: - paramMap.put(java_obj.getParam(param.name), value) + java_param = java_obj.getParam(param.name) + paramMap.put(java_param.w(value)) return paramMap @@ -126,10 +129,8 @@ class JavaTransformer(Transformer, JavaWrapper): def transform(self, dataset, params={}): java_obj = self._java_obj() - self._transfer_params_to_java({}, java_obj) - java_param_map = self._create_java_param_map(params, java_obj) - return DataFrame(java_obj.transform(dataset._jdf, java_param_map), - dataset.sql_ctx) + self._transfer_params_to_java(params, java_obj) + return DataFrame(java_obj.transform(dataset._jdf), dataset.sql_ctx) @inherit_doc |