aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/tests.py')
-rw-r--r--python/pyspark/ml/tests.py16
1 files changed, 16 insertions, 0 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index bcbeacbe80..0b0ad2377f 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -406,6 +406,22 @@ class FeatureTests(PySparkTestCase):
transformedDF = stopWordRemover.transform(dataset)
self.assertEqual(transformedDF.head().output, ["a"])
+ def test_count_vectorizer_with_binary(self):
+ sqlContext = SQLContext(self.sc)
+ dataset = sqlContext.createDataFrame([
+ (0, "a a a b b c".split(' '), SparseVector(3, {0: 1.0, 1: 1.0, 2: 1.0}),),
+ (1, "a a".split(' '), SparseVector(3, {0: 1.0}),),
+ (2, "a b".split(' '), SparseVector(3, {0: 1.0, 1: 1.0}),),
+ (3, "c".split(' '), SparseVector(3, {2: 1.0}),)], ["id", "words", "expected"])
+ cv = CountVectorizer(binary=True, inputCol="words", outputCol="features")
+ model = cv.fit(dataset)
+
+ transformedList = model.transform(dataset).select("features", "expected").collect()
+
+ for r in transformedList:
+ feature, expected = r
+ self.assertEqual(feature, expected)
+
class HasInducedError(Params):