aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala14
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala7
2 files changed, 8 insertions, 13 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
index f1be971a6a..00abbbe29c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/feature/CountVectorizer.scala
@@ -170,16 +170,10 @@ class CountVectorizer(override val uid: String)
(word, count)
}.cache()
val fullVocabSize = wordCounts.count()
- val vocab: Array[String] = {
- val tmpSortedWC: Array[(String, Long)] = if (fullVocabSize <= vocSize) {
- // Use all terms
- wordCounts.collect().sortBy(-_._2)
- } else {
- // Sort terms to select vocab
- wordCounts.sortBy(_._2, ascending = false).take(vocSize)
- }
- tmpSortedWC.map(_._1)
- }
+
+ val vocab = wordCounts
+ .top(math.min(fullVocabSize, vocSize).toInt)(Ordering.by(_._2))
+ .map(_._1)
require(vocab.length > 0, "The vocabulary size should be > 0. Lower minDF as necessary.")
copyValues(new CountVectorizerModel(uid, vocab).setParent(this))
diff --git a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
index ff0de06e27..7641e3b8cf 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/feature/CountVectorizerSuite.scala
@@ -59,14 +59,15 @@ class CountVectorizerSuite extends SparkFunSuite with MLlibTestSparkContext
(0, split("a b c d e"),
Vectors.sparse(5, Seq((0, 1.0), (1, 1.0), (2, 1.0), (3, 1.0), (4, 1.0)))),
(1, split("a a a a a a"), Vectors.sparse(5, Seq((0, 6.0)))),
- (2, split("c"), Vectors.sparse(5, Seq((2, 1.0)))),
- (3, split("b b b b b"), Vectors.sparse(5, Seq((1, 5.0)))))
+ (2, split("c c"), Vectors.sparse(5, Seq((2, 2.0)))),
+ (3, split("d"), Vectors.sparse(5, Seq((3, 1.0)))),
+ (4, split("b b b b b"), Vectors.sparse(5, Seq((1, 5.0)))))
).toDF("id", "words", "expected")
val cv = new CountVectorizer()
.setInputCol("words")
.setOutputCol("features")
.fit(df)
- assert(cv.vocabulary === Array("a", "b", "c", "d", "e"))
+ assert(cv.vocabulary.toSet === Set("a", "b", "c", "d", "e"))
cv.transform(df).select("features", "expected").collect().foreach {
case Row(features: Vector, expected: Vector) =>