diff options
Diffstat (limited to 'python/pyspark/mllib/tests.py')
-rw-r--r-- | python/pyspark/mllib/tests.py | 16 |
1 files changed, 16 insertions, 0 deletions
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 5f515b666c..ac55fbf798 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -58,6 +58,7 @@ from pyspark.mllib.recommendation import Rating from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD from pyspark.mllib.random import RandomRDDs from pyspark.mllib.stat import Statistics +from pyspark.mllib.feature import HashingTF from pyspark.mllib.feature import Word2Vec from pyspark.mllib.feature import IDF from pyspark.mllib.feature import StandardScaler, ElementwiseProduct @@ -1583,6 +1584,21 @@ class ALSTests(MLlibTestCase): self.assertRaises(Py4JJavaError, self.sc._jvm.SerDe.loads, bytearray(ser.dumps(r))) +class HashingTFTest(MLlibTestCase): + + def test_binary_term_freqs(self): + hashingTF = HashingTF(100).setBinary(True) + doc = "a a b c c c".split(" ") + n = hashingTF.numFeatures + output = hashingTF.transform(doc).toArray() + expected = Vectors.sparse(n, {hashingTF.indexOf("a"): 1.0, + hashingTF.indexOf("b"): 1.0, + hashingTF.indexOf("c"): 1.0}).toArray() + for i in range(0, n): + self.assertAlmostEqual(output[i], expected[i], 14, "Error at " + str(i) + + ": expected " + str(expected[i]) + ", got " + str(output[i])) + + if __name__ == "__main__": from pyspark.mllib.tests import * if not _have_scipy: |