aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBryan Cutler <cutlerb@gmail.com>2016-04-14 20:47:31 +0200
committerNick Pentreath <nickp@za.ibm.com>2016-04-14 20:47:31 +0200
commitc5172f8205beabe58c0b5392c0d83f9fb9c27f18 (patch)
treed8acee14c00a9bdb3b5826a5f32bcd0e39b20a9e
parent28efdd3fd789fa2ebed5be03b36ca0f682e37669 (diff)
downloadspark-c5172f8205beabe58c0b5392c0d83f9fb9c27f18.tar.gz
spark-c5172f8205beabe58c0b5392c0d83f9fb9c27f18.tar.bz2
spark-c5172f8205beabe58c0b5392c0d83f9fb9c27f18.zip
[SPARK-13967][PYSPARK][ML] Added binary Param to Python CountVectorizer
Added binary toggle param to CountVectorizer feature transformer in PySpark. Created a unit test for using CountVectorizer with the binary toggle on. Author: Bryan Cutler <cutlerb@gmail.com> Closes #12308 from BryanCutler/binary-param-python-CountVectorizer-SPARK-13967.
-rw-r--r--python/pyspark/ml/feature.py34
-rw-r--r--python/pyspark/ml/tests.py16
2 files changed, 45 insertions, 5 deletions
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 86b53285b5..0b0c573eea 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -256,24 +256,33 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable,
vocabSize = Param(
Params._dummy(), "vocabSize", "max size of the vocabulary. Default 1 << 18.",
typeConverter=TypeConverters.toInt)
+ binary = Param(
+ Params._dummy(), "binary", "Binary toggle to control the output vector values." +
+ " If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful" +
+ " for discrete probabilistic models that model binary events rather than integer counts." +
+ " Default False", typeConverter=TypeConverters.toBoolean)
@keyword_only
- def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None):
+ def __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None,
+ outputCol=None):
"""
- __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None)
+ __init__(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None,\
+ outputCol=None)
"""
super(CountVectorizer, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.CountVectorizer",
self.uid)
- self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18)
+ self._setDefault(minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False)
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
@keyword_only
@since("1.6.0")
- def setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None):
+ def setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None,
+ outputCol=None):
"""
- setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, inputCol=None, outputCol=None)
+ setParams(self, minTF=1.0, minDF=1.0, vocabSize=1 << 18, binary=False, inputCol=None,\
+ outputCol=None)
Set the params for the CountVectorizer
"""
kwargs = self.setParams._input_kwargs
@@ -324,6 +333,21 @@ class CountVectorizer(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable,
"""
return self.getOrDefault(self.vocabSize)
+ @since("2.0.0")
+ def setBinary(self, value):
+ """
+ Sets the value of :py:attr:`binary`.
+ """
+ self._paramMap[self.binary] = value
+ return self
+
+ @since("2.0.0")
+ def getBinary(self):
+ """
+ Gets the value of binary or its default value.
+ """
+ return self.getOrDefault(self.binary)
+
def _create_model(self, java_model):
return CountVectorizerModel(java_model)
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index bcbeacbe80..0b0ad2377f 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -406,6 +406,22 @@ class FeatureTests(PySparkTestCase):
transformedDF = stopWordRemover.transform(dataset)
self.assertEqual(transformedDF.head().output, ["a"])
+ def test_count_vectorizer_with_binary(self):
+ sqlContext = SQLContext(self.sc)
+ dataset = sqlContext.createDataFrame([
+ (0, "a a a b b c".split(' '), SparseVector(3, {0: 1.0, 1: 1.0, 2: 1.0}),),
+ (1, "a a".split(' '), SparseVector(3, {0: 1.0}),),
+ (2, "a b".split(' '), SparseVector(3, {0: 1.0, 1: 1.0}),),
+ (3, "c".split(' '), SparseVector(3, {2: 1.0}),)], ["id", "words", "expected"])
+ cv = CountVectorizer(binary=True, inputCol="words", outputCol="features")
+ model = cv.fit(dataset)
+
+ transformedList = model.transform(dataset).select("features", "expected").collect()
+
+ for r in transformedList:
+ feature, expected = r
+ self.assertEqual(feature, expected)
+
class HasInducedError(Params):