aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/feature.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/feature.py')
-rw-r--r--python/pyspark/ml/feature.py43
1 files changed, 41 insertions, 2 deletions
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 4e4614b859..8a0fdddd2d 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -16,12 +16,12 @@
#
from pyspark.rdd import ignore_unicode_prefix
-from pyspark.ml.param.shared import HasInputCol, HasOutputCol, HasNumFeatures
+from pyspark.ml.param.shared import HasInputCol, HasInputCols, HasOutputCol, HasNumFeatures
from pyspark.ml.util import keyword_only
from pyspark.ml.wrapper import JavaTransformer
from pyspark.mllib.common import inherit_doc
-__all__ = ['Tokenizer', 'HashingTF']
+__all__ = ['Tokenizer', 'HashingTF', 'VectorAssembler']
@inherit_doc
@@ -112,6 +112,45 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures):
return self._set(**kwargs)
+@inherit_doc
+class VectorAssembler(JavaTransformer, HasInputCols, HasOutputCol):
+ """
+ A feature transformer that merges multiple columns into a vector column.
+
+ >>> from pyspark.sql import Row
+ >>> df = sc.parallelize([Row(a=1, b=0, c=3)]).toDF()
+ >>> vecAssembler = VectorAssembler(inputCols=["a", "b", "c"], outputCol="features")
+ >>> vecAssembler.transform(df).head().features
+ SparseVector(3, {0: 1.0, 2: 3.0})
+ >>> vecAssembler.setParams(outputCol="freqs").transform(df).head().freqs
+ SparseVector(3, {0: 1.0, 2: 3.0})
+ >>> params = {vecAssembler.inputCols: ["b", "a"], vecAssembler.outputCol: "vector"}
+ >>> vecAssembler.transform(df, params).head().vector
+ SparseVector(2, {1: 1.0})
+ """
+
+ _java_class = "org.apache.spark.ml.feature.VectorAssembler"
+
+ @keyword_only
+ def __init__(self, inputCols=None, outputCol=None):
+ """
+ __init__(self, inputCols=None, outputCol=None)
+ """
+ super(VectorAssembler, self).__init__()
+ self._setDefault()
+ kwargs = self.__init__._input_kwargs
+ self.setParams(**kwargs)
+
+ @keyword_only
+ def setParams(self, inputCols=None, outputCol=None):
+ """
+ setParams(self, inputCols=None, outputCol=None)
+ Sets params for this VectorAssembler.
+ """
+ kwargs = self.setParams._input_kwargs
+ return self._set(**kwargs)
+
+
if __name__ == "__main__":
import doctest
from pyspark.context import SparkContext