diff options
Diffstat (limited to 'python/pyspark/ml/tests.py')
-rw-r--r-- | python/pyspark/ml/tests.py | 19 |
1 files changed, 19 insertions, 0 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 0b0ad2377f..86c0254a2b 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -847,6 +847,25 @@ class TrainingSummaryTest(PySparkTestCase): self.assertAlmostEqual(sameSummary.areaUnderROC, s.areaUnderROC) +class HashingTFTest(PySparkTestCase): + + def test_apply_binary_term_freqs(self): + sqlContext = SQLContext(self.sc) + + df = sqlContext.createDataFrame([(0, ["a", "a", "b", "c", "c", "c"])], ["id", "words"]) + n = 100 + hashingTF = HashingTF() + hashingTF.setInputCol("words").setOutputCol("features").setNumFeatures(n).setBinary(True) + output = hashingTF.transform(df) + features = output.select("features").first().features.toArray() + expected = Vectors.sparse(n, {(ord("a") % n): 1.0, + (ord("b") % n): 1.0, + (ord("c") % n): 1.0}).toArray() + for i in range(0, n): + self.assertAlmostEqual(features[i], expected[i], 14, "Error at " + str(i) + + ": expected " + str(expected[i]) + ", got " + str(features[i])) + + if __name__ == "__main__": from pyspark.ml.tests import * if xmlrunner: |