aboutsummaryrefslogtreecommitdiff
path: root/examples
diff options
context:
space:
mode:
authorYuhao Yang <hhbyyh@gmail.com>2015-12-08 10:29:51 -0800
committerJoseph K. Bradley <joseph@databricks.com>2015-12-08 10:29:51 -0800
commit872a2ee281d84f40a786f765bf772cdb06e8c956 (patch)
tree6b094a180a299ff23fe487eae2b8823378cbee44 /examples
parent5d96a710a5ed543ec81e383620fc3b2a808b26a1 (diff)
downloadspark-872a2ee281d84f40a786f765bf772cdb06e8c956.tar.gz
spark-872a2ee281d84f40a786f765bf772cdb06e8c956.tar.bz2
spark-872a2ee281d84f40a786f765bf772cdb06e8c956.zip
[SPARK-10393] use ML pipeline in LDA example
jira: https://issues.apache.org/jira/browse/SPARK-10393 Since the logic of the text processing part has been moved to ML estimators/transformers, replace the related code in LDA Example with the ML pipeline. Author: Yuhao Yang <hhbyyh@gmail.com> Author: yuhaoyang <yuhao@zhanglipings-iMac.local> Closes #8551 from hhbyyh/ldaExUpdate.
Diffstat (limited to 'examples')
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala153
1 files changed, 40 insertions, 113 deletions
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
index 75b0f69cf9..70010b05e4 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/LDAExample.scala
@@ -18,19 +18,16 @@
// scalastyle:off println
package org.apache.spark.examples.mllib
-import java.text.BreakIterator
-
-import scala.collection.mutable
-
import scopt.OptionParser
import org.apache.log4j.{Level, Logger}
-
-import org.apache.spark.{SparkContext, SparkConf}
-import org.apache.spark.mllib.clustering.{EMLDAOptimizer, OnlineLDAOptimizer, DistributedLDAModel, LDA}
-import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.ml.Pipeline
+import org.apache.spark.ml.feature.{CountVectorizer, CountVectorizerModel, RegexTokenizer, StopWordsRemover}
+import org.apache.spark.mllib.clustering.{DistributedLDAModel, EMLDAOptimizer, LDA, OnlineLDAOptimizer}
+import org.apache.spark.mllib.linalg.Vector
import org.apache.spark.rdd.RDD
-
+import org.apache.spark.sql.{Row, SQLContext}
+import org.apache.spark.{SparkConf, SparkContext}
/**
* An example Latent Dirichlet Allocation (LDA) app. Run with
@@ -192,115 +189,45 @@ object LDAExample {
vocabSize: Int,
stopwordFile: String): (RDD[(Long, Vector)], Array[String], Long) = {
+ val sqlContext = SQLContext.getOrCreate(sc)
+ import sqlContext.implicits._
+
// Get dataset of document texts
// One document per line in each text file. If the input consists of many small files,
// this can result in a large number of small partitions, which can degrade performance.
// In this case, consider using coalesce() to create fewer, larger partitions.
- val textRDD: RDD[String] = sc.textFile(paths.mkString(","))
-
- // Split text into words
- val tokenizer = new SimpleTokenizer(sc, stopwordFile)
- val tokenized: RDD[(Long, IndexedSeq[String])] = textRDD.zipWithIndex().map { case (text, id) =>
- id -> tokenizer.getWords(text)
- }
- tokenized.cache()
-
- // Counts words: RDD[(word, wordCount)]
- val wordCounts: RDD[(String, Long)] = tokenized
- .flatMap { case (_, tokens) => tokens.map(_ -> 1L) }
- .reduceByKey(_ + _)
- wordCounts.cache()
- val fullVocabSize = wordCounts.count()
- // Select vocab
- // (vocab: Map[word -> id], total tokens after selecting vocab)
- val (vocab: Map[String, Int], selectedTokenCount: Long) = {
- val tmpSortedWC: Array[(String, Long)] = if (vocabSize == -1 || fullVocabSize <= vocabSize) {
- // Use all terms
- wordCounts.collect().sortBy(-_._2)
- } else {
- // Sort terms to select vocab
- wordCounts.sortBy(_._2, ascending = false).take(vocabSize)
- }
- (tmpSortedWC.map(_._1).zipWithIndex.toMap, tmpSortedWC.map(_._2).sum)
- }
-
- val documents = tokenized.map { case (id, tokens) =>
- // Filter tokens by vocabulary, and create word count vector representation of document.
- val wc = new mutable.HashMap[Int, Int]()
- tokens.foreach { term =>
- if (vocab.contains(term)) {
- val termIndex = vocab(term)
- wc(termIndex) = wc.getOrElse(termIndex, 0) + 1
- }
- }
- val indices = wc.keys.toArray.sorted
- val values = indices.map(i => wc(i).toDouble)
-
- val sb = Vectors.sparse(vocab.size, indices, values)
- (id, sb)
- }
-
- val vocabArray = new Array[String](vocab.size)
- vocab.foreach { case (term, i) => vocabArray(i) = term }
-
- (documents, vocabArray, selectedTokenCount)
- }
-}
-
-/**
- * Simple Tokenizer.
- *
- * TODO: Formalize the interface, and make this a public class in mllib.feature
- */
-private class SimpleTokenizer(sc: SparkContext, stopwordFile: String) extends Serializable {
-
- private val stopwords: Set[String] = if (stopwordFile.isEmpty) {
- Set.empty[String]
- } else {
- val stopwordText = sc.textFile(stopwordFile).collect()
- stopwordText.flatMap(_.stripMargin.split("\\s+")).toSet
- }
-
- // Matches sequences of Unicode letters
- private val allWordRegex = "^(\\p{L}*)$".r
-
- // Ignore words shorter than this length.
- private val minWordLength = 3
-
- def getWords(text: String): IndexedSeq[String] = {
-
- val words = new mutable.ArrayBuffer[String]()
-
- // Use Java BreakIterator to tokenize text into words.
- val wb = BreakIterator.getWordInstance
- wb.setText(text)
-
- // current,end index start,end of each word
- var current = wb.first()
- var end = wb.next()
- while (end != BreakIterator.DONE) {
- // Convert to lowercase
- val word: String = text.substring(current, end).toLowerCase
- // Remove short words and strings that aren't only letters
- word match {
- case allWordRegex(w) if w.length >= minWordLength && !stopwords.contains(w) =>
- words += w
- case _ =>
- }
-
- current = end
- try {
- end = wb.next()
- } catch {
- case e: Exception =>
- // Ignore remaining text in line.
- // This is a known bug in BreakIterator (for some Java versions),
- // which fails when it sees certain characters.
- end = BreakIterator.DONE
- }
+ val df = sc.textFile(paths.mkString(",")).toDF("docs")
+ val customizedStopWords: Array[String] = if (stopwordFile.isEmpty) {
+ Array.empty[String]
+ } else {
+ val stopWordText = sc.textFile(stopwordFile).collect()
+ stopWordText.flatMap(_.stripMargin.split("\\s+"))
}
- words
+ val tokenizer = new RegexTokenizer()
+ .setInputCol("docs")
+ .setOutputCol("rawTokens")
+ val stopWordsRemover = new StopWordsRemover()
+ .setInputCol("rawTokens")
+ .setOutputCol("tokens")
+ stopWordsRemover.setStopWords(stopWordsRemover.getStopWords ++ customizedStopWords)
+ val countVectorizer = new CountVectorizer()
+ .setVocabSize(vocabSize)
+ .setInputCol("tokens")
+ .setOutputCol("features")
+
+ val pipeline = new Pipeline()
+ .setStages(Array(tokenizer, stopWordsRemover, countVectorizer))
+
+ val model = pipeline.fit(df)
+ val documents = model.transform(df)
+ .select("features")
+ .map { case Row(features: Vector) => features }
+ .zipWithIndex()
+ .map(_.swap)
+
+ (documents,
+ model.stages(2).asInstanceOf[CountVectorizerModel].vocabulary, // vocabulary
+ documents.map(_._2.numActives).sum().toLong) // total token count
}
-
}
// scalastyle:on println