aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala57
1 files changed, 27 insertions, 30 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
index 5694b3890f..922670a41b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
@@ -26,7 +26,7 @@ import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.linalg.{Vectors, VectorUDT}
import org.apache.spark.rdd.RDD
-import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.{DataFrame, Dataset}
import org.apache.spark.sql.functions._
import org.apache.spark.sql.types._
import org.apache.spark.util.collection.OpenHashMap
@@ -100,6 +100,21 @@ private[feature] trait CountVectorizerParams extends Params with HasInputCol wit
/** @group getParam */
def getMinTF: Double = $(minTF)
+
+ /**
+ * Binary toggle to control the output vector values.
+ * If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful for
+ * discrete probabilistic models that model binary events rather than integer counts.
+ * Default: false
+ * @group param
+ */
+ val binary: BooleanParam =
+ new BooleanParam(this, "binary", "If True, all non zero counts are set to 1.")
+
+ /** @group getParam */
+ def getBinary: Boolean = $(binary)
+
+ setDefault(binary -> false)
}
/**
@@ -127,9 +142,13 @@ class CountVectorizer(override val uid: String)
/** @group setParam */
def setMinTF(value: Double): this.type = set(minTF, value)
+ /** @group setParam */
+ def setBinary(value: Boolean): this.type = set(binary, value)
+
setDefault(vocabSize -> (1 << 18), minDF -> 1)
- override def fit(dataset: DataFrame): CountVectorizerModel = {
+ @Since("2.0.0")
+ override def fit(dataset: Dataset[_]): CountVectorizerModel = {
transformSchema(dataset.schema, logging = true)
val vocSize = $(vocabSize)
val input = dataset.select($(inputCol)).rdd.map(_.getAs[Seq[String]](0))
@@ -152,16 +171,10 @@ class CountVectorizer(override val uid: String)
(word, count)
}.cache()
val fullVocabSize = wordCounts.count()
- val vocab: Array[String] = {
- val tmpSortedWC: Array[(String, Long)] = if (fullVocabSize <= vocSize) {
- // Use all terms
- wordCounts.collect().sortBy(-_._2)
- } else {
- // Sort terms to select vocab
- wordCounts.sortBy(_._2, ascending = false).take(vocSize)
- }
- tmpSortedWC.map(_._1)
- }
+
+ val vocab = wordCounts
+ .top(math.min(fullVocabSize, vocSize).toInt)(Ordering.by(_._2))
+ .map(_._1)
require(vocab.length > 0, "The vocabulary size should be > 0. Lower minDF as necessary.")
copyValues(new CountVectorizerModel(uid, vocab).setParent(this))
@@ -206,30 +219,14 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin
/** @group setParam */
def setMinTF(value: Double): this.type = set(minTF, value)
- /**
- * Binary toggle to control the output vector values.
- * If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful for
- * discrete probabilistic models that model binary events rather than integer counts.
- * Default: false
- * @group param
- */
- val binary: BooleanParam =
- new BooleanParam(this, "binary", "If True, all non zero counts are set to 1. " +
- "This is useful for discrete probabilistic models that model binary events rather " +
- "than integer counts")
-
- /** @group getParam */
- def getBinary: Boolean = $(binary)
-
/** @group setParam */
def setBinary(value: Boolean): this.type = set(binary, value)
- setDefault(binary -> false)
-
/** Dictionary created from [[vocabulary]] and its indices, broadcast once for [[transform()]] */
private var broadcastDict: Option[Broadcast[Map[String, Int]]] = None
- override def transform(dataset: DataFrame): DataFrame = {
+ @Since("2.0.0")
+ override def transform(dataset: Dataset[_]): DataFrame = {
transformSchema(dataset.schema, logging = true)
if (broadcastDict.isEmpty) {
val dict = vocabulary.zipWithIndex.toMap