aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYuhao Yang <hhbyyh@gmail.com>2016-03-17 11:21:11 +0200
committerNick Pentreath <nick.pentreath@gmail.com>2016-03-17 11:21:11 +0200
commit357d82d84d6372debd28da6ad0a2ee904957a7fe (patch)
tree1c0facd6a63b865b7ea06ff516f69bf479a26cba /mllib
parent204c9dec2c3876d20558ef5bda4dbd6edaf59643 (diff)
downloadspark-357d82d84d6372debd28da6ad0a2ee904957a7fe.tar.gz
spark-357d82d84d6372debd28da6ad0a2ee904957a7fe.tar.bz2
spark-357d82d84d6372debd28da6ad0a2ee904957a7fe.zip
[SPARK-13629][ML] Add binary toggle Param to CountVectorizer
## What changes were proposed in this pull request? It would be handy to add a binary toggle Param to CountVectorizer, as in the scikit-learn one: http://scikit-learn.org/stable/modules/generated/sklearn.feature_extraction.text.CountVectorizer.html If set, then all non-zero counts will be set to 1. ## How was this patch tested? unit tests Author: Yuhao Yang <hhbyyh@gmail.com> Closes #11536 from hhbyyh/cvToggle.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala29
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala19
2 files changed, 46 insertions, 2 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
index f7d08b39a9..a3845d3977 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
@@ -206,6 +206,27 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin
/** @group setParam */
def setMinTF(value: Double): this.type = set(minTF, value)
+ /**
+ * Binary toggle to control the output vector values.
+ * If True, all non zero counts are set to 1. This is useful for discrete probabilistic
+ * models that model binary events rather than integer counts
+ *
+ * Default: false
+ * @group param
+ */
+ val binary: BooleanParam =
+ new BooleanParam(this, "binary", "If True, all non zero counts are set to 1. " +
+ "This is useful for discrete probabilistic models that model binary events rather " +
+ "than integer counts")
+
+ /** @group getParam */
+ def getBinary: Boolean = $(binary)
+
+ /** @group setParam */
+ def setBinary(value: Boolean): this.type = set(binary, value)
+
+ setDefault(binary -> false)
+
/** Dictionary created from [[vocabulary]] and its indices, broadcast once for [[transform()]] */
private var broadcastDict: Option[Broadcast[Map[String, Int]]] = None
@@ -232,7 +253,13 @@ class CountVectorizerModel(override val uid: String, val vocabulary: Array[Strin
} else {
tokenCount * minTf
}
- Vectors.sparse(dictBr.value.size, termCounts.filter(_._2 >= effectiveMinTF).toSeq)
+ val effectiveCounts = if ($(binary)) {
+ termCounts.filter(_._2 >= effectiveMinTF).map(p => (p._1, 1.0)).toSeq
+ }
+ else {
+ termCounts.filter(_._2 >= effectiveMinTF).toSeq
+ }
+ Vectors.sparse(dictBr.value.size, effectiveCounts)
}
dataset.withColumn($(outputCol), vectorizer(col($(inputCol))))
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
index 9c99990173..04f165c5f1 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
@@ -157,7 +157,7 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
(3, split("e e e e e"), Vectors.sparse(4, Seq())))
).toDF("id", "words", "expected")
- // minTF: count
+ // minTF: set frequency
val cv = new CountVectorizerModel(Array("a", "b", "c", "d"))
.setInputCol("words")
.setOutputCol("features")
@@ -168,6 +168,23 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
}
}
+ test("CountVectorizerModel with binary") {
+ val df = sqlContext.createDataFrame(Seq(
+ (0, split("a a a b b c"), Vectors.sparse(4, Seq((0, 1.0), (1, 1.0), (2, 1.0)))),
+ (1, split("c c c"), Vectors.sparse(4, Seq((2, 1.0)))),
+ (2, split("a"), Vectors.sparse(4, Seq((0, 1.0))))
+ )).toDF("id", "words", "expected")
+
+ val cv = new CountVectorizerModel(Array("a", "b", "c", "d"))
+ .setInputCol("words")
+ .setOutputCol("features")
+ .setBinary(true)
+ cv.transform(df).select("features", "expected").collect().foreach {
+ case Row(features: Vector, expected: Vector) =>
+ assert(features ~== expected absTol 1e-14)
+ }
+ }
+
test("CountVectorizer read/write") {
val t = new CountVectorizer()
.setInputCol("myInputCol")