diff options
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala | 12 |
1 files changed, 12 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala index cf279c0233..6c07e3a5ce 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala @@ -20,6 +20,7 @@ package org.apache.spark.mllib.feature import org.apache.spark.SparkFunSuite import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.util.MLlibTestSparkContext +import org.apache.spark.mllib.util.TestingUtils._ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -48,4 +49,15 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext { val docs = sc.parallelize(localDocs, 2) assert(hashingTF.transform(docs).collect().toSet === localDocs.map(hashingTF.transform).toSet) } + + test("applying binary term freqs") { + val hashingTF = new HashingTF(100).setBinary(true) + val doc = "a a b c c c".split(" ") + val n = hashingTF.numFeatures + val expected = Vectors.sparse(n, Seq( + (hashingTF.indexOf("a"), 1.0), + (hashingTF.indexOf("b"), 1.0), + (hashingTF.indexOf("c"), 1.0))) + assert(hashingTF.transform(doc) ~== expected absTol 1e-14) + } } |