diff options
Diffstat (limited to 'mllib/src/test')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala | 19 |
1 files changed, 18 insertions, 1 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala index 9c99990173..04f165c5f1 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala @@ -157,7 +157,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext (3, split("e e e e e"), Vectors.sparse(4, Seq()))) ).toDF("id", "words", "expected") - // minTF: count + // minTF: set frequency val cv = new CountVectorizerModel(Array("a", "b", "c", "d")) .setInputCol("words") .setOutputCol("features") @@ -168,6 +168,23 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext } } + test("CountVectorizerModel with binary") { + val df = sqlContext.createDataFrame(Seq( + (0, split("a a a b b c"), Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0)))), + (1, split("c c c"), Vectors.sparse(4, Seq((2, 1.0)))), + (2, split("a"), Vectors.sparse(4, Seq((0, 1.0)))) + )).toDF("id", "words", "expected") + + val cv = new CountVectorizerModel(Array("a", "b", "c", "d")) + .setInputCol("words") + .setOutputCol("features") + .setBinary(true) + cv.transform(df).select("features", "expected").collect().foreach { + case Row(features: Vector, expected: Vector) => + assert(features ~== expected absTol 1e-14) + } + } + test("CountVectorizer read/write") { val t = new CountVectorizer() .setInputCol("myInputCol") |