aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/feature.py
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2016-04-25 12:08:43 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-04-25 12:08:43 -0700
commit425f6916462ca5d0595c61101d52686006ed6b8b (patch)
tree57c8247319bf4f01c62f90272d2260a18cb0905c /python/pyspark/ml/feature.py
parent88e54218d5f0a8696563813feb387c08ec6b13d5 (diff)
downloadspark-425f6916462ca5d0595c61101d52686006ed6b8b.tar.gz
spark-425f6916462ca5d0595c61101d52686006ed6b8b.tar.bz2
spark-425f6916462ca5d0595c61101d52686006ed6b8b.zip
[SPARK-10574][ML][MLLIB] HashingTF supports MurmurHash3
## What changes were proposed in this pull request? As the discussion at [SPARK-10574](https://issues.apache.org/jira/browse/SPARK-10574), ```HashingTF``` should support MurmurHash3 and make it as the default hash algorithm. We should also expose set/get API for ```hashAlgorithm```, then users can choose the hash method. Note: The problem that ```mllib.feature.HashingTF``` behaves differently between Scala/Java and Python will be resolved in the followup work. ## How was this patch tested? unit tests. cc jkbradley MLnick Author: Yanbo Liang <ybliang8@gmail.com> Author: Joseph K. Bradley <joseph@databricks.com> Closes #12498 from yanboliang/spark-10574.
Diffstat (limited to 'python/pyspark/ml/feature.py')
-rw-r--r--python/pyspark/ml/feature.py39
1 files changed, 31 insertions, 8 deletions
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 1b298e639d..0e578d48ca 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -523,12 +523,12 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, Java
>>> df = sqlContext.createDataFrame([(["a", "b", "c"],)], ["words"])
>>> hashingTF = HashingTF(numFeatures=10, inputCol="words", outputCol="features")
>>> hashingTF.transform(df).head().features
- SparseVector(10, {7: 1.0, 8: 1.0, 9: 1.0})
+ SparseVector(10, {0: 1.0, 1: 1.0, 2: 1.0})
>>> hashingTF.setParams(outputCol="freqs").transform(df).head().freqs
- SparseVector(10, {7: 1.0, 8: 1.0, 9: 1.0})
+ SparseVector(10, {0: 1.0, 1: 1.0, 2: 1.0})
>>> params = {hashingTF.numFeatures: 5, hashingTF.outputCol: "vector"}
>>> hashingTF.transform(df, params).head().vector
- SparseVector(5, {2: 1.0, 3: 1.0, 4: 1.0})
+ SparseVector(5, {0: 1.0, 1: 1.0, 2: 1.0})
>>> hashingTFPath = temp_path + "/hashing-tf"
>>> hashingTF.save(hashingTFPath)
>>> loadedHashingTF = HashingTF.load(hashingTFPath)
@@ -543,22 +543,30 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, Java
"rather than integer counts. Default False.",
typeConverter=TypeConverters.toBoolean)
+ hashAlgorithm = Param(Params._dummy(), "hashAlgorithm", "The hash algorithm used when " +
+ "mapping term to integer. Supported options: murmur3(default) " +
+ "and native.", typeConverter=TypeConverters.toString)
+
@keyword_only
- def __init__(self, numFeatures=1 << 18, binary=False, inputCol=None, outputCol=None):
+ def __init__(self, numFeatures=1 << 18, binary=False, inputCol=None, outputCol=None,
+ hashAlgorithm="murmur3"):
"""
- __init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None)
+ __init__(self, numFeatures=1 << 18, binary=False, inputCol=None, outputCol=None, \
+ hashAlgorithm="murmur3")
"""
super(HashingTF, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.HashingTF", self.uid)
- self._setDefault(numFeatures=1 << 18, binary=False)
+ self._setDefault(numFeatures=1 << 18, binary=False, hashAlgorithm="murmur3")
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
@keyword_only
@since("1.3.0")
- def setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None):
+ def setParams(self, numFeatures=1 << 18, binary=False, inputCol=None, outputCol=None,
+ hashAlgorithm="murmur3"):
"""
- setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None)
+ setParams(self, numFeatures=1 << 18, binary=False, inputCol=None, outputCol=None, \
+ hashAlgorithm="murmur3")
Sets params for this HashingTF.
"""
kwargs = self.setParams._input_kwargs
@@ -579,6 +587,21 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, Java
"""
return self.getOrDefault(self.binary)
+ @since("2.0.0")
+ def setHashAlgorithm(self, value):
+ """
+ Sets the value of :py:attr:`hashAlgorithm`.
+ """
+ self._set(hashAlgorithm=value)
+ return self
+
+ @since("2.0.0")
+ def getHashAlgorithm(self):
+ """
+ Gets the value of hashAlgorithm or its default value.
+ """
+ return self.getOrDefault(self.hashAlgorithm)
+
@inherit_doc
class IDF(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable):