aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala23
1 files changed, 9 insertions, 14 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
index a3845d3977..5694b3890f 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
@@ -207,13 +207,12 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin
def setMinTF(value: Double): this.type = set(minTF, value)
/**
- * Binary toggle to control the output vector values.
- * If True, all non zero counts are set to 1. This is useful for discrete probabilistic
- * models that model binary events rather than integer counts
- *
- * Default: false
- * @group param
- */
+ * Binary toggle to control the output vector values.
+ * If True, all nonzero counts (after minTF filter applied) are set to 1. This is useful for
+ * discrete probabilistic models that model binary events rather than integer counts.
+ * Default: false
+ * @group param
+ */
val binary: BooleanParam =
new BooleanParam(this, "binary", "If True, all non zero counts are set to 1. " +
"This is useful for discrete probabilistic models that model binary events rather " +
@@ -248,17 +247,13 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin
}
tokenCount += 1
}
- val effectiveMinTF = if (minTf >= 1.0) {
- minTf
- } else {
- tokenCount * minTf
- }
+ val effectiveMinTF = if (minTf >= 1.0) minTf else tokenCount * minTf
val effectiveCounts = if ($(binary)) {
termCounts.filter(_._2 >= effectiveMinTF).map(p => (p._1, 1.0)).toSeq
- }
- else {
+ } else {
termCounts.filter(_._2 >= effectiveMinTF).toSeq
}
+
Vectors.sparse(dictBr.value.size, effectiveCounts)
}
dataset.withColumn($(outputCol), vectorizer(col($(inputCol))))