diff options
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala | 24 |
1 files changed, 23 insertions, 1 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala index 0dcd0f4946..addd733c20 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala @@ -46,12 +46,30 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with Defau require(attrGroup.numAttributes === Some(n)) val features = output.select("features").first().getAs[Vector](0) // Assume perfect hash on "a", "b", "c", and "d". - def idx(any: Any): Int = Utils.nonNegativeMod(any.##, n) + def idx: Any => Int = featureIdx(n) val expected = Vectors.sparse(n, Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0))) assert(features ~== expected absTol 1e-14) } + test("applying binary term freqs") { + val df = sqlContext.createDataFrame(Seq( + (0, "a a b c c c".split(" ").toSeq) + )).toDF("id", "words") + val n = 100 + val hashingTF = new HashingTF() + .setInputCol("words") + .setOutputCol("features") + .setNumFeatures(n) + .setBinary(true) + val output = hashingTF.transform(df) + val features = output.select("features").first().getAs[Vector](0) + def idx: Any => Int = featureIdx(n) // Assume perfect hash on input features + val expected = Vectors.sparse(n, + Seq((idx("a"), 1.0), (idx("b"), 1.0), (idx("c"), 1.0))) + assert(features ~== expected absTol 1e-14) + } + test("read/write") { val t = new HashingTF() .setInputCol("myInputCol") @@ -59,4 +77,8 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with Defau .setNumFeatures(10) testDefaultReadWrite(t) } + + private def featureIdx(numFeatures: Int)(term: Any): Int = { + Utils.nonNegativeMod(term.##, numFeatures) + } } |