aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib
diff options
context:
space:
mode:
authorYong Tang <yong.tang.github@outlook.com>2016-04-14 21:53:32 +0200
committerNick Pentreath <nickp@za.ibm.com>2016-04-14 21:53:32 +0200
commitbc748b7b8f3b5aee28aff9ea078c216ca137a5b7 (patch)
tree2255d50eea81c6117024152451285fccb96b80f9 /python/pyspark/mllib
parentbf65c87f706019d235d7093637341668a13b1be1 (diff)
downloadspark-bc748b7b8f3b5aee28aff9ea078c216ca137a5b7.tar.gz
spark-bc748b7b8f3b5aee28aff9ea078c216ca137a5b7.tar.bz2
spark-bc748b7b8f3b5aee28aff9ea078c216ca137a5b7.zip
[SPARK-14238][ML][MLLIB][PYSPARK] Add binary toggle Param to PySpark HashingTF in ML & MLlib
## What changes were proposed in this pull request? This fix tries to add binary toggle Param to PySpark HashingTF in ML & MLlib. If this toggle is set, then all non-zero counts will be set to 1. Note: This fix (SPARK-14238) is extended from SPARK-13963 where Scala implementation was done. ## How was this patch tested? This fix adds two tests to cover the code changes. One for HashingTF in PySpark's ML and one for HashingTF in PySpark's MLLib. Author: Yong Tang <yong.tang.github@outlook.com> Closes #12079 from yongtang/SPARK-14238.
Diffstat (limited to 'python/pyspark/mllib')
-rw-r--r--python/pyspark/mllib/feature.py13
-rw-r--r--python/pyspark/mllib/tests.py16
2 files changed, 28 insertions, 1 deletions
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
index 6129353525..b3dd2f63a5 100644
--- a/python/pyspark/mllib/feature.py
+++ b/python/pyspark/mllib/feature.py
@@ -379,6 +379,17 @@ class HashingTF(object):
"""
def __init__(self, numFeatures=1 << 20):
self.numFeatures = numFeatures
+ self.binary = False
+
+ @since("2.0.0")
+ def setBinary(self, value):
+ """
+ If True, term frequency vector will be binary such that non-zero
+ term counts will be set to 1
+ (default: False)
+ """
+ self.binary = value
+ return self
@since('1.2.0')
def indexOf(self, term):
@@ -398,7 +409,7 @@ class HashingTF(object):
freq = {}
for term in document:
i = self.indexOf(term)
- freq[i] = freq.get(i, 0) + 1.0
+ freq[i] = 1.0 if self.binary else freq.get(i, 0) + 1.0
return Vectors.sparse(self.numFeatures, freq.items())
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 5f515b666c..ac55fbf798 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -58,6 +58,7 @@ from pyspark.mllib.recommendation import Rating
from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD
from pyspark.mllib.random import RandomRDDs
from pyspark.mllib.stat import Statistics
+from pyspark.mllib.feature import HashingTF
from pyspark.mllib.feature import Word2Vec
from pyspark.mllib.feature import IDF
from pyspark.mllib.feature import StandardScaler, ElementwiseProduct
@@ -1583,6 +1584,21 @@ class ALSTests(MLlibTestCase):
self.assertRaises(Py4JJavaError, self.sc._jvm.SerDe.loads, bytearray(ser.dumps(r)))
+class HashingTFTest(MLlibTestCase):
+
+ def test_binary_term_freqs(self):
+ hashingTF = HashingTF(100).setBinary(True)
+ doc = "a a b c c c".split(" ")
+ n = hashingTF.numFeatures
+ output = hashingTF.transform(doc).toArray()
+ expected = Vectors.sparse(n, {hashingTF.indexOf("a"): 1.0,
+ hashingTF.indexOf("b"): 1.0,
+ hashingTF.indexOf("c"): 1.0}).toArray()
+ for i in range(0, n):
+ self.assertAlmostEqual(output[i], expected[i], 14, "Error at " + str(i) +
+ ": expected " + str(expected[i]) + ", got " + str(output[i]))
+
+
if __name__ == "__main__":
from pyspark.mllib.tests import *
if not _have_scipy: