aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala')
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala12
1 files changed, 12 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala
index cf279c0233..6c07e3a5ce 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -48,4 +49,15 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext {
val docs = sc.parallelize(localDocs, 2)
assert(hashingTF.transform(docs).collect().toSet === localDocs.map(hashingTF.transform).toSet)
}
+
+ test("applying binary term freqs") {
+ val hashingTF = new HashingTF(100).setBinary(true)
+ val doc = "a a b c c c".split(" ")
+ val n = hashingTF.numFeatures
+ val expected = Vectors.sparse(n, Seq(
+ (hashingTF.indexOf("a"), 1.0),
+ (hashingTF.indexOf("b"), 1.0),
+ (hashingTF.indexOf("c"), 1.0)))
+ assert(hashingTF.transform(doc) ~== expected absTol 1e-14)
+ }
}