aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala23
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala15
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala24
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala12
4 files changed, 69 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
index 61a78d73c4..0f7ae5a100 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/HashingTF.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.feature
import org.apache.spark.annotation.{Experimental, Since}
import org.apache.spark.ml.Transformer
import org.apache.spark.ml.attribute.AttributeGroup
-import org.apache.spark.ml.param.{IntParam, ParamMap, ParamValidators}
+import org.apache.spark.ml.param.{BooleanParam, IntParam, ParamMap, ParamValidators}
import org.apache.spark.ml.param.shared.{HasInputCol, HasOutputCol}
import org.apache.spark.ml.util._
import org.apache.spark.mllib.feature
@@ -52,7 +52,18 @@ class HashingTF(override val uid: String)
val numFeatures = new IntParam(this, "numFeatures", "number of features (> 0)",
ParamValidators.gt(0))
- setDefault(numFeatures -> (1 << 18))
+ /**
+ * Binary toggle to control term frequency counts.
+ * If true, all non-zero counts are set to 1. This is useful for discrete probabilistic
+ * models that model binary events rather than integer counts.
+ * (default = false)
+ * @group param
+ */
+ val binary = new BooleanParam(this, "binary", "If true, all non zero counts are set to 1. " +
+ "This is useful for discrete probabilistic models that model binary events rather " +
+ "than integer counts")
+
+ setDefault(numFeatures -> (1 << 18), binary -> false)
/** @group getParam */
def getNumFeatures: Int = $(numFeatures)
@@ -60,9 +71,15 @@ class HashingTF(override val uid: String)
/** @group setParam */
def setNumFeatures(value: Int): this.type = set(numFeatures, value)
+ /** @group getParam */
+ def getBinary: Boolean = $(binary)
+
+ /** @group setParam */
+ def setBinary(value: Boolean): this.type = set(binary, value)
+
override def transform(dataset: DataFrame): DataFrame = {
val outputSchema = transformSchema(dataset.schema)
- val hashingTF = new feature.HashingTF($(numFeatures))
+ val hashingTF = new feature.HashingTF($(numFeatures)).setBinary($(binary))
val t = udf { terms: Seq[_] => hashingTF.transform(terms) }
val metadata = outputSchema($(outputCol)).metadata
dataset.select(col("*"), t(col($(inputCol))).as($(outputCol), metadata))
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala
index c93ed64183..47c9e850a0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/HashingTF.scala
@@ -36,12 +36,24 @@ import org.apache.spark.util.Utils
@Since("1.1.0")
class HashingTF(val numFeatures: Int) extends Serializable {
+ private var binary = false
+
/**
*/
@Since("1.1.0")
def this() = this(1 << 20)
/**
+ * If true, term frequency vector will be binary such that non-zero term counts will be set to 1
+ * (default: false)
+ */
+ @Since("2.0.0")
+ def setBinary(value: Boolean): this.type = {
+ binary = value
+ this
+ }
+
+ /**
* Returns the index of the input term.
*/
@Since("1.1.0")
@@ -53,9 +65,10 @@ class HashingTF(val numFeatures: Int) extends Serializable {
@Since("1.1.0")
def transform(document: Iterable[_]): Vector = {
val termFrequencies = mutable.HashMap.empty[Int, Double]
+ val setTF = if (binary) (i: Int) => 1.0 else (i: Int) => termFrequencies.getOrElse(i, 0.0) + 1.0
document.foreach { term =>
val i = indexOf(term)
- termFrequencies.put(i, termFrequencies.getOrElse(i, 0.0) + 1.0)
+ termFrequencies.put(i, setTF(i))
}
Vectors.sparse(numFeatures, termFrequencies.toSeq)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
index 0dcd0f4946..addd733c20 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/HashingTFSuite.scala
@@ -46,12 +46,30 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
require(attrGroup.numAttributes === Some(n))
val features = output.select("features").first().getAs[Vector](0)
// Assume perfect hash on "a", "b", "c", and "d".
- def idx(any: Any): Int = Utils.nonNegativeMod(any.##, n)
+ def idx: Any => Int = featureIdx(n)
val expected = Vectors.sparse(n,
Seq((idx("a"), 2.0), (idx("b"), 2.0), (idx("c"), 1.0), (idx("d"), 1.0)))
assert(features ~== expected absTol 1e-14)
}
+ test("applying binary term freqs") {
+ val df = sqlContext.createDataFrame(Seq(
+ (0, "a a b c c c".split(" ").toSeq)
+ )).toDF("id", "words")
+ val n = 100
+ val hashingTF = new HashingTF()
+ .setInputCol("words")
+ .setOutputCol("features")
+ .setNumFeatures(n)
+ .setBinary(true)
+ val output = hashingTF.transform(df)
+ val features = output.select("features").first().getAs[Vector](0)
+ def idx: Any => Int = featureIdx(n) // Assume perfect hash on input features
+ val expected = Vectors.sparse(n,
+ Seq((idx("a"), 1.0), (idx("b"), 1.0), (idx("c"), 1.0)))
+ assert(features ~== expected absTol 1e-14)
+ }
+
test("read/write") {
val t = new HashingTF()
.setInputCol("myInputCol")
@@ -59,4 +77,8 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext with Defau
.setNumFeatures(10)
testDefaultReadWrite(t)
}
+
+ private def featureIdx(numFeatures: Int)(term: Any): Int = {
+ Utils.nonNegativeMod(term.##, numFeatures)
+ }
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala
index cf279c0233..6c07e3a5ce 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/HashingTFSuite.scala
@@ -20,6 +20,7 @@ package org.apache.spark.mllib.feature
import org.apache.spark.SparkFunSuite
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.mllib.util.TestingUtils._
class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -48,4 +49,15 @@ class HashingTFSuite extends SparkFunSuite with MLlibTestSparkContext {
val docs = sc.parallelize(localDocs, 2)
assert(hashingTF.transform(docs).collect().toSet === localDocs.map(hashingTF.transform).toSet)
}
+
+ test("applying binary term freqs") {
+ val hashingTF = new HashingTF(100).setBinary(true)
+ val doc = "a a b c c c".split(" ")
+ val n = hashingTF.numFeatures
+ val expected = Vectors.sparse(n, Seq(
+ (hashingTF.indexOf("a"), 1.0),
+ (hashingTF.indexOf("b"), 1.0),
+ (hashingTF.indexOf("c"), 1.0)))
+ assert(hashingTF.transform(doc) ~== expected absTol 1e-14)
+ }
}