aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test
diff options
context:
space:
mode:
authorYuhao Yang <hhbyyh@gmail.com>2016-03-17 11:21:11 +0200
committerNick Pentreath <nick.pentreath@gmail.com>2016-03-17 11:21:11 +0200
commit357d82d84d6372debd28da6ad0a2ee904957a7fe (patch)
tree1c0facd6a63b865b7ea06ff516f69bf479a26cba /mllib/src/test
parent204c9dec2c3876d20558ef5bda4dbd6edaf59643 (diff)
downloadspark-357d82d84d6372debd28da6ad0a2ee904957a7fe.tar.gz
spark-357d82d84d6372debd28da6ad0a2ee904957a7fe.tar.bz2
spark-357d82d84d6372debd28da6ad0a2ee904957a7fe.zip
[SPARK-13629][ML] Add binary toggle Param to CountVectorizer
## What changes were proposed in this pull request? It would be handy to add a binary toggle Param to CountVectorizer, as in the scikit-learn one: http://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html If set, then all non-zero counts will be set to 1. ## How was this patch tested? unit tests Author: Yuhao Yang <hhbyyh@gmail.com> Closes #11536 from hhbyyh/cvToggle.
Diffstat (limited to 'mllib/src/test')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala19
1 files changed, 18 insertions, 1 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
index 9c99990173..04f165c5f1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
@@ -157,7 +157,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
(3, split("e e e e e"), Vectors.sparse(4, Seq())))
).toDF("id", "words", "expected")
- // minTF: count
+ // minTF: set frequency
val cv = new CountVectorizerModel(Array("a", "b", "c", "d"))
.setInputCol("words")
.setOutputCol("features")
@@ -168,6 +168,23 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
}
}
+ test("CountVectorizerModel with binary") {
+ val df = sqlContext.createDataFrame(Seq(
+ (0, split("a a a b b c"), Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0)))),
+ (1, split("c c c"), Vectors.sparse(4, Seq((2, 1.0)))),
+ (2, split("a"), Vectors.sparse(4, Seq((0, 1.0))))
+ )).toDF("id", "words", "expected")
+
+ val cv = new CountVectorizerModel(Array("a", "b", "c", "d"))
+ .setInputCol("words")
+ .setOutputCol("features")
+ .setBinary(true)
+ cv.transform(df).select("features", "expected").collect().foreach {
+ case Row(features: Vector, expected: Vector) =>
+ assert(features ~== expected absTol 1e-14)
+ }
+ }
+
test("CountVectorizer read/write") {
val t = new CountVectorizer()
.setInputCol("myInputCol")