aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2016-08-18 05:33:52 -0700
committerFelix Cheung <felixcheung@apache.org>2016-08-18 05:33:52 -0700
commitb72bb62d421840f82d663c6b8e3922bd14383fbb (patch)
tree1445a4e605794d84a606661dcfbd68decb3df657 /mllib
parent68f5087d2107d6afec5d5745f0cb0e9e3bdd6a0b (diff)
downloadspark-b72bb62d421840f82d663c6b8e3922bd14383fbb.tar.gz
spark-b72bb62d421840f82d663c6b8e3922bd14383fbb.tar.bz2
spark-b72bb62d421840f82d663c6b8e3922bd14383fbb.zip
[SPARK-16447][ML][SPARKR] LDA wrapper in SparkR
## What changes were proposed in this pull request? Add LDA Wrapper in SparkR with the following interfaces: - spark.lda(data, ...) - spark.posterior(object, newData, ...) - spark.perplexity(object, ...) - summary(object) - write.ml(object) - read.ml(path) ## How was this patch tested? Test with SparkR unit test. Author: Xusen Yin <yinxusen@gmail.com> Closes #14229 from yinxusen/SPARK-16447.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala216
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala2
3 files changed, 222 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
index 034f2c3fa2..b5a764b586 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -386,6 +386,10 @@ sealed abstract class LDAModel private[ml] (
@Since("1.6.0")
protected def getModel: OldLDAModel
+ private[ml] def getEffectiveDocConcentration: Array[Double] = getModel.docConcentration.toArray
+
+ private[ml] def getEffectiveTopicConcentration: Double = getModel.topicConcentration
+
/**
* The features for LDA should be a [[Vector]] representing the word counts in a document.
* The vector should be of length vocabSize, with counts for each term (word).
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala
new file mode 100644
index 0000000000..cbe6a70500
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.r
+
+import scala.collection.mutable
+
+import org.apache.hadoop.fs.Path
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.SparkException
+import org.apache.spark.ml.{Pipeline, PipelineModel, PipelineStage}
+import org.apache.spark.ml.clustering.{LDA, LDAModel}
+import org.apache.spark.ml.feature.{CountVectorizer, CountVectorizerModel, RegexTokenizer, StopWordsRemover}
+import org.apache.spark.ml.linalg.{Vector, VectorUDT}
+import org.apache.spark.ml.param.ParamPair
+import org.apache.spark.ml.util._
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.StringType
+
+
+private[r] class LDAWrapper private (
+ val pipeline: PipelineModel,
+ val logLikelihood: Double,
+ val logPerplexity: Double,
+ val vocabulary: Array[String]) extends MLWritable {
+
+ import LDAWrapper._
+
+ private val lda: LDAModel = pipeline.stages.last.asInstanceOf[LDAModel]
+ private val preprocessor: PipelineModel =
+ new PipelineModel(s"${Identifiable.randomUID(pipeline.uid)}", pipeline.stages.dropRight(1))
+
+ def transform(data: Dataset[_]): DataFrame = {
+ val vec2ary = udf { vec: Vector => vec.toArray }
+ val outputCol = lda.getTopicDistributionCol
+ val tempCol = s"${Identifiable.randomUID(outputCol)}"
+ val preprocessed = preprocessor.transform(data)
+ lda.transform(preprocessed, ParamPair(lda.topicDistributionCol, tempCol))
+ .withColumn(outputCol, vec2ary(col(tempCol)))
+ .drop(TOKENIZER_COL, STOPWORDS_REMOVER_COL, COUNT_VECTOR_COL, tempCol)
+ }
+
+ def computeLogPerplexity(data: Dataset[_]): Double = {
+ lda.logPerplexity(preprocessor.transform(data))
+ }
+
+ def topics(maxTermsPerTopic: Int): DataFrame = {
+ val topicIndices: DataFrame = lda.describeTopics(maxTermsPerTopic)
+ if (vocabulary.isEmpty || vocabulary.length < vocabSize) {
+ topicIndices
+ } else {
+ val index2term = udf { indices: mutable.WrappedArray[Int] => indices.map(i => vocabulary(i)) }
+ topicIndices
+ .select(col("topic"), index2term(col("termIndices")).as("term"), col("termWeights"))
+ }
+ }
+
+ lazy val isDistributed: Boolean = lda.isDistributed
+ lazy val vocabSize: Int = lda.vocabSize
+ lazy val docConcentration: Array[Double] = lda.getEffectiveDocConcentration
+ lazy val topicConcentration: Double = lda.getEffectiveTopicConcentration
+
+ override def write: MLWriter = new LDAWrapper.LDAWrapperWriter(this)
+}
+
+private[r] object LDAWrapper extends MLReadable[LDAWrapper] {
+
+ val TOKENIZER_COL = s"${Identifiable.randomUID("rawTokens")}"
+ val STOPWORDS_REMOVER_COL = s"${Identifiable.randomUID("tokens")}"
+ val COUNT_VECTOR_COL = s"${Identifiable.randomUID("features")}"
+
+ private def getPreStages(
+ features: String,
+ customizedStopWords: Array[String],
+ maxVocabSize: Int): Array[PipelineStage] = {
+ val tokenizer = new RegexTokenizer()
+ .setInputCol(features)
+ .setOutputCol(TOKENIZER_COL)
+ val stopWordsRemover = new StopWordsRemover()
+ .setInputCol(TOKENIZER_COL)
+ .setOutputCol(STOPWORDS_REMOVER_COL)
+ stopWordsRemover.setStopWords(stopWordsRemover.getStopWords ++ customizedStopWords)
+ val countVectorizer = new CountVectorizer()
+ .setVocabSize(maxVocabSize)
+ .setInputCol(STOPWORDS_REMOVER_COL)
+ .setOutputCol(COUNT_VECTOR_COL)
+
+ Array(tokenizer, stopWordsRemover, countVectorizer)
+ }
+
+ def fit(
+ data: DataFrame,
+ features: String,
+ k: Int,
+ maxIter: Int,
+ optimizer: String,
+ subsamplingRate: Double,
+ topicConcentration: Double,
+ docConcentration: Array[Double],
+ customizedStopWords: Array[String],
+ maxVocabSize: Int): LDAWrapper = {
+
+ val lda = new LDA()
+ .setK(k)
+ .setMaxIter(maxIter)
+ .setSubsamplingRate(subsamplingRate)
+
+ val featureSchema = data.schema(features)
+ val stages = featureSchema.dataType match {
+ case d: StringType =>
+ getPreStages(features, customizedStopWords, maxVocabSize) ++
+ Array(lda.setFeaturesCol(COUNT_VECTOR_COL))
+ case d: VectorUDT =>
+ Array(lda.setFeaturesCol(features))
+ case _ =>
+ throw new SparkException(
+ s"Unsupported input features type of ${featureSchema.dataType.typeName}," +
+ s" only String type and Vector type are supported now.")
+ }
+
+ if (topicConcentration != -1) {
+ lda.setTopicConcentration(topicConcentration)
+ } else {
+ // Auto-set topicConcentration
+ }
+
+ if (docConcentration.length == 1) {
+ if (docConcentration.head != -1) {
+ lda.setDocConcentration(docConcentration.head)
+ } else {
+ // Auto-set docConcentration
+ }
+ } else {
+ lda.setDocConcentration(docConcentration)
+ }
+
+ val pipeline = new Pipeline().setStages(stages)
+ val model = pipeline.fit(data)
+
+ val vocabulary: Array[String] = featureSchema.dataType match {
+ case d: StringType =>
+ val countVectorModel = model.stages(2).asInstanceOf[CountVectorizerModel]
+ countVectorModel.vocabulary
+ case _ => Array.empty[String]
+ }
+
+ val ldaModel: LDAModel = model.stages.last.asInstanceOf[LDAModel]
+ val preprocessor: PipelineModel =
+ new PipelineModel(s"${Identifiable.randomUID(pipeline.uid)}", model.stages.dropRight(1))
+
+ val preprocessedData = preprocessor.transform(data)
+
+ new LDAWrapper(
+ model,
+ ldaModel.logLikelihood(preprocessedData),
+ ldaModel.logPerplexity(preprocessedData),
+ vocabulary)
+ }
+
+ override def read: MLReader[LDAWrapper] = new LDAWrapperReader
+
+ override def load(path: String): LDAWrapper = super.load(path)
+
+ class LDAWrapperWriter(instance: LDAWrapper) extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ val rMetadataPath = new Path(path, "rMetadata").toString
+ val pipelinePath = new Path(path, "pipeline").toString
+
+ val rMetadata = ("class" -> instance.getClass.getName) ~
+ ("logLikelihood" -> instance.logLikelihood) ~
+ ("logPerplexity" -> instance.logPerplexity) ~
+ ("vocabulary" -> instance.vocabulary.toList)
+ val rMetadataJson: String = compact(render(rMetadata))
+ sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
+
+ instance.pipeline.save(pipelinePath)
+ }
+ }
+
+ class LDAWrapperReader extends MLReader[LDAWrapper] {
+
+ override def load(path: String): LDAWrapper = {
+ implicit val format = DefaultFormats
+ val rMetadataPath = new Path(path, "rMetadata").toString
+ val pipelinePath = new Path(path, "pipeline").toString
+
+ val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
+ val rMetadata = parse(rMetadataStr)
+ val logLikelihood = (rMetadata \ "logLikelihood").extract[Double]
+ val logPerplexity = (rMetadata \ "logPerplexity").extract[Double]
+ val vocabulary = (rMetadata \ "vocabulary").extract[List[String]].toArray
+
+ val pipeline = PipelineModel.load(pipelinePath)
+ new LDAWrapper(pipeline, logLikelihood, logPerplexity, vocabulary)
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
index 88ac26bc5e..e23af51df5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
@@ -44,6 +44,8 @@ private[r] object RWrappers extends MLReader[Object] {
GeneralizedLinearRegressionWrapper.load(path)
case "org.apache.spark.ml.r.KMeansWrapper" =>
KMeansWrapper.load(path)
+ case "org.apache.spark.ml.r.LDAWrapper" =>
+ LDAWrapper.load(path)
case "org.apache.spark.ml.r.IsotonicRegressionWrapper" =>
IsotonicRegressionWrapper.load(path)
case "org.apache.spark.ml.r.GaussianMixtureWrapper" =>