aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYuhao Yang <hhbyyh@gmail.com>2015-08-18 11:00:09 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-08-18 11:00:09 -0700
commit354f4582b637fa25d3892ec2b12869db50ed83c9 (patch)
treea0e4202868d5b34b59a5789cd60d0d0ccbaa74bf /mllib
parent1968276af0f681fe51328b7dd795bd21724a5441 (diff)
downloadspark-354f4582b637fa25d3892ec2b12869db50ed83c9.tar.gz
spark-354f4582b637fa25d3892ec2b12869db50ed83c9.tar.bz2
spark-354f4582b637fa25d3892ec2b12869db50ed83c9.zip
[SPARK-9028] [ML] Add CountVectorizer as an estimator to generate CountVectorizerModel
jira: https://issues.apache.org/jira/browse/SPARK-9028 Add an estimator for CountVectorizerModel. The estimator will extract a vocabulary from document collections according to the term frequency. I changed the meaning of minCount as a filter across the corpus. This aligns with Word2Vec and the similar parameter in SKlearn. Author: Yuhao Yang <hhbyyh@gmail.com> Author: Joseph K. Bradley <joseph@databricks.com> Closes #7388 from hhbyyh/cvEstimator.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala235
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala82
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala167
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala73
4 files changed, 402 insertions, 155 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
new file mode 100644
index 0000000000..49028e4b85
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
@@ -0,0 +1,235 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.ml.feature
+
+import org.apache.spark.annotation.Experimental
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.ml.param._
+import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
+import org.apache.spark.ml.util.{Identifiable, SchemaUtils}
+import org.apache.spark.ml.{Estimator, Model}
+import org.apache.spark.mllib.linalg.{VectorUDT, Vectors}
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.DataFrame
+import org.apache.spark.util.collection.OpenHashMap
+
+/**
+ * Params for [[CountVectorizer]] and [[CountVectorizerModel]].
+ */
+private[feature] trait CountVectorizerParams extends Params with HasInputCol with HasOutputCol {
+
+ /**
+ * Max size of the vocabulary.
+ * CountVectorizer will build a vocabulary that only considers the top
+ * vocabSize terms ordered by term frequency across the corpus.
+ *
+ * Default: 2^18^
+ * @group param
+ */
+ val vocabSize: IntParam =
+ new IntParam(this, "vocabSize", "max size of the vocabulary", ParamValidators.gt(0))
+
+ /** @group getParam */
+ def getVocabSize: Int = $(vocabSize)
+
+ /**
+ * Specifies the minimum number of different documents a term must appear in to be included
+ * in the vocabulary.
+ * If this is an integer >= 1, this specifies the number of documents the term must appear in;
+ * if this is a double in [0,1), then this specifies the fraction of documents.
+ *
+ * Default: 1
+ * @group param
+ */
+ val minDF: DoubleParam = new DoubleParam(this, "minDF", "Specifies the minimum number of" +
+ " different documents a term must appear in to be included in the vocabulary." +
+ " If this is an integer >= 1, this specifies the number of documents the term must" +
+ " appear in; if this is a double in [0,1), then this specifies the fraction of documents.",
+ ParamValidators.gtEq(0.0))
+
+ /** @group getParam */
+ def getMinDF: Double = $(minDF)
+
+ /** Validates and transforms the input schema. */
+ protected def validateAndTransformSchema(schema: StructType): StructType = {
+ SchemaUtils.checkColumnType(schema, $(inputCol), new ArrayType(StringType, true))
+ SchemaUtils.appendColumn(schema, $(outputCol), new VectorUDT)
+ }
+
+ /**
+ * Filter to ignore rare words in a document. For each document, terms with
+ * frequency/count less than the given threshold are ignored.
+ * If this is an integer >= 1, then this specifies a count (of times the term must appear
+ * in the document);
+ * if this is a double in [0,1), then this specifies a fraction (out of the document's token
+ * count).
+ *
+ * Note that the parameter is only used in transform of [[CountVectorizerModel]] and does not
+ * affect fitting.
+ *
+ * Default: 1
+ * @group param
+ */
+ val minTF: DoubleParam = new DoubleParam(this, "minTF", "Filter to ignore rare words in" +
+ " a document. For each document, terms with frequency/count less than the given threshold are" +
+ " ignored. If this is an integer >= 1, then this specifies a count (of times the term must" +
+ " appear in the document); if this is a double in [0,1), then this specifies a fraction (out" +
+ " of the document's token count). Note that the parameter is only used in transform of" +
+ " CountVectorizerModel and does not affect fitting.", ParamValidators.gtEq(0.0))
+
+ setDefault(minTF -> 1)
+
+ /** @group getParam */
+ def getMinTF: Double = $(minTF)
+}
+
+/**
+ * :: Experimental ::
+ * Extracts a vocabulary from document collections and generates a [[CountVectorizerModel]].
+ */
+@Experimental
+class CountVectorizer(override val uid: String)
+ extends Estimator[CountVectorizerModel] with CountVectorizerParams {
+
+ def this() = this(Identifiable.randomUID("cntVec"))
+
+ /** @group setParam */
+ def setInputCol(value: String): this.type = set(inputCol, value)
+
+ /** @group setParam */
+ def setOutputCol(value: String): this.type = set(outputCol, value)
+
+ /** @group setParam */
+ def setVocabSize(value: Int): this.type = set(vocabSize, value)
+
+ /** @group setParam */
+ def setMinDF(value: Double): this.type = set(minDF, value)
+
+ /** @group setParam */
+ def setMinTF(value: Double): this.type = set(minTF, value)
+
+ setDefault(vocabSize -> (1 << 18), minDF -> 1)
+
+ override def fit(dataset: DataFrame): CountVectorizerModel = {
+ transformSchema(dataset.schema, logging = true)
+ val vocSize = $(vocabSize)
+ val input = dataset.select($(inputCol)).map(_.getAs[Seq[String]](0))
+ val minDf = if ($(minDF) >= 1.0) {
+ $(minDF)
+ } else {
+ $(minDF) * input.cache().count()
+ }
+ val wordCounts: RDD[(String, Long)] = input.flatMap { case (tokens) =>
+ val wc = new OpenHashMap[String, Long]
+ tokens.foreach { w =>
+ wc.changeValue(w, 1L, _ + 1L)
+ }
+ wc.map { case (word, count) => (word, (count, 1)) }
+ }.reduceByKey { case ((wc1, df1), (wc2, df2)) =>
+ (wc1 + wc2, df1 + df2)
+ }.filter { case (word, (wc, df)) =>
+ df >= minDf
+ }.map { case (word, (count, dfCount)) =>
+ (word, count)
+ }.cache()
+ val fullVocabSize = wordCounts.count()
+ val vocab: Array[String] = {
+ val tmpSortedWC: Array[(String, Long)] = if (fullVocabSize <= vocSize) {
+ // Use all terms
+ wordCounts.collect().sortBy(-_._2)
+ } else {
+ // Sort terms to select vocab
+ wordCounts.sortBy(_._2, ascending = false).take(vocSize)
+ }
+ tmpSortedWC.map(_._1)
+ }
+
+ require(vocab.length > 0, "The vocabulary size should be > 0. Lower minDF as necessary.")
+ copyValues(new CountVectorizerModel(uid, vocab).setParent(this))
+ }
+
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema)
+ }
+
+ override def copy(extra: ParamMap): CountVectorizer = defaultCopy(extra)
+}
+
+/**
+ * :: Experimental ::
+ * Converts a text document to a sparse vector of token counts.
+ * @param vocabulary An Array over terms. Only the terms in the vocabulary will be counted.
+ */
+@Experimental
+class CountVectorizerModel(override val uid: String, val vocabulary: Array[String])
+ extends Model[CountVectorizerModel] with CountVectorizerParams {
+
+ def this(vocabulary: Array[String]) = {
+ this(Identifiable.randomUID("cntVecModel"), vocabulary)
+ set(vocabSize, vocabulary.length)
+ }
+
+ /** @group setParam */
+ def setInputCol(value: String): this.type = set(inputCol, value)
+
+ /** @group setParam */
+ def setOutputCol(value: String): this.type = set(outputCol, value)
+
+ /** @group setParam */
+ def setMinTF(value: Double): this.type = set(minTF, value)
+
+ /** Dictionary created from [[vocabulary]] and its indices, broadcast once for [[transform()]] */
+ private var broadcastDict: Option[Broadcast[Map[String, Int]]] = None
+
+ override def transform(dataset: DataFrame): DataFrame = {
+ if (broadcastDict.isEmpty) {
+ val dict = vocabulary.zipWithIndex.toMap
+ broadcastDict = Some(dataset.sqlContext.sparkContext.broadcast(dict))
+ }
+ val dictBr = broadcastDict.get
+ val minTf = $(minTF)
+ val vectorizer = udf { (document: Seq[String]) =>
+ val termCounts = new OpenHashMap[Int, Double]
+ var tokenCount = 0L
+ document.foreach { term =>
+ dictBr.value.get(term) match {
+ case Some(index) => termCounts.changeValue(index, 1.0, _ + 1.0)
+ case None => // ignore terms not in the vocabulary
+ }
+ tokenCount += 1
+ }
+ val effectiveMinTF = if (minTf >= 1.0) {
+ minTf
+ } else {
+ tokenCount * minTf
+ }
+ Vectors.sparse(dictBr.value.size, termCounts.filter(_._2 >= effectiveMinTF).toSeq)
+ }
+ dataset.withColumn($(outputCol), vectorizer(col($(inputCol))))
+ }
+
+ override def transformSchema(schema: StructType): StructType = {
+ validateAndTransformSchema(schema)
+ }
+
+ override def copy(extra: ParamMap): CountVectorizerModel = {
+ val copied = new CountVectorizerModel(uid, vocabulary).setParent(parent)
+ copyValues(copied, extra)
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala
deleted file mode 100644
index 6b77de89a0..0000000000
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizerModel.scala
+++ /dev/null
@@ -1,82 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.ml.feature
-
-import scala.collection.mutable
-
-import org.apache.spark.annotation.Experimental
-import org.apache.spark.ml.UnaryTransformer
-import org.apache.spark.ml.param.{ParamMap, ParamValidators, IntParam}
-import org.apache.spark.ml.util.Identifiable
-import org.apache.spark.mllib.linalg.{Vectors, VectorUDT, Vector}
-import org.apache.spark.sql.types.{StringType, ArrayType, DataType}
-
-/**
- * :: Experimental ::
- * Converts a text document to a sparse vector of token counts.
- * @param vocabulary An Array over terms. Only the terms in the vocabulary will be counted.
- */
-@Experimental
-class CountVectorizerModel (override val uid: String, val vocabulary: Array[String])
- extends UnaryTransformer[Seq[String], Vector, CountVectorizerModel] {
-
- def this(vocabulary: Array[String]) =
- this(Identifiable.randomUID("cntVec"), vocabulary)
-
- /**
- * Corpus-specific filter to ignore scarce words in a document. For each document, terms with
- * frequency (count) less than the given threshold are ignored.
- * Default: 1
- * @group param
- */
- val minTermFreq: IntParam = new IntParam(this, "minTermFreq",
- "minimum frequency (count) filter used to neglect scarce words (>= 1). For each document, " +
- "terms with frequency less than the given threshold are ignored.", ParamValidators.gtEq(1))
-
- /** @group setParam */
- def setMinTermFreq(value: Int): this.type = set(minTermFreq, value)
-
- /** @group getParam */
- def getMinTermFreq: Int = $(minTermFreq)
-
- setDefault(minTermFreq -> 1)
-
- override protected def createTransformFunc: Seq[String] => Vector = {
- val dict = vocabulary.zipWithIndex.toMap
- document =>
- val termCounts = mutable.HashMap.empty[Int, Double]
- document.foreach { term =>
- dict.get(term) match {
- case Some(index) => termCounts.put(index, termCounts.getOrElse(index, 0.0) + 1.0)
- case None => // ignore terms not in the vocabulary
- }
- }
- Vectors.sparse(dict.size, termCounts.filter(_._2 >= $(minTermFreq)).toSeq)
- }
-
- override protected def validateInputType(inputType: DataType): Unit = {
- require(inputType.sameType(ArrayType(StringType)),
- s"Input type must be ArrayType(StringType) but got $inputType.")
- }
-
- override protected def outputDataType: DataType = new VectorUDT()
-
- override def copy(extra: ParamMap): CountVectorizerModel = {
- val copied = new CountVectorizerModel(uid, vocabulary)
- copyValues(copied, extra)
- }
-}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
new file mode 100644
index 0000000000..e192fa4850
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
@@ -0,0 +1,167 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.spark.ml.feature
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.param.ParamsSuite
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
+import org.apache.spark.sql.Row
+
+class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ test("params") {
+ ParamsSuite.checkParams(new CountVectorizerModel(Array("empty")))
+ }
+
+ private def split(s: String): Seq[String] = s.split("\\s+")
+
+ test("CountVectorizerModel common cases") {
+ val df = sqlContext.createDataFrame(Seq(
+ (0, split("a b c d"),
+ Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))),
+ (1, split("a b b c d a"),
+ Vectors.sparse(4, Seq((0, 2.0), (1, 2.0), (2, 1.0), (3, 1.0)))),
+ (2, split("a"), Vectors.sparse(4, Seq((0, 1.0)))),
+ (3, split(""), Vectors.sparse(4, Seq())), // empty string
+ (4, split("a notInDict d"),
+ Vectors.sparse(4, Seq((0, 1.0), (3, 1.0)))) // with words not in vocabulary
+ )).toDF("id", "words", "expected")
+ val cv = new CountVectorizerModel(Array("a", "b", "c", "d"))
+ .setInputCol("words")
+ .setOutputCol("features")
+ cv.transform(df).select("features", "expected").collect().foreach {
+ case Row(features: Vector, expected: Vector) =>
+ assert(features ~== expected absTol 1e-14)
+ }
+ }
+
+ test("CountVectorizer common cases") {
+ val df = sqlContext.createDataFrame(Seq(
+ (0, split("a b c d e"),
+ Vectors.sparse(5, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0), (4, 1.0)))),
+ (1, split("a a a a a a"), Vectors.sparse(5, Seq((0, 6.0)))),
+ (2, split("c"), Vectors.sparse(5, Seq((2, 1.0)))),
+ (3, split("b b b b b"), Vectors.sparse(5, Seq((1, 5.0)))))
+ ).toDF("id", "words", "expected")
+ val cv = new CountVectorizer()
+ .setInputCol("words")
+ .setOutputCol("features")
+ .fit(df)
+ assert(cv.vocabulary === Array("a", "b", "c", "d", "e"))
+
+ cv.transform(df).select("features", "expected").collect().foreach {
+ case Row(features: Vector, expected: Vector) =>
+ assert(features ~== expected absTol 1e-14)
+ }
+ }
+
+ test("CountVectorizer vocabSize and minDF") {
+ val df = sqlContext.createDataFrame(Seq(
+ (0, split("a b c d"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))),
+ (1, split("a b c"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))),
+ (2, split("a b"), Vectors.sparse(3, Seq((0, 1.0), (1, 1.0)))),
+ (3, split("a"), Vectors.sparse(3, Seq((0, 1.0)))))
+ ).toDF("id", "words", "expected")
+ val cvModel = new CountVectorizer()
+ .setInputCol("words")
+ .setOutputCol("features")
+ .setVocabSize(3) // limit vocab size to 3
+ .fit(df)
+ assert(cvModel.vocabulary === Array("a", "b", "c"))
+
+ // minDF: ignore terms with count less than 3
+ val cvModel2 = new CountVectorizer()
+ .setInputCol("words")
+ .setOutputCol("features")
+ .setMinDF(3)
+ .fit(df)
+ assert(cvModel2.vocabulary === Array("a", "b"))
+
+ cvModel2.transform(df).select("features", "expected").collect().foreach {
+ case Row(features: Vector, expected: Vector) =>
+ assert(features ~== expected absTol 1e-14)
+ }
+
+ // minDF: ignore terms with freq < 0.75
+ val cvModel3 = new CountVectorizer()
+ .setInputCol("words")
+ .setOutputCol("features")
+ .setMinDF(3.0 / df.count())
+ .fit(df)
+ assert(cvModel3.vocabulary === Array("a", "b"))
+
+ cvModel3.transform(df).select("features", "expected").collect().foreach {
+ case Row(features: Vector, expected: Vector) =>
+ assert(features ~== expected absTol 1e-14)
+ }
+ }
+
+ test("CountVectorizer throws exception when vocab is empty") {
+ intercept[IllegalArgumentException] {
+ val df = sqlContext.createDataFrame(Seq(
+ (0, split("a a b b c c")),
+ (1, split("aa bb cc")))
+ ).toDF("id", "words")
+ val cvModel = new CountVectorizer()
+ .setInputCol("words")
+ .setOutputCol("features")
+ .setVocabSize(3) // limit vocab size to 3
+ .setMinDF(3)
+ .fit(df)
+ }
+ }
+
+ test("CountVectorizerModel with minTF count") {
+ val df = sqlContext.createDataFrame(Seq(
+ (0, split("a a a b b c c c d "), Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))),
+ (1, split("c c c c c c"), Vectors.sparse(4, Seq((2, 6.0)))),
+ (2, split("a"), Vectors.sparse(4, Seq())),
+ (3, split("e e e e e"), Vectors.sparse(4, Seq())))
+ ).toDF("id", "words", "expected")
+
+ // minTF: count
+ val cv = new CountVectorizerModel(Array("a", "b", "c", "d"))
+ .setInputCol("words")
+ .setOutputCol("features")
+ .setMinTF(3)
+ cv.transform(df).select("features", "expected").collect().foreach {
+ case Row(features: Vector, expected: Vector) =>
+ assert(features ~== expected absTol 1e-14)
+ }
+ }
+
+ test("CountVectorizerModel with minTF freq") {
+ val df = sqlContext.createDataFrame(Seq(
+ (0, split("a a a b b c c c d "), Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))),
+ (1, split("c c c c c c"), Vectors.sparse(4, Seq((2, 6.0)))),
+ (2, split("a"), Vectors.sparse(4, Seq((0, 1.0)))),
+ (3, split("e e e e e"), Vectors.sparse(4, Seq())))
+ ).toDF("id", "words", "expected")
+
+ // minTF: count
+ val cv = new CountVectorizerModel(Array("a", "b", "c", "d"))
+ .setInputCol("words")
+ .setOutputCol("features")
+ .setMinTF(0.3)
+ cv.transform(df).select("features", "expected").collect().foreach {
+ case Row(features: Vector, expected: Vector) =>
+ assert(features ~== expected absTol 1e-14)
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala
deleted file mode 100644
index e90d9d4ef2..0000000000
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizorSuite.scala
+++ /dev/null
@@ -1,73 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-package org.apache.spark.ml.feature
-
-import org.apache.spark.SparkFunSuite
-import org.apache.spark.ml.param.ParamsSuite
-import org.apache.spark.mllib.linalg.{Vector, Vectors}
-import org.apache.spark.mllib.util.MLlibTestSparkContext
-import org.apache.spark.mllib.util.TestingUtils._
-
-class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext {
-
- test("params") {
- ParamsSuite.checkParams(new CountVectorizerModel(Array("empty")))
- }
-
- test("CountVectorizerModel common cases") {
- val df = sqlContext.createDataFrame(Seq(
- (0, "a b c d".split(" ").toSeq,
- Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0)))),
- (1, "a b b c d a".split(" ").toSeq,
- Vectors.sparse(4, Seq((0, 2.0), (1, 2.0), (2, 1.0), (3, 1.0)))),
- (2, "a".split(" ").toSeq, Vectors.sparse(4, Seq((0, 1.0)))),
- (3, "".split(" ").toSeq, Vectors.sparse(4, Seq())), // empty string
- (4, "a notInDict d".split(" ").toSeq,
- Vectors.sparse(4, Seq((0, 1.0), (3, 1.0)))) // with words not in vocabulary
- )).toDF("id", "words", "expected")
- val cv = new CountVectorizerModel(Array("a", "b", "c", "d"))
- .setInputCol("words")
- .setOutputCol("features")
- val output = cv.transform(df).collect()
- output.foreach { p =>
- val features = p.getAs[Vector]("features")
- val expected = p.getAs[Vector]("expected")
- assert(features ~== expected absTol 1e-14)
- }
- }
-
- test("CountVectorizerModel with minTermFreq") {
- val df = sqlContext.createDataFrame(Seq(
- (0, "a a a b b c c c d ".split(" ").toSeq, Vectors.sparse(4, Seq((0, 3.0), (2, 3.0)))),
- (1, "c c c c c c".split(" ").toSeq, Vectors.sparse(4, Seq((2, 6.0)))),
- (2, "a".split(" ").toSeq, Vectors.sparse(4, Seq())),
- (3, "e e e e e".split(" ").toSeq, Vectors.sparse(4, Seq())))
- ).toDF("id", "words", "expected")
- val cv = new CountVectorizerModel(Array("a", "b", "c", "d"))
- .setInputCol("words")
- .setOutputCol("features")
- .setMinTermFreq(3)
- val output = cv.transform(df).collect()
- output.foreach { p =>
- val features = p.getAs[Vector]("features")
- val expected = p.getAs[Vector]("expected")
- assert(features ~== expected absTol 1e-14)
- }
- }
-}
-
-