diff options
Diffstat (limited to 'python/pyspark/ml/tests.py')
-rw-r--r-- | python/pyspark/ml/tests.py | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index bcbeacbe80..0b0ad2377f 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -406,6 +406,22 @@ class FeatureTests(PySparkTestCase): transformedDF = stopWordRemover.transform(dataset) self.assertEqual(transformedDF.head().output, ["a"]) + def test_count_vectorizer_with_binary(self): + sqlContext = SQLContext(self.sc) + dataset = sqlContext.createDataFrame([ + (0, "a a a b b c".split(' '), SparseVector(3, {0: 1.0, 1: 1.0, 2: 1.0}),), + (1, "a a".split(' '), SparseVector(3, {0: 1.0}),), + (2, "a b".split(' '), SparseVector(3, {0: 1.0, 1: 1.0}),), + (3, "c".split(' '), SparseVector(3, {2: 1.0}),)], ["id", "words", "expected"]) + cv = CountVectorizer(binary=True, inputCol="words", outputCol="features") + model = cv.fit(dataset) + + transformedList = model.transform(dataset).select("features", "expected").collect() + + for r in transformedList: + feature, expected = r + self.assertEqual(feature, expected) + class HasInducedError(Params): |