aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorBurak Yavuz <brkyvz@gmail.com>2015-05-07 10:25:41 -0700
committerXiangrui Meng <meng@databricks.com>2015-05-07 10:25:41 -0700
commit9e2ffb13287e6efe256b8d23a4654e4cc305e20b (patch)
tree79a13615578199c2907d371c965ef031307c47b9 /python
parented9be06a4797bbb678355b361054c8872ac20b75 (diff)
downloadspark-9e2ffb13287e6efe256b8d23a4654e4cc305e20b.tar.gz
spark-9e2ffb13287e6efe256b8d23a4654e4cc305e20b.tar.bz2
spark-9e2ffb13287e6efe256b8d23a4654e4cc305e20b.zip
[SPARK-7388] [SPARK-7383] wrapper for VectorAssembler in Python
The wrapper required the implementation of the `ArrayParam`, because `Array[T]` is hard to obtain from Python. `ArrayParam` has an extra function called `wCast` which is an internal function to obtain `Array[T]` from `Seq[T]` Author: Burak Yavuz <brkyvz@gmail.com> Author: Xiangrui Meng <meng@databricks.com> Closes #5930 from brkyvz/ml-feat and squashes the following commits: 73e745f [Burak Yavuz] Merge pull request #3 from mengxr/SPARK-7388 c221db9 [Xiangrui Meng] overload StringArrayParam.w c81072d [Burak Yavuz] addressed comments 99c2ebf [Burak Yavuz] add to python_shared_params 39ecb07 [Burak Yavuz] fix scalastyle 7f7ea2a [Burak Yavuz] [SPARK-7388][SPARK-7383] wrapper for VectorAssembler in Python
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/ml/feature.py43
-rw-r--r--python/pyspark/ml/param/_shared_params_code_gen.py1
-rw-r--r--python/pyspark/ml/param/shared.py29
-rw-r--r--python/pyspark/ml/wrapper.py13
4 files changed, 78 insertions, 8 deletions
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 4e4614b859..8a0fdddd2d 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -16,12 +16,12 @@
#
from pyspark.rdd import ignore_unicode_prefix
-from pyspark.ml.param.shared import HasInputCol, HasOutputCol, HasNumFeatures
+from pyspark.ml.param.shared import HasInputCol, HasInputCols, HasOutputCol, HasNumFeatures
from pyspark.ml.util import keyword_only
from pyspark.ml.wrapper import JavaTransformer
from pyspark.mllib.common import inherit_doc
-__all__ = ['Tokenizer', 'HashingTF']
+__all__ = ['Tokenizer', 'HashingTF', 'VectorAssembler']
@inherit_doc
@@ -112,6 +112,45 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures):
return self._set(**kwargs)
+@inherit_doc
+class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol):
+ """
+ A feature transformer that merges multiple columns into a vector column.
+
+ >>> from pyspark.sql import Row
+ >>> df = sc.parallelize([Row(a=1, b=0, c=3)]).toDF()
+ >>> vecAssembler = VectorAssembler(inputCols=["a", "b", "c"], outputCol="features")
+ >>> vecAssembler.transform(df).head().features
+ SparseVector(3, {0: 1.0, 2: 3.0})
+ >>> vecAssembler.setParams(outputCol="freqs").transform(df).head().freqs
+ SparseVector(3, {0: 1.0, 2: 3.0})
+ >>> params = {vecAssembler.inputCols: ["b", "a"], vecAssembler.outputCol: "vector"}
+ >>> vecAssembler.transform(df, params).head().vector
+ SparseVector(2, {1: 1.0})
+ """
+
+ _java_class = "org.apache.spark.ml.feature.VectorAssembler"
+
+ @keyword_only
+ def __init__(self, inputCols=None, outputCol=None):
+ """
+ __init__(self, inputCols=None, outputCol=None)
+ """
+ super(VectorAssembler, self).__init__()
+ self._setDefault()
+ kwargs = self.__init__._input_kwargs
+ self.setParams(**kwargs)
+
+ @keyword_only
+ def setParams(self, inputCols=None, outputCol=None):
+ """
+ setParams(self, inputCols=None, outputCol=None)
+ Sets params for this VectorAssembler.
+ """
+ kwargs = self.setParams._input_kwargs
+ return self._set(**kwargs)
+
+
if __name__ == "__main__":
import doctest
from pyspark.context import SparkContext
diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py
index c71c823db2..c1c8e921dd 100644
--- a/python/pyspark/ml/param/_shared_params_code_gen.py
+++ b/python/pyspark/ml/param/_shared_params_code_gen.py
@@ -95,6 +95,7 @@ if __name__ == "__main__":
("predictionCol", "prediction column name", "'prediction'"),
("rawPredictionCol", "raw prediction column name", "'rawPrediction'"),
("inputCol", "input column name", None),
+ ("inputCols", "input column names", None),
("outputCol", "output column name", None),
("numFeatures", "number of features", None)]
code = []
diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py
index 4f243844f8..aaf80f0008 100644
--- a/python/pyspark/ml/param/shared.py
+++ b/python/pyspark/ml/param/shared.py
@@ -223,6 +223,35 @@ class HasInputCol(Params):
return self.getOrDefault(self.inputCol)
+class HasInputCols(Params):
+ """
+ Mixin for param inputCols: input column names.
+ """
+
+ # a placeholder to make it appear in the generated doc
+ inputCols = Param(Params._dummy(), "inputCols", "input column names")
+
+ def __init__(self):
+ super(HasInputCols, self).__init__()
+ #: param for input column names
+ self.inputCols = Param(self, "inputCols", "input column names")
+ if None is not None:
+ self._setDefault(inputCols=None)
+
+ def setInputCols(self, value):
+ """
+ Sets the value of :py:attr:`inputCols`.
+ """
+ self.paramMap[self.inputCols] = value
+ return self
+
+ def getInputCols(self):
+ """
+ Gets the value of inputCols or its default value.
+ """
+ return self.getOrDefault(self.inputCols)
+
+
class HasOutputCol(Params):
"""
Mixin for param outputCol: output column name.
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index 0634254bbd..f5ac2a3986 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -67,7 +67,9 @@ class JavaWrapper(Params):
paramMap = self.extractParamMap(params)
for param in self.params:
if param in paramMap:
- java_obj.set(param.name, paramMap[param])
+ value = paramMap[param]
+ java_param = java_obj.getParam(param.name)
+ java_obj.set(java_param.w(value))
def _empty_java_param_map(self):
"""
@@ -79,7 +81,8 @@ class JavaWrapper(Params):
paramMap = self._empty_java_param_map()
for param, value in params.items():
if param.parent is self:
- paramMap.put(java_obj.getParam(param.name), value)
+ java_param = java_obj.getParam(param.name)
+ paramMap.put(java_param.w(value))
return paramMap
@@ -126,10 +129,8 @@ class JavaTransformer(Transformer, JavaWrapper):
def transform(self, dataset, params={}):
java_obj = self._java_obj()
- self._transfer_params_to_java({}, java_obj)
- java_param_map = self._create_java_param_map(params, java_obj)
- return DataFrame(java_obj.transform(dataset._jdf, java_param_map),
- dataset.sql_ctx)
+ self._transfer_params_to_java(params, java_obj)
+ return DataFrame(java_obj.transform(dataset._jdf), dataset.sql_ctx)
@inherit_doc