aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala37
-rw-r--r--mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java20
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala36
3 files changed, 88 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
index d40d5553c1..720bb70b08 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/feature/IDF.scala
@@ -30,9 +30,18 @@ import org.apache.spark.rdd.RDD
* Inverse document frequency (IDF).
* The standard formulation is used: `idf = log((m + 1) / (d(t) + 1))`, where `m` is the total
* number of documents and `d(t)` is the number of documents that contain term `t`.
+ *
+ * This implementation supports filtering out terms which do not appear in a minimum number
+ * of documents (controlled by the variable `minDocFreq`). For terms that are not in
+ * at least `minDocFreq` documents, the IDF is found as 0, resulting in TF-IDFs of 0.
+ *
+ * @param minDocFreq minimum of documents in which a term
+ * should appear for filtering
*/
@Experimental
-class IDF {
+class IDF(val minDocFreq: Int) {
+
+ def this() = this(0)
// TODO: Allow different IDF formulations.
@@ -41,7 +50,8 @@ class IDF {
* @param dataset an RDD of term frequency vectors
*/
def fit(dataset: RDD[Vector]): IDFModel = {
- val idf = dataset.treeAggregate(new IDF.DocumentFrequencyAggregator)(
+ val idf = dataset.treeAggregate(new IDF.DocumentFrequencyAggregator(
+ minDocFreq = minDocFreq))(
seqOp = (df, v) => df.add(v),
combOp = (df1, df2) => df1.merge(df2)
).idf()
@@ -60,13 +70,16 @@ class IDF {
private object IDF {
/** Document frequency aggregator. */
- class DocumentFrequencyAggregator extends Serializable {
+ class DocumentFrequencyAggregator(val minDocFreq: Int) extends Serializable {
/** number of documents */
private var m = 0L
/** document frequency vector */
private var df: BDV[Long] = _
+
+ def this() = this(0)
+
/** Adds a new document. */
def add(doc: Vector): this.type = {
if (isEmpty) {
@@ -123,7 +136,18 @@ private object IDF {
val inv = new Array[Double](n)
var j = 0
while (j < n) {
- inv(j) = math.log((m + 1.0)/ (df(j) + 1.0))
+ /*
+ * If the term is not present in the minimum
+ * number of documents, set IDF to 0. This
+ * will cause multiplication in IDFModel to
+ * set TF-IDF to 0.
+ *
+ * Since arrays are initialized to 0 by default,
+ * we just omit changing those entries.
+ */
+ if(df(j) >= minDocFreq) {
+ inv(j) = math.log((m + 1.0) / (df(j) + 1.0))
+ }
j += 1
}
Vectors.dense(inv)
@@ -140,6 +164,11 @@ class IDFModel private[mllib] (val idf: Vector) extends Serializable {
/**
* Transforms term frequency (TF) vectors to TF-IDF vectors.
+ *
+ * If `minDocFreq` was set for the IDF calculation,
+ * the terms which occur in fewer than `minDocFreq`
+ * documents will have an entry of 0.
+ *
* @param dataset an RDD of term frequency vectors
* @return an RDD of TF-IDF vectors
*/
diff --git a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java
index e8d99f4ae4..064263e02c 100644
--- a/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java
+++ b/mllib/src/test/java/org/apache/spark/mllib/feature/JavaTfIdfSuite.java
@@ -63,4 +63,24 @@ public class JavaTfIdfSuite implements Serializable {
Assert.assertEquals(0.0, v.apply(indexOfThis), 1e-15);
}
}
+
+ @Test
+ public void tfIdfMinimumDocumentFrequency() {
+ // The tests are to check Java compatibility.
+ HashingTF tf = new HashingTF();
+ JavaRDD<ArrayList<String>> documents = sc.parallelize(Lists.newArrayList(
+ Lists.newArrayList("this is a sentence".split(" ")),
+ Lists.newArrayList("this is another sentence".split(" ")),
+ Lists.newArrayList("this is still a sentence".split(" "))), 2);
+ JavaRDD<Vector> termFreqs = tf.transform(documents);
+ termFreqs.collect();
+ IDF idf = new IDF(2);
+ JavaRDD<Vector> tfIdfs = idf.fit(termFreqs).transform(termFreqs);
+ List<Vector> localTfIdfs = tfIdfs.collect();
+ int indexOfThis = tf.indexOf("this");
+ for (Vector v: localTfIdfs) {
+ Assert.assertEquals(0.0, v.apply(indexOfThis), 1e-15);
+ }
+ }
+
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala
index 53d9c0c640..43974f84e3 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/feature/IDFSuite.scala
@@ -38,7 +38,7 @@ class IDFSuite extends FunSuite with LocalSparkContext {
val idf = new IDF
val model = idf.fit(termFrequencies)
val expected = Vectors.dense(Array(0, 3, 1, 2).map { x =>
- math.log((m.toDouble + 1.0) / (x + 1.0))
+ math.log((m + 1.0) / (x + 1.0))
})
assert(model.idf ~== expected absTol 1e-12)
val tfidf = model.transform(termFrequencies).cache().zipWithIndex().map(_.swap).collectAsMap()
@@ -54,4 +54,38 @@ class IDFSuite extends FunSuite with LocalSparkContext {
assert(tfidf2.indices === Array(1))
assert(tfidf2.values(0) ~== (1.0 * expected(1)) absTol 1e-12)
}
+
+ test("idf minimum document frequency filtering") {
+ val n = 4
+ val localTermFrequencies = Seq(
+ Vectors.sparse(n, Array(1, 3), Array(1.0, 2.0)),
+ Vectors.dense(0.0, 1.0, 2.0, 3.0),
+ Vectors.sparse(n, Array(1), Array(1.0))
+ )
+ val m = localTermFrequencies.size
+ val termFrequencies = sc.parallelize(localTermFrequencies, 2)
+ val idf = new IDF(minDocFreq = 1)
+ val model = idf.fit(termFrequencies)
+ val expected = Vectors.dense(Array(0, 3, 1, 2).map { x =>
+ if (x > 0) {
+ math.log((m + 1.0) / (x + 1.0))
+ } else {
+ 0
+ }
+ })
+ assert(model.idf ~== expected absTol 1e-12)
+ val tfidf = model.transform(termFrequencies).cache().zipWithIndex().map(_.swap).collectAsMap()
+ assert(tfidf.size === 3)
+ val tfidf0 = tfidf(0L).asInstanceOf[SparseVector]
+ assert(tfidf0.indices === Array(1, 3))
+ assert(Vectors.dense(tfidf0.values) ~==
+ Vectors.dense(1.0 * expected(1), 2.0 * expected(3)) absTol 1e-12)
+ val tfidf1 = tfidf(1L).asInstanceOf[DenseVector]
+ assert(Vectors.dense(tfidf1.values) ~==
+ Vectors.dense(0.0, 1.0 * expected(1), 2.0 * expected(2), 3.0 * expected(3)) absTol 1e-12)
+ val tfidf2 = tfidf(2L).asInstanceOf[SparseVector]
+ assert(tfidf2.indices === Array(1))
+ assert(tfidf2.values(0) ~== (1.0 * expected(1)) absTol 1e-12)
+ }
+
}