diff options
author | Bryan Cutler <cutlerb@gmail.com> | 2016-03-29 12:30:30 +0200 |
---|---|---|
committer | Nick Pentreath <nick.pentreath@gmail.com> | 2016-03-29 12:30:30 +0200 |
commit | 425bcf6d6844732fe402af05472ad87b4e032cb6 (patch) | |
tree | fea575f26678d2ae8646b6e4eec913226509e7bf /mllib/src/main | |
parent | 83775bc78e183791f75a99cdfbcd68a67ca0d472 (diff) | |
download | spark-425bcf6d6844732fe402af05472ad87b4e032cb6.tar.gz spark-425bcf6d6844732fe402af05472ad87b4e032cb6.tar.bz2 spark-425bcf6d6844732fe402af05472ad87b4e032cb6.zip |
[SPARK-13963][ML] Adding binary toggle param to HashingTF
## What changes were proposed in this pull request?
Adding binary toggle parameter to ml.feature.HashingTF, as well as mllib.feature.HashingTF since the former wraps this functionality. This parameter, if true, will set non-zero valued term counts to 1 to transform term count features to binary values that are well suited for discrete probability models.
## How was this patch tested?
Added unit tests for ML and MLlib
Author: Bryan Cutler <cutlerb@gmail.com>
Closes #11832 from BryanCutler/binary-param-HashingTF-SPARK-13963.
Diffstat (limited to 'mllib/src/main')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala | 23 | ||||
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala | 15 |
2 files changed, 34 insertions, 4 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala index 61a78d73c4..0f7ae5a100 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala @@ -20,7 +20,7 @@ package org.apache.spark.ml.feature import org.apache.spark.annotation.{Experimental, Since} import org.apache.spark.ml.Transformer import org.apache.spark.ml.attribute.AttributeGroup -import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators} +import org.apache.spark.ml.param.{BooleanParam, IntParam, ParamMap, ParamValidators} import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol} import org.apache.spark.ml.util._ import org.apache.spark.mllib.feature @@ -52,7 +52,18 @@ class HashingTF(override val uid: String) val numFeatures = new IntParam(this, "numFeatures", "number of features (> 0)", ParamValidators.gt(0)) - setDefault(numFeatures -> (1 << 18)) + /** + * Binary toggle to control term frequency counts. + * If true, all non-zero counts are set to 1. This is useful for discrete probabilistic + * models that model binary events rather than integer counts. + * (default = false) + * @group param + */ + val binary = new BooleanParam(this, "binary", "If true, all non zero counts are set to 1. " + + "This is useful for discrete probabilistic models that model binary events rather " + + "than integer counts") + + setDefault(numFeatures -> (1 << 18), binary -> false) /** @group getParam */ def getNumFeatures: Int = $(numFeatures) @@ -60,9 +71,15 @@ class HashingTF(override val uid: String) /** @group setParam */ def setNumFeatures(value: Int): this.type = set(numFeatures, value) + /** @group getParam */ + def getBinary: Boolean = $(binary) + + /** @group setParam */ + def setBinary(value: Boolean): this.type = set(binary, value) + override def transform(dataset: DataFrame): DataFrame = { val outputSchema = transformSchema(dataset.schema) - val hashingTF = new feature.HashingTF($(numFeatures)) + val hashingTF = new feature.HashingTF($(numFeatures)).setBinary($(binary)) val t = udf { terms: Seq[_] => hashingTF.transform(terms) } val metadata = outputSchema($(outputCol)).metadata dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata)) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala index c93ed64183..47c9e850a0 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala @@ -36,12 +36,24 @@ import org.apache.spark.util.Utils @Since("1.1.0") class HashingTF(val numFeatures: Int) extends Serializable { + private var binary = false + /** */ @Since("1.1.0") def this() = this(1 << 20) /** + * If true, term frequency vector will be binary such that non-zero term counts will be set to 1 + * (default: false) + */ + @Since("2.0.0") + def setBinary(value: Boolean): this.type = { + binary = value + this + } + + /** * Returns the index of the input term. */ @Since("1.1.0") @@ -53,9 +65,10 @@ class HashingTF(val numFeatures: Int) extends Serializable { @Since("1.1.0") def transform(document: Iterable[_]): Vector = { val termFrequencies = mutable.HashMap.empty[Int, Double] + val setTF = if (binary) (i: Int) => 1.0 else (i: Int) => termFrequencies.getOrElse(i, 0.0) + 1.0 document.foreach { term => val i = indexOf(term) - termFrequencies.put(i, termFrequencies.getOrElse(i, 0.0) + 1.0) + termFrequencies.put(i, setTF(i)) } Vectors.sparse(numFeatures, termFrequencies.toSeq) } |