aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala24
1 files changed, 23 insertions, 1 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
index 0dcd0f4946..addd733c20 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
@@ -46,12 +46,30 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
require(attrGroup.numAttributes === Some(n))
val features = output.select("features").first().getAs[Vector](0)
// Assume perfect hash on "a", "b", "c", and "d".
- def idx(any: Any): Int = Utils.nonNegativeMod(any.##, n)
+ def idx: Any => Int = featureIdx(n)
val expected = Vectors.sparse(n,
Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0)))
assert(features ~== expected absTol 1e-14)
}
+ test("applying binary term freqs") {
+ val df = sqlContext.createDataFrame(Seq(
+ (0, "a a b c c c".split(" ").toSeq)
+ )).toDF("id", "words")
+ val n = 100
+ val hashingTF = new HashingTF()
+ .setInputCol("words")
+ .setOutputCol("features")
+ .setNumFeatures(n)
+ .setBinary(true)
+ val output = hashingTF.transform(df)
+ val features = output.select("features").first().getAs[Vector](0)
+ def idx: Any => Int = featureIdx(n) // Assume perfect hash on input features
+ val expected = Vectors.sparse(n,
+ Seq((idx("a"), 1.0), (idx("b"), 1.0), (idx("c"), 1.0)))
+ assert(features ~== expected absTol 1e-14)
+ }
+
test("read/write") {
val t = new HashingTF()
.setInputCol("myInputCol")
@@ -59,4 +77,8 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
.setNumFeatures(10)
testDefaultReadWrite(t)
}
+
+ private def featureIdx(numFeatures: Int)(term: Any): Int = {
+ Utils.nonNegativeMod(term.##, numFeatures)
+ }
}