diff options
author | Burak Yavuz <brkyvz@gmail.com> | 2015-05-07 10:25:41 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2015-05-07 10:25:41 -0700 |
commit | 9e2ffb13287e6efe256b8d23a4654e4cc305e20b (patch) | |
tree | 79a13615578199c2907d371c965ef031307c47b9 /python/pyspark | |
parent | ed9be06a4797bbb678355b361054c8872ac20b75 (diff) | |
download | spark-9e2ffb13287e6efe256b8d23a4654e4cc305e20b.tar.gz spark-9e2ffb13287e6efe256b8d23a4654e4cc305e20b.tar.bz2 spark-9e2ffb13287e6efe256b8d23a4654e4cc305e20b.zip |
[SPARK-7388] [SPARK-7383] wrapper for VectorAssembler in Python
The wrapper required the implementation of the `ArrayParam`, because `Array[T]` is hard to obtain from Python. `ArrayParam` has an extra function called `wCast` which is an internal function to obtain `Array[T]` from `Seq[T]`
Author: Burak Yavuz <brkyvz@gmail.com>
Author: Xiangrui Meng <meng@databricks.com>
Closes #5930 from brkyvz/ml-feat and squashes the following commits:
73e745f [Burak Yavuz] Merge pull request #3 from mengxr/SPARK-7388
c221db9 [Xiangrui Meng] overload StringArrayParam.w
c81072d [Burak Yavuz] addressed comments
99c2ebf [Burak Yavuz] add to python_shared_params
39ecb07 [Burak Yavuz] fix scalastyle
7f7ea2a [Burak Yavuz] [SPARK-7388][SPARK-7383] wrapper for VectorAssembler in Python
Diffstat (limited to 'python/pyspark')
-rw-r--r-- | python/pyspark/ml/feature.py | 43 | ||||
-rw-r--r-- | python/pyspark/ml/param/_shared_params_code_gen.py | 1 | ||||
-rw-r--r-- | python/pyspark/ml/param/shared.py | 29 | ||||
-rw-r--r-- | python/pyspark/ml/wrapper.py | 13 |
4 files changed, 78 insertions, 8 deletions
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 4e4614b859..8a0fdddd2d 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -16,12 +16,12 @@ # from pyspark.rdd import ignore_unicode_prefix -from pyspark.ml.param.shared import HasInputCol, HasOutputCol, HasNumFeatures +from pyspark.ml.param.shared import HasInputCol, HasInputCols, HasOutputCol, HasNumFeatures from pyspark.ml.util import keyword_only from pyspark.ml.wrapper import JavaTransformer from pyspark.mllib.common import inherit_doc -__all__ = ['Tokenizer', 'HashingTF'] +__all__ = ['Tokenizer', 'HashingTF', 'VectorAssembler'] @inherit_doc @@ -112,6 +112,45 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures): return self._set(**kwargs) +@inherit_doc +class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol): + """ + A feature transformer that merges multiple columns into a vector column. + + >>> from pyspark.sql import Row + >>> df = sc.parallelize([Row(a=1, b=0, c=3)]).toDF() + >>> vecAssembler = VectorAssembler(inputCols=["a", "b", "c"], outputCol="features") + >>> vecAssembler.transform(df).head().features + SparseVector(3, {0: 1.0, 2: 3.0}) + >>> vecAssembler.setParams(outputCol="freqs").transform(df).head().freqs + SparseVector(3, {0: 1.0, 2: 3.0}) + >>> params = {vecAssembler.inputCols: ["b", "a"], vecAssembler.outputCol: "vector"} + >>> vecAssembler.transform(df, params).head().vector + SparseVector(2, {1: 1.0}) + """ + + _java_class = "org.apache.spark.ml.feature.VectorAssembler" + + @keyword_only + def __init__(self, inputCols=None, outputCol=None): + """ + __init__(self, inputCols=None, outputCol=None) + """ + super(VectorAssembler, self).__init__() + self._setDefault() + kwargs = self.__init__._input_kwargs + self.setParams(**kwargs) + + @keyword_only + def setParams(self, inputCols=None, outputCol=None): + """ + setParams(self, inputCols=None, outputCol=None) + Sets params for this VectorAssembler. + """ + kwargs = self.setParams._input_kwargs + return self._set(**kwargs) + + if __name__ == "__main__": import doctest from pyspark.context import SparkContext diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py index c71c823db2..c1c8e921dd 100644 --- a/python/pyspark/ml/param/_shared_params_code_gen.py +++ b/python/pyspark/ml/param/_shared_params_code_gen.py @@ -95,6 +95,7 @@ if __name__ == "__main__": ("predictionCol", "prediction column name", "'prediction'"), ("rawPredictionCol", "raw prediction column name", "'rawPrediction'"), ("inputCol", "input column name", None), + ("inputCols", "input column names", None), ("outputCol", "output column name", None), ("numFeatures", "number of features", None)] code = [] diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py index 4f243844f8..aaf80f0008 100644 --- a/python/pyspark/ml/param/shared.py +++ b/python/pyspark/ml/param/shared.py @@ -223,6 +223,35 @@ class HasInputCol(Params): return self.getOrDefault(self.inputCol) +class HasInputCols(Params): + """ + Mixin for param inputCols: input column names. + """ + + # a placeholder to make it appear in the generated doc + inputCols = Param(Params._dummy(), "inputCols", "input column names") + + def __init__(self): + super(HasInputCols, self).__init__() + #: param for input column names + self.inputCols = Param(self, "inputCols", "input column names") + if None is not None: + self._setDefault(inputCols=None) + + def setInputCols(self, value): + """ + Sets the value of :py:attr:`inputCols`. + """ + self.paramMap[self.inputCols] = value + return self + + def getInputCols(self): + """ + Gets the value of inputCols or its default value. + """ + return self.getOrDefault(self.inputCols) + + class HasOutputCol(Params): """ Mixin for param outputCol: output column name. diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py index 0634254bbd..f5ac2a3986 100644 --- a/python/pyspark/ml/wrapper.py +++ b/python/pyspark/ml/wrapper.py @@ -67,7 +67,9 @@ class JavaWrapper(Params): paramMap = self.extractParamMap(params) for param in self.params: if param in paramMap: - java_obj.set(param.name, paramMap[param]) + value = paramMap[param] + java_param = java_obj.getParam(param.name) + java_obj.set(java_param.w(value)) def _empty_java_param_map(self): """ @@ -79,7 +81,8 @@ class JavaWrapper(Params): paramMap = self._empty_java_param_map() for param, value in params.items(): if param.parent is self: - paramMap.put(java_obj.getParam(param.name), value) + java_param = java_obj.getParam(param.name) + paramMap.put(java_param.w(value)) return paramMap @@ -126,10 +129,8 @@ class JavaTransformer(Transformer, JavaWrapper): def transform(self, dataset, params={}): java_obj = self._java_obj() - self._transfer_params_to_java({}, java_obj) - java_param_map = self._create_java_param_map(params, java_obj) - return DataFrame(java_obj.transform(dataset._jdf, java_param_map), - dataset.sql_ctx) + self._transfer_params_to_java(params, java_obj) + return DataFrame(java_obj.transform(dataset._jdf), dataset.sql_ctx) @inherit_doc |