From c5172f8205beabe58c0b5392c0d83f9fb9c27f18 Mon Sep 17 00:00:00 2001 From: Bryan Cutler Date: Thu, 14 Apr 2016 20:47:31 +0200 Subject: [SPARK-13967][PYSPARK][ML] Added binary Param to Python CountVectorizer Added binary toggle param to CountVectorizer feature transformer in PySpark. Created a unit test for using CountVectorizer with the binary toggle on. Author: Bryan Cutler Closes #12308 from BryanCutler/binary-param-python-CountVectorizer-SPARK-13967. --- python/pyspark/ml/feature.py | 34 +++++++++++++++++++++++++++++----- python/pyspark/ml/tests.py | 16 ++++++++++++++++ 2 files changed, 45 insertions(+), 5 deletions(-) (limited to 'python') diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 86b53285b5..0b0c573eea 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -256,24 +256,33 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, vocabSize = Param( Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.", typeConverter=TypeConverters.toInt) + binary = Param( + Params._dummy(), "binary", "Binary toggle to control the output vector values." + + " If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful" + + " for discrete probabilistic models that model binary events rather than integer counts." + + " Default False", typeConverter=TypeConverters.toBoolean) @keyword_only - def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None): + def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None, + outputCol=None): """ - __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None) + __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None,\ + outputCol=None) """ super(CountVectorizer, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer", self.uid) - self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18) + self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False) kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only @since("1.6.0") - def setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None): + def setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None, + outputCol=None): """ - setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None) + setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None,\ + outputCol=None) Set the params for the CountVectorizer """ kwargs = self.setParams._input_kwargs @@ -324,6 +333,21 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, """ return self.getOrDefault(self.vocabSize) + @since("2.0.0") + def setBinary(self, value): + """ + Sets the value of :py:attr:`binary`. + """ + self._paramMap[self.binary] = value + return self + + @since("2.0.0") + def getBinary(self): + """ + Gets the value of binary or its default value. + """ + return self.getOrDefault(self.binary) + def _create_model(self, java_model): return CountVectorizerModel(java_model) diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index bcbeacbe80..0b0ad2377f 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -406,6 +406,22 @@ class FeatureTests(PySparkTestCase): transformedDF = stopWordRemover.transform(dataset) self.assertEqual(transformedDF.head().output, ["a"]) + def test_count_vectorizer_with_binary(self): + sqlContext = SQLContext(self.sc) + dataset = sqlContext.createDataFrame([ + (0, "a a a b b c".split(' '), SparseVector(3, {0: 1.0, 1: 1.0, 2: 1.0}),), + (1, "a a".split(' '), SparseVector(3, {0: 1.0}),), + (2, "a b".split(' '), SparseVector(3, {0: 1.0, 1: 1.0}),), + (3, "c".split(' '), SparseVector(3, {2: 1.0}),)], ["id", "words", "expected"]) + cv = CountVectorizer(binary=True, inputCol="words", outputCol="features") + model = cv.fit(dataset) + + transformedList = model.transform(dataset).select("features", "expected").collect() + + for r in transformedList: + feature, expected = r + self.assertEqual(feature, expected) + class HasInducedError(Params): -- cgit v1.2.3