diff options
author | fwang1 <desperado.wf@gmail.com> | 2016-04-10 01:13:25 -0700 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2016-04-10 01:13:25 -0700 |
commit | f4344582ba28983bf3892d08e11236f090f5bf92 (patch) | |
tree | 8c2b41c51fa233534cee346c1f96020f452ffab1 /mllib/src/main/scala | |
parent | 22014e6fb919a35c31d852b7c2f5b7eb05751208 (diff) | |
download | spark-f4344582ba28983bf3892d08e11236f090f5bf92.tar.gz spark-f4344582ba28983bf3892d08e11236f090f5bf92.tar.bz2 spark-f4344582ba28983bf3892d08e11236f090f5bf92.zip |
[SPARK-14497][ML] Use top instead of sortBy() to get top N frequent words as dict in ConutVectorizer
## What changes were proposed in this pull request?
Replace sortBy() with top() to calculate the top N frequent words as dictionary.
## How was this patch tested?
existing unit tests. The terms with same TF would be sorted in descending order. The test would fail if hardcode the terms with same TF the dictionary like "c", "d"...
Author: fwang1 <desperado.wf@gmail.com>
Closes #12265 from lionelfeng/master.
Diffstat (limited to 'mllib/src/main/scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala | 14 |
1 files changed, 4 insertions, 10 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala index f1be971a6a..00abbbe29c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala @@ -170,16 +170,10 @@ class CountVectorizer(override val uid: String) (word, count) }.cache() val fullVocabSize = wordCounts.count() - val vocab: Array[String] = { - val tmpSortedWC: Array[(String, Long)] = if (fullVocabSize <= vocSize) { - // Use all terms - wordCounts.collect().sortBy(-_._2) - } else { - // Sort terms to select vocab - wordCounts.sortBy(_._2, ascending = false).take(vocSize) - } - tmpSortedWC.map(_._1) - } + + val vocab = wordCounts + .top(math.min(fullVocabSize, vocSize).toInt)(Ordering.by(_._2)) + .map(_._1) require(vocab.length > 0, "The vocabulary size should be > 0. Lower minDF as necessary.") copyValues(new CountVectorizerModel(uid, vocab).setParent(this)) |