diff options
author | Yanbo Liang <ybliang8@gmail.com> | 2016-04-25 12:08:43 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-04-25 12:08:43 -0700 |
commit | 425f6916462ca5d0595c61101d52686006ed6b8b (patch) | |
tree | 57c8247319bf4f01c62f90272d2260a18cb0905c | |
parent | 88e54218d5f0a8696563813feb387c08ec6b13d5 (diff) | |
download | spark-425f6916462ca5d0595c61101d52686006ed6b8b.tar.gz spark-425f6916462ca5d0595c61101d52686006ed6b8b.tar.bz2 spark-425f6916462ca5d0595c61101d52686006ed6b8b.zip |
[SPARK-10574][ML][MLLIB] HashingTF supports MurmurHash3
## What changes were proposed in this pull request?
As the discussion at [SPARK-10574](https://issues.apache.org/jira/browse/SPARK-10574), ```HashingTF``` should support MurmurHash3 and make it as the default hash algorithm. We should also expose set/get API for ```hashAlgorithm```, then users can choose the hash method.
Note: The problem that ```mllib.feature.HashingTF``` behaves differently between Scala/Java and Python will be resolved in the followup work.
## How was this patch tested?
unit tests.
cc jkbradley MLnick
Author: Yanbo Liang <ybliang8@gmail.com>
Author: Joseph K. Bradley <joseph@databricks.com>
Closes #12498 from yanboliang/spark-10574.
5 files changed, 162 insertions, 30 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index 467ad73074..6fc08aee13 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.AttributeGroup -import org.apache.spark.ml.param.{BooleanParam, IntParam, ParamMap, ParamValidators} +import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature @@ -31,6 +31,12 @@ import org.apache.spark.sql.types.{ArrayType, StructType} /** * :: Experimental :: * Maps a sequence of terms to their term frequencies using the hashing trick. + * Currently we support two hash algorithms: "murmur3" (default) and "native". + * "murmur3" calculates a hash code value for the term object using + * Austin Appleby's MurmurHash 3 algorithm (MurmurHash3_x86_32); + * "native" calculates the hash code value using the native Scala implementation. + * In Spark 1.6 and earlier, "native" is the default hash algorithm; + * after Spark 2.0, we use "murmur3" as the default one. */ @Experimental class HashingTF(override val uid: String) @@ -63,7 +69,20 @@ class HashingTF(override val uid: String) "This is useful for discrete probabilistic models that model binary events rather " + "than integer counts") - setDefault(numFeatures -> (1 << 18), binary -> false) + /** + * The hash algorithm used when mapping term to integer. + * Supported options: "murmur3" and "native". We use "native" as default hash algorithm + * in Spark 1.6 and earlier. After Spark 2.0, we use "murmur3" as default one. + * (Default = "murmur3") + * @group expertParam + */ + val hashAlgorithm = new Param[String](this, "hashAlgorithm", "The hash algorithm used when " + + "mapping term to integer. Supported options: " + + s"${feature.HashingTF.supportedHashAlgorithms.mkString(",")}.", + ParamValidators.inArray[String](feature.HashingTF.supportedHashAlgorithms)) + + setDefault(numFeatures -> (1 << 18), binary -> false, + hashAlgorithm -> feature.HashingTF.Murmur3) /** @group getParam */ def getNumFeatures: Int = $(numFeatures) @@ -77,10 +96,18 @@ class HashingTF(override val uid: String) /** @group setParam */ def setBinary(value: Boolean): this.type = set(binary, value) + /** @group expertGetParam */ + def getHashAlgorithm: String = $(hashAlgorithm) + + /** @group expertSetParam */ + def setHashAlgorithm(value: String): this.type = set(hashAlgorithm, value) + @Since("2.0.0") override def transform(dataset: Dataset[_]): DataFrame = { val outputSchema = transformSchema(dataset.schema) - val hashingTF = new feature.HashingTF($(numFeatures)).setBinary($(binary)) + val hashingTF = new feature.HashingTF($(numFeatures)) + .setBinary($(binary)) + .setHashAlgorithm($(hashAlgorithm)) val t = udf { terms: Seq[_] => hashingTF.transform(terms) } val metadata = outputSchema($(outputCol)).metadata dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala index 47c9e850a0..321f11d9f9 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala @@ -22,10 +22,13 @@ import java.lang.{Iterable => JavaIterable} import scala.collection.JavaConverters._ import scala.collection.mutable +import org.apache.spark.SparkException import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.rdd.RDD +import org.apache.spark.unsafe.hash.Murmur3_x86_32._ +import org.apache.spark.unsafe.types.UTF8String import org.apache.spark.util.Utils /** @@ -36,7 +39,10 @@ import org.apache.spark.util.Utils @Since("1.1.0") class HashingTF(val numFeatures: Int) extends Serializable { + import HashingTF._ + private var binary = false + private var hashAlgorithm = HashingTF.Murmur3 /** */ @@ -54,10 +60,34 @@ class HashingTF(val numFeatures: Int) extends Serializable { } /** + * Set the hash algorithm used when mapping term to integer. + * (default: murmur3) + */ + @Since("2.0.0") + def setHashAlgorithm(value: String): this.type = { + hashAlgorithm = value + this + } + + /** * Returns the index of the input term. */ @Since("1.1.0") - def indexOf(term: Any): Int = Utils.nonNegativeMod(term.##, numFeatures) + def indexOf(term: Any): Int = { + Utils.nonNegativeMod(getHashFunction(term), numFeatures) + } + + /** + * Get the hash function corresponding to the current [[hashAlgorithm]] setting. + */ + private def getHashFunction: Any => Int = hashAlgorithm match { + case Murmur3 => murmur3Hash + case Native => nativeHash + case _ => + // This should never happen. + throw new IllegalArgumentException( + s"HashingTF does not recognize hash algorithm $hashAlgorithm") + } /** * Transforms the input document into a sparse term frequency vector. @@ -66,8 +96,9 @@ class HashingTF(val numFeatures: Int) extends Serializable { def transform(document: Iterable[_]): Vector = { val termFrequencies = mutable.HashMap.empty[Int, Double] val setTF = if (binary) (i: Int) => 1.0 else (i: Int) => termFrequencies.getOrElse(i, 0.0) + 1.0 + val hashFunc: Any => Int = getHashFunction document.foreach { term => - val i = indexOf(term) + val i = Utils.nonNegativeMod(hashFunc(term), numFeatures) termFrequencies.put(i, setTF(i)) } Vectors.sparse(numFeatures, termFrequencies.toSeq) @@ -97,3 +128,41 @@ class HashingTF(val numFeatures: Int) extends Serializable { dataset.rdd.map(this.transform).toJavaRDD() } } + +object HashingTF { + + private[spark] val Native: String = "native" + + private[spark] val Murmur3: String = "murmur3" + + private[spark] val supportedHashAlgorithms: Array[String] = Array(Native, Murmur3) + + private val seed = 42 + + /** + * Calculate a hash code value for the term object using the native Scala implementation. + */ + private[spark] def nativeHash(term: Any): Int = term.## + + /** + * Calculate a hash code value for the term object using + * Austin Appleby's MurmurHash 3 algorithm (MurmurHash3_x86_32). + */ + private[spark] def murmur3Hash(term: Any): Int = { + term match { + case null => seed + case b: Boolean => hashInt(if (b) 1 else 0, seed) + case b: Byte => hashInt(b, seed) + case s: Short => hashInt(s, seed) + case i: Int => hashInt(i, seed) + case l: Long => hashLong(l, seed) + case f: Float => hashInt(java.lang.Float.floatToIntBits(f), seed) + case d: Double => hashLong(java.lang.Double.doubleToLongBits(d), seed) + case s: String => + val utf8 = UTF8String.fromString(s) + hashUnsafeBytes(utf8.getBaseObject, utf8.getBaseOffset, utf8.numBytes(), seed) + case _ => throw new SparkException("HashingTF with murmur3 algorithm does not " + + s"support type ${term.getClass.getCanonicalName} of input data.") + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala index addd733c20..e32b862af7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala @@ -21,6 +21,7 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.attribute.AttributeGroup import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.util.DefaultReadWriteTest +import org.apache.spark.mllib.feature.{HashingTF => MLlibHashingTF} import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ @@ -37,19 +38,26 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with Defau (0, "a a b b c d".split(" ").toSeq) )).toDF("id", "words") val n = 100 - val hashingTF = new HashingTF() - .setInputCol("words") - .setOutputCol("features") - .setNumFeatures(n) - val output = hashingTF.transform(df) - val attrGroup = AttributeGroup.fromStructField(output.schema("features")) - require(attrGroup.numAttributes === Some(n)) - val features = output.select("features").first().getAs[Vector](0) - // Assume perfect hash on "a", "b", "c", and "d". - def idx: Any => Int = featureIdx(n) - val expected = Vectors.sparse(n, - Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0))) - assert(features ~== expected absTol 1e-14) + Seq("murmur3", "native").foreach { hashAlgorithm => + val hashingTF = new HashingTF() + .setInputCol("words") + .setOutputCol("features") + .setNumFeatures(n) + .setHashAlgorithm(hashAlgorithm) + val output = hashingTF.transform(df) + val attrGroup = AttributeGroup.fromStructField(output.schema("features")) + require(attrGroup.numAttributes === Some(n)) + val features = output.select("features").first().getAs[Vector](0) + // Assume perfect hash on "a", "b", "c", and "d". + def idx: Any => Int = if (hashAlgorithm == "murmur3") { + murmur3FeatureIdx(n) + } else { + nativeFeatureIdx(n) + } + val expected = Vectors.sparse(n, + Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0))) + assert(features ~== expected absTol 1e-14) + } } test("applying binary term freqs") { @@ -64,7 +72,7 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with Defau .setBinary(true) val output = hashingTF.transform(df) val features = output.select("features").first().getAs[Vector](0) - def idx: Any => Int = featureIdx(n) // Assume perfect hash on input features + def idx: Any => Int = murmur3FeatureIdx(n) // Assume perfect hash on input features val expected = Vectors.sparse(n, Seq((idx("a"), 1.0), (idx("b"), 1.0), (idx("c"), 1.0))) assert(features ~== expected absTol 1e-14) @@ -78,7 +86,11 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with Defau testDefaultReadWrite(t) } - private def featureIdx(numFeatures: Int)(term: Any): Int = { - Utils.nonNegativeMod(term.##, numFeatures) + private def nativeFeatureIdx(numFeatures: Int)(term: Any): Int = { + Utils.nonNegativeMod(MLlibHashingTF.nativeHash(term), numFeatures) + } + + private def murmur3FeatureIdx(numFeatures: Int)(term: Any): Int = { + Utils.nonNegativeMod(MLlibHashingTF.murmur3Hash(term), numFeatures) } } diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py index 1b298e639d..0e578d48ca 100644 --- a/python/pyspark/ml/feature.py +++ b/python/pyspark/ml/feature.py @@ -523,12 +523,12 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, Java >>> df = sqlContext.createDataFrame([(["a", "b", "c"],)], ["words"]) >>> hashingTF = HashingTF(numFeatures=10, inputCol="words", outputCol="features") >>> hashingTF.transform(df).head().features - SparseVector(10, {7: 1.0, 8: 1.0, 9: 1.0}) + SparseVector(10, {0: 1.0, 1: 1.0, 2: 1.0}) >>> hashingTF.setParams(outputCol="freqs").transform(df).head().freqs - SparseVector(10, {7: 1.0, 8: 1.0, 9: 1.0}) + SparseVector(10, {0: 1.0, 1: 1.0, 2: 1.0}) >>> params = {hashingTF.numFeatures: 5, hashingTF.outputCol: "vector"} >>> hashingTF.transform(df, params).head().vector - SparseVector(5, {2: 1.0, 3: 1.0, 4: 1.0}) + SparseVector(5, {0: 1.0, 1: 1.0, 2: 1.0}) >>> hashingTFPath = temp_path + "/hashing-tf" >>> hashingTF.save(hashingTFPath) >>> loadedHashingTF = HashingTF.load(hashingTFPath) @@ -543,22 +543,30 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, Java "rather than integer counts. Default False.", typeConverter=TypeConverters.toBoolean) + hashAlgorithm = Param(Params._dummy(), "hashAlgorithm", "The hash algorithm used when " + + "mapping term to integer. Supported options: murmur3(default) " + + "and native.", typeConverter=TypeConverters.toString) + @keyword_only - def __init__(self, numFeatures=1 << 18, binary=False, inputCol=None, outputCol=None): + def __init__(self, numFeatures=1 << 18, binary=False, inputCol=None, outputCol=None, + hashAlgorithm="murmur3"): """ - __init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None) + __init__(self, numFeatures=1 << 18, binary=False, inputCol=None, outputCol=None, \ + hashAlgorithm="murmur3") """ super(HashingTF, self).__init__() self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.HashingTF", self.uid) - self._setDefault(numFeatures=1 << 18, binary=False) + self._setDefault(numFeatures=1 << 18, binary=False, hashAlgorithm="murmur3") kwargs = self.__init__._input_kwargs self.setParams(**kwargs) @keyword_only @since("1.3.0") - def setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None): + def setParams(self, numFeatures=1 << 18, binary=False, inputCol=None, outputCol=None, + hashAlgorithm="murmur3"): """ - setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None) + setParams(self, numFeatures=1 << 18, binary=False, inputCol=None, outputCol=None, \ + hashAlgorithm="murmur3") Sets params for this HashingTF. """ kwargs = self.setParams._input_kwargs @@ -579,6 +587,21 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures, Java """ return self.getOrDefault(self.binary) + @since("2.0.0") + def setHashAlgorithm(self, value): + """ + Sets the value of :py:attr:`hashAlgorithm`. + """ + self._set(hashAlgorithm=value) + return self + + @since("2.0.0") + def getHashAlgorithm(self): + """ + Gets the value of hashAlgorithm or its default value. + """ + return self.getOrDefault(self.hashAlgorithm) + @inherit_doc class IDF(JavaEstimator, HasInputCol, HasOutputCol, JavaMLReadable, JavaMLWritable): diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index e95458699d..8954e96df9 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -918,7 +918,8 @@ class HashingTFTest(PySparkTestCase): df = sqlContext.createDataFrame([(0, ["a", "a", "b", "c", "c", "c"])], ["id", "words"]) n = 100 hashingTF = HashingTF() - hashingTF.setInputCol("words").setOutputCol("features").setNumFeatures(n).setBinary(True) + hashingTF.setInputCol("words").setOutputCol("features").setNumFeatures(n)\ + .setBinary(True).setHashAlgorithm("native") output = hashingTF.transform(df) features = output.select("features").first().features.toArray() expected = Vectors.sparse(n, {(ord("a") % n): 1.0, |