aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/tests.py')
-rw-r--r--python/pyspark/mllib/tests.py16
1 files changed, 16 insertions, 0 deletions
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 5f515b666c..ac55fbf798 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -58,6 +58,7 @@ from pyspark.mllib.recommendation import Rating
from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD
from pyspark.mllib.random import RandomRDDs
from pyspark.mllib.stat import Statistics
+from pyspark.mllib.feature import HashingTF
from pyspark.mllib.feature import Word2Vec
from pyspark.mllib.feature import IDF
from pyspark.mllib.feature import StandardScaler, ElementwiseProduct
@@ -1583,6 +1584,21 @@ class ALSTests(MLlibTestCase):
self.assertRaises(Py4JJavaError, self.sc._jvm.SerDe.loads, bytearray(ser.dumps(r)))
+class HashingTFTest(MLlibTestCase):
+
+ def test_binary_term_freqs(self):
+ hashingTF = HashingTF(100).setBinary(True)
+ doc = "a a b c c c".split(" ")
+ n = hashingTF.numFeatures
+ output = hashingTF.transform(doc).toArray()
+ expected = Vectors.sparse(n, {hashingTF.indexOf("a"): 1.0,
+ hashingTF.indexOf("b"): 1.0,
+ hashingTF.indexOf("c"): 1.0}).toArray()
+ for i in range(0, n):
+ self.assertAlmostEqual(output[i], expected[i], 14, "Error at " + str(i) +
+ ": expected " + str(expected[i]) + ", got " + str(output[i]))
+
+
if __name__ == "__main__":
from pyspark.mllib.tests import *
if not _have_scipy: