aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorfwang1 <desperado.wf@gmail.com>2016-04-10 01:13:25 -0700
committerXiangrui Meng <meng@databricks.com>2016-04-10 01:13:25 -0700
commitf4344582ba28983bf3892d08e11236f090f5bf92 (patch)
tree8c2b41c51fa233534cee346c1f96020f452ffab1 /mllib
parent22014e6fb919a35c31d852b7c2f5b7eb05751208 (diff)
downloadspark-f4344582ba28983bf3892d08e11236f090f5bf92.tar.gz
spark-f4344582ba28983bf3892d08e11236f090f5bf92.tar.bz2
spark-f4344582ba28983bf3892d08e11236f090f5bf92.zip
[SPARK-14497][ML] Use top instead of sortBy() to get top N frequent words as dict in ConutVectorizer
## What changes were proposed in this pull request? Replace sortBy() with top() to calculate the top N frequent words as dictionary. ## How was this patch tested? existing unit tests. The terms with same TF would be sorted in descending order. The test would fail if hardcode the terms with same TF the dictionary like "c", "d"... Author: fwang1 <desperado.wf@gmail.com> Closes #12265 from lionelfeng/master.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala14
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala7
2 files changed, 8 insertions, 13 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
index f1be971a6a..00abbbe29c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
@@ -170,16 +170,10 @@ class CountVectorizer(override val uid: String)
(word, count)
}.cache()
val fullVocabSize = wordCounts.count()
- val vocab: Array[String] = {
- val tmpSortedWC: Array[(String, Long)] = if (fullVocabSize <= vocSize) {
- // Use all terms
- wordCounts.collect().sortBy(-_._2)
- } else {
- // Sort terms to select vocab
- wordCounts.sortBy(_._2, ascending = false).take(vocSize)
- }
- tmpSortedWC.map(_._1)
- }
+
+ val vocab = wordCounts
+ .top(math.min(fullVocabSize, vocSize).toInt)(Ordering.by(_._2))
+ .map(_._1)
require(vocab.length > 0, "The vocabulary size should be > 0. Lower minDF as necessary.")
copyValues(new CountVectorizerModel(uid, vocab).setParent(this))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
index ff0de06e27..7641e3b8cf 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
@@ -59,14 +59,15 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
(0, split("a b c d e"),
Vectors.sparse(5, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0), (4, 1.0)))),
(1, split("a a a a a a"), Vectors.sparse(5, Seq((0, 6.0)))),
- (2, split("c"), Vectors.sparse(5, Seq((2, 1.0)))),
- (3, split("b b b b b"), Vectors.sparse(5, Seq((1, 5.0)))))
+ (2, split("c c"), Vectors.sparse(5, Seq((2, 2.0)))),
+ (3, split("d"), Vectors.sparse(5, Seq((3, 1.0)))),
+ (4, split("b b b b b"), Vectors.sparse(5, Seq((1, 5.0)))))
).toDF("id", "words", "expected")
val cv = new CountVectorizer()
.setInputCol("words")
.setOutputCol("features")
.fit(df)
- assert(cv.vocabulary === Array("a", "b", "c", "d", "e"))
+ assert(cv.vocabulary.toSet === Set("a", "b", "c", "d", "e"))
cv.transform(df).select("features", "expected").collect().foreach {
case Row(features: Vector, expected: Vector) =>