aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/tests.py')
-rw-r--r--python/pyspark/ml/tests.py19
1 files changed, 19 insertions, 0 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 0b0ad2377f..86c0254a2b 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -847,6 +847,25 @@ class TrainingSummaryTest(PySparkTestCase):
self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC)
+class HashingTFTest(PySparkTestCase):
+
+ def test_apply_binary_term_freqs(self):
+ sqlContext = SQLContext(self.sc)
+
+ df = sqlContext.createDataFrame([(0, ["a", "a", "b", "c", "c", "c"])], ["id", "words"])
+ n = 100
+ hashingTF = HashingTF()
+ hashingTF.setInputCol("words").setOutputCol("features").setNumFeatures(n).setBinary(True)
+ output = hashingTF.transform(df)
+ features = output.select("features").first().features.toArray()
+ expected = Vectors.sparse(n, {(ord("a") % n): 1.0,
+ (ord("b") % n): 1.0,
+ (ord("c") % n): 1.0}).toArray()
+ for i in range(0, n):
+ self.assertAlmostEqual(features[i], expected[i], 14, "Error at " + str(i) +
+ ": expected " + str(expected[i]) + ", got " + str(features[i]))
+
+
if __name__ == "__main__":
from pyspark.ml.tests import *
if xmlrunner: