aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorBryan Cutler <cutlerb@gmail.com>2016-03-29 12:30:30 +0200
committerNick Pentreath <nick.pentreath@gmail.com>2016-03-29 12:30:30 +0200
commit425bcf6d6844732fe402af05472ad87b4e032cb6 (patch)
treefea575f26678d2ae8646b6e4eec913226509e7bf /mllib/src/test
parent83775bc78e183791f75a99cdfbcd68a67ca0d472 (diff)
downloadspark-425bcf6d6844732fe402af05472ad87b4e032cb6.tar.gz
spark-425bcf6d6844732fe402af05472ad87b4e032cb6.tar.bz2
spark-425bcf6d6844732fe402af05472ad87b4e032cb6.zip
[SPARK-13963][ML] Adding binary toggle param to HashingTF
## What changes were proposed in this pull request? Adding binary toggle parameter to ml.feature.HashingTF, as well as mllib.feature.HashingTF since the former wraps this functionality. This parameter, if true, will set non-zero valued term counts to 1 to transform term count features to binary values that are well suited for discrete probability models. ## How was this patch tested? Added unit tests for ML and MLlib Author: Bryan Cutler <cutlerb@gmail.com> Closes #11832 from BryanCutler/binary-param-HashingTF-SPARK-13963.
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala24
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala12
2 files changed, 35 insertions, 1 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
index 0dcd0f4946..addd733c20 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
@@ -46,12 +46,30 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
require(attrGroup.numAttributes === Some(n))
val features = output.select("features").first().getAs[Vector](0)
// Assume perfect hash on "a", "b", "c", and "d".
- def idx(any: Any): Int = Utils.nonNegativeMod(any.##, n)
+ def idx: Any => Int = featureIdx(n)
val expected = Vectors.sparse(n,
Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0)))
assert(features ~== expected absTol 1e-14)
}
+ test("applying binary term freqs") {
+ val df = sqlContext.createDataFrame(Seq(
+ (0, "a a b c c c".split(" ").toSeq)
+ )).toDF("id", "words")
+ val n = 100
+ val hashingTF = new HashingTF()
+ .setInputCol("words")
+ .setOutputCol("features")
+ .setNumFeatures(n)
+ .setBinary(true)
+ val output = hashingTF.transform(df)
+ val features = output.select("features").first().getAs[Vector](0)
+ def idx: Any => Int = featureIdx(n) // Assume perfect hash on input features
+ val expected = Vectors.sparse(n,
+ Seq((idx("a"), 1.0), (idx("b"), 1.0), (idx("c"), 1.0)))
+ assert(features ~== expected absTol 1e-14)
+ }
+
test("read/write") {
val t = new HashingTF()
.setInputCol("myInputCol")
@@ -59,4 +77,8 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
.setNumFeatures(10)
testDefaultReadWrite(t)
}
+
+ private def featureIdx(numFeatures: Int)(term: Any): Int = {
+ Utils.nonNegativeMod(term.##, numFeatures)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala
index cf279c0233..6c07e3a5ce 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -48,4 +49,15 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext {
val docs = sc.parallelize(localDocs, 2)
assert(hashingTF.transform(docs).collect().toSet === localDocs.map(hashingTF.transform).toSet)
}
+
+ test("applying binary term freqs") {
+ val hashingTF = new HashingTF(100).setBinary(true)
+ val doc = "a a b c c c".split(" ")
+ val n = hashingTF.numFeatures
+ val expected = Vectors.sparse(n, Seq(
+ (hashingTF.indexOf("a"), 1.0),
+ (hashingTF.indexOf("b"), 1.0),
+ (hashingTF.indexOf("c"), 1.0)))
+ assert(hashingTF.transform(doc) ~== expected absTol 1e-14)
+ }
}