diff options
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala | 37 |
1 files changed, 13 insertions, 24 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala index e32b862af7..44bad4aba4 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala @@ -38,26 +38,19 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with Defau (0, "a a b b c d".split(" ").toSeq) )).toDF("id", "words") val n = 100 - Seq("murmur3", "native").foreach { hashAlgorithm => - val hashingTF = new HashingTF() - .setInputCol("words") - .setOutputCol("features") - .setNumFeatures(n) - .setHashAlgorithm(hashAlgorithm) - val output = hashingTF.transform(df) - val attrGroup = AttributeGroup.fromStructField(output.schema("features")) - require(attrGroup.numAttributes === Some(n)) - val features = output.select("features").first().getAs[Vector](0) - // Assume perfect hash on "a", "b", "c", and "d". - def idx: Any => Int = if (hashAlgorithm == "murmur3") { - murmur3FeatureIdx(n) - } else { - nativeFeatureIdx(n) - } - val expected = Vectors.sparse(n, - Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0))) - assert(features ~== expected absTol 1e-14) - } + val hashingTF = new HashingTF() + .setInputCol("words") + .setOutputCol("features") + .setNumFeatures(n) + val output = hashingTF.transform(df) + val attrGroup = AttributeGroup.fromStructField(output.schema("features")) + require(attrGroup.numAttributes === Some(n)) + val features = output.select("features").first().getAs[Vector](0) + // Assume perfect hash on "a", "b", "c", and "d". + def idx: Any => Int = murmur3FeatureIdx(n) + val expected = Vectors.sparse(n, + Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0))) + assert(features ~== expected absTol 1e-14) } test("applying binary term freqs") { @@ -86,10 +79,6 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with Defau testDefaultReadWrite(t) } - private def nativeFeatureIdx(numFeatures: Int)(term: Any): Int = { - Utils.nonNegativeMod(MLlibHashingTF.nativeHash(term), numFeatures) - } - private def murmur3FeatureIdx(numFeatures: Int)(term: Any): Int = { Utils.nonNegativeMod(MLlibHashingTF.murmur3Hash(term), numFeatures) } |