aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala37
1 files changed, 13 insertions, 24 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
index e32b862af7..44bad4aba4 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
@@ -38,26 +38,19 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
(0, "a a b b c d".split(" ").toSeq)
)).toDF("id", "words")
val n = 100
- Seq("murmur3", "native").foreach { hashAlgorithm =>
- val hashingTF = new HashingTF()
- .setInputCol("words")
- .setOutputCol("features")
- .setNumFeatures(n)
- .setHashAlgorithm(hashAlgorithm)
- val output = hashingTF.transform(df)
- val attrGroup = AttributeGroup.fromStructField(output.schema("features"))
- require(attrGroup.numAttributes === Some(n))
- val features = output.select("features").first().getAs[Vector](0)
- // Assume perfect hash on "a", "b", "c", and "d".
- def idx: Any => Int = if (hashAlgorithm == "murmur3") {
- murmur3FeatureIdx(n)
- } else {
- nativeFeatureIdx(n)
- }
- val expected = Vectors.sparse(n,
- Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0)))
- assert(features ~== expected absTol 1e-14)
- }
+ val hashingTF = new HashingTF()
+ .setInputCol("words")
+ .setOutputCol("features")
+ .setNumFeatures(n)
+ val output = hashingTF.transform(df)
+ val attrGroup = AttributeGroup.fromStructField(output.schema("features"))
+ require(attrGroup.numAttributes === Some(n))
+ val features = output.select("features").first().getAs[Vector](0)
+ // Assume perfect hash on "a", "b", "c", and "d".
+ def idx: Any => Int = murmur3FeatureIdx(n)
+ val expected = Vectors.sparse(n,
+ Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0)))
+ assert(features ~== expected absTol 1e-14)
}
test("applying binary term freqs") {
@@ -86,10 +79,6 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
testDefaultReadWrite(t)
}
- private def nativeFeatureIdx(numFeatures: Int)(term: Any): Int = {
- Utils.nonNegativeMod(MLlibHashingTF.nativeHash(term), numFeatures)
- }
-
private def murmur3FeatureIdx(numFeatures: Int)(term: Any): Int = {
Utils.nonNegativeMod(MLlibHashingTF.murmur3Hash(term), numFeatures)
}