aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--R/pkg/NAMESPACE3
-rw-r--r--R/pkg/R/generics.R14
-rw-r--r--R/pkg/R/mllib.R166
-rw-r--r--R/pkg/inst/tests/testthat/test_mllib.R87
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala216
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala2
7 files changed, 490 insertions, 2 deletions
diff --git a/R/pkg/NAMESPACE b/R/pkg/NAMESPACE
index c71eec5ce0..4404cffc29 100644
--- a/R/pkg/NAMESPACE
+++ b/R/pkg/NAMESPACE
@@ -25,6 +25,9 @@ exportMethods("glm",
"fitted",
"spark.naiveBayes",
"spark.survreg",
+ "spark.lda",
+ "spark.posterior",
+ "spark.perplexity",
"spark.isoreg",
"spark.gaussianMixture")
diff --git a/R/pkg/R/generics.R b/R/pkg/R/generics.R
index 06bb25d62d..fe04bcfc7d 100644
--- a/R/pkg/R/generics.R
+++ b/R/pkg/R/generics.R
@@ -1304,6 +1304,19 @@ setGeneric("spark.naiveBayes", function(data, formula, ...) { standardGeneric("s
#' @export
setGeneric("spark.survreg", function(data, formula, ...) { standardGeneric("spark.survreg") })
+#' @rdname spark.lda
+#' @param ... Additional parameters to tune LDA.
+#' @export
+setGeneric("spark.lda", function(data, ...) { standardGeneric("spark.lda") })
+
+#' @rdname spark.lda
+#' @export
+setGeneric("spark.posterior", function(object, newData) { standardGeneric("spark.posterior") })
+
+#' @rdname spark.lda
+#' @export
+setGeneric("spark.perplexity", function(object, data) { standardGeneric("spark.perplexity") })
+
#' @rdname spark.isoreg
#' @export
setGeneric("spark.isoreg", function(data, formula, ...) { standardGeneric("spark.isoreg") })
@@ -1315,6 +1328,7 @@ setGeneric("spark.gaussianMixture",
standardGeneric("spark.gaussianMixture")
})
+#' write.ml
#' @rdname write.ml
#' @export
setGeneric("write.ml", function(object, path, ...) { standardGeneric("write.ml") })
diff --git a/R/pkg/R/mllib.R b/R/pkg/R/mllib.R
index db74046056..b9527410a9 100644
--- a/R/pkg/R/mllib.R
+++ b/R/pkg/R/mllib.R
@@ -39,6 +39,13 @@ setClass("GeneralizedLinearRegressionModel", representation(jobj = "jobj"))
#' @note NaiveBayesModel since 2.0.0
setClass("NaiveBayesModel", representation(jobj = "jobj"))
+#' S4 class that represents an LDAModel
+#'
+#' @param jobj a Java object reference to the backing Scala LDAWrapper
+#' @export
+#' @note LDAModel since 2.1.0
+setClass("LDAModel", representation(jobj = "jobj"))
+
#' S4 class that represents a AFTSurvivalRegressionModel
#'
#' @param jobj a Java object reference to the backing Scala AFTSurvivalRegressionWrapper
@@ -75,7 +82,7 @@ setClass("GaussianMixtureModel", representation(jobj = "jobj"))
#' @name write.ml
#' @export
#' @seealso \link{spark.glm}, \link{glm}, \link{spark.gaussianMixture}
-#' @seealso \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg}
+#' @seealso \link{spark.kmeans}, \link{spark.naiveBayes}, \link{spark.survreg}, \link{spark.lda}
#' @seealso \link{spark.isoreg}
#' @seealso \link{read.ml}
NULL
@@ -315,6 +322,94 @@ setMethod("summary", signature(object = "NaiveBayesModel"),
return(list(apriori = apriori, tables = tables))
})
+# Returns posterior probabilities from a Latent Dirichlet Allocation model produced by spark.lda()
+
+#' @param newData A SparkDataFrame for testing
+#' @return \code{spark.posterior} returns a SparkDataFrame containing posterior probabilities
+#' vectors named "topicDistribution"
+#' @rdname spark.lda
+#' @aliases spark.posterior,LDAModel,SparkDataFrame-method
+#' @export
+#' @note spark.posterior(LDAModel) since 2.1.0
+setMethod("spark.posterior", signature(object = "LDAModel", newData = "SparkDataFrame"),
+ function(object, newData) {
+ return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
+ })
+
+# Returns the summary of a Latent Dirichlet Allocation model produced by \code{spark.lda}
+
+#' @param object A Latent Dirichlet Allocation model fitted by \code{spark.lda}.
+#' @param maxTermsPerTopic Maximum number of terms to collect for each topic. Default value of 10.
+#' @return \code{summary} returns a list containing
+#' \item{\code{docConcentration}}{concentration parameter commonly named \code{alpha} for
+#' the prior placed on documents distributions over topics \code{theta}}
+#' \item{\code{topicConcentration}}{concentration parameter commonly named \code{beta} or
+#' \code{eta} for the prior placed on topic distributions over terms}
+#' \item{\code{logLikelihood}}{log likelihood of the entire corpus}
+#' \item{\code{logPerplexity}}{log perplexity}
+#' \item{\code{isDistributed}}{TRUE for distributed model while FALSE for local model}
+#' \item{\code{vocabSize}}{number of terms in the corpus}
+#' \item{\code{topics}}{top 10 terms and their weights of all topics}
+#' \item{\code{vocabulary}}{whole terms of the training corpus, NULL if libsvm format file
+#' used as training set}
+#' @rdname spark.lda
+#' @aliases summary,LDAModel-method
+#' @export
+#' @note summary(LDAModel) since 2.1.0
+setMethod("summary", signature(object = "LDAModel"),
+ function(object, maxTermsPerTopic) {
+ maxTermsPerTopic <- as.integer(ifelse(missing(maxTermsPerTopic), 10, maxTermsPerTopic))
+ jobj <- object@jobj
+ docConcentration <- callJMethod(jobj, "docConcentration")
+ topicConcentration <- callJMethod(jobj, "topicConcentration")
+ logLikelihood <- callJMethod(jobj, "logLikelihood")
+ logPerplexity <- callJMethod(jobj, "logPerplexity")
+ isDistributed <- callJMethod(jobj, "isDistributed")
+ vocabSize <- callJMethod(jobj, "vocabSize")
+ topics <- dataFrame(callJMethod(jobj, "topics", maxTermsPerTopic))
+ vocabulary <- callJMethod(jobj, "vocabulary")
+ return(list(docConcentration = unlist(docConcentration),
+ topicConcentration = topicConcentration,
+ logLikelihood = logLikelihood, logPerplexity = logPerplexity,
+ isDistributed = isDistributed, vocabSize = vocabSize,
+ topics = topics,
+ vocabulary = unlist(vocabulary)))
+ })
+
+# Returns the log perplexity of a Latent Dirichlet Allocation model produced by \code{spark.lda}
+
+#' @return \code{spark.perplexity} returns the log perplexity of given SparkDataFrame, or the log
+#' perplexity of the training data if missing argument "data".
+#' @rdname spark.lda
+#' @aliases spark.perplexity,LDAModel-method
+#' @export
+#' @note spark.perplexity(LDAModel) since 2.1.0
+setMethod("spark.perplexity", signature(object = "LDAModel", data = "SparkDataFrame"),
+ function(object, data) {
+ return(ifelse(missing(data), callJMethod(object@jobj, "logPerplexity"),
+ callJMethod(object@jobj, "computeLogPerplexity", data@sdf)))
+ })
+
+# Saves the Latent Dirichlet Allocation model to the input path.
+
+#' @param path The directory where the model is saved
+#' @param overwrite Overwrites or not if the output path already exists. Default is FALSE
+#' which means throw exception if the output path exists.
+#'
+#' @rdname spark.lda
+#' @aliases write.ml,LDAModel,character-method
+#' @export
+#' @seealso \link{read.ml}
+#' @note write.ml(LDAModel, character) since 2.1.0
+setMethod("write.ml", signature(object = "LDAModel", path = "character"),
+ function(object, path, overwrite = FALSE) {
+ writer <- callJMethod(object@jobj, "write")
+ if (overwrite) {
+ writer <- callJMethod(writer, "overwrite")
+ }
+ invisible(callJMethod(writer, "save", path))
+ })
+
#' Isotonic Regression Model
#'
#' Fits an Isotonic Regression model against a Spark DataFrame, similarly to R's isoreg().
@@ -700,6 +795,8 @@ read.ml <- function(path) {
return(new("GeneralizedLinearRegressionModel", jobj = jobj))
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.KMeansWrapper")) {
return(new("KMeansModel", jobj = jobj))
+ } else if (isInstanceOf(jobj, "org.apache.spark.ml.r.LDAWrapper")) {
+ return(new("LDAModel", jobj = jobj))
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.IsotonicRegressionWrapper")) {
return(new("IsotonicRegressionModel", jobj = jobj))
} else if (isInstanceOf(jobj, "org.apache.spark.ml.r.GaussianMixtureWrapper")) {
@@ -751,6 +848,71 @@ setMethod("spark.survreg", signature(data = "SparkDataFrame", formula = "formula
return(new("AFTSurvivalRegressionModel", jobj = jobj))
})
+#' Latent Dirichlet Allocation
+#'
+#' \code{spark.lda} fits a Latent Dirichlet Allocation model on a SparkDataFrame. Users can call
+#' \code{summary} to get a summary of the fitted LDA model, \code{spark.posterior} to compute
+#' posterior probabilities on new data, \code{spark.perplexity} to compute log perplexity on new
+#' data and \code{write.ml}/\code{read.ml} to save/load fitted models.
+#'
+#' @param data A SparkDataFrame for training
+#' @param features Features column name, default "features". Either libSVM-format column or
+#' character-format column is valid.
+#' @param k Number of topics, default 10
+#' @param maxIter Maximum iterations, default 20
+#' @param optimizer Optimizer to train an LDA model, "online" or "em", default "online"
+#' @param subsamplingRate (For online optimizer) Fraction of the corpus to be sampled and used in
+#' each iteration of mini-batch gradient descent, in range (0, 1], default 0.05
+#' @param topicConcentration concentration parameter (commonly named \code{beta} or \code{eta}) for
+#' the prior placed on topic distributions over terms, default -1 to set automatically on the
+#' Spark side. Use \code{summary} to retrieve the effective topicConcentration. Only 1-size
+#' numeric is accepted.
+#' @param docConcentration concentration parameter (commonly named \code{alpha}) for the
+#' prior placed on documents distributions over topics (\code{theta}), default -1 to set
+#' automatically on the Spark side. Use \code{summary} to retrieve the effective
+#' docConcentration. Only 1-size or \code{k}-size numeric is accepted.
+#' @param customizedStopWords stopwords that need to be removed from the given corpus. Ignore the
+#' parameter if libSVM-format column is used as the features column.
+#' @param maxVocabSize maximum vocabulary size, default 1 << 18
+#' @return \code{spark.lda} returns a fitted Latent Dirichlet Allocation model
+#' @rdname spark.lda
+#' @aliases spark.lda,SparkDataFrame-method
+#' @seealso topicmodels: \url{https://cran.r-project.org/web/packages/topicmodels/}
+#' @export
+#' @examples
+#' \dontrun{
+#' text <- read.df("path/to/data", source = "libsvm")
+#' model <- spark.lda(data = text, optimizer = "em")
+#'
+#' # get a summary of the model
+#' summary(model)
+#'
+#' # compute posterior probabilities
+#' posterior <- spark.posterior(model, df)
+#' showDF(posterior)
+#'
+#' # compute perplexity
+#' perplexity <- spark.perplexity(model, df)
+#'
+#' # save and load the model
+#' path <- "path/to/model"
+#' write.ml(model, path)
+#' savedModel <- read.ml(path)
+#' summary(savedModel)
+#' }
+#' @note spark.lda since 2.1.0
+setMethod("spark.lda", signature(data = "SparkDataFrame"),
+ function(data, features = "features", k = 10, maxIter = 20, optimizer = c("online", "em"),
+ subsamplingRate = 0.05, topicConcentration = -1, docConcentration = -1,
+ customizedStopWords = "", maxVocabSize = bitwShiftL(1, 18)) {
+ optimizer <- match.arg(optimizer)
+ jobj <- callJStatic("org.apache.spark.ml.r.LDAWrapper", "fit", data@sdf, features,
+ as.integer(k), as.integer(maxIter), optimizer,
+ as.numeric(subsamplingRate), topicConcentration,
+ as.array(docConcentration), as.array(customizedStopWords),
+ maxVocabSize)
+ return(new("LDAModel", jobj = jobj))
+ })
# Returns a summary of the AFT survival regression model produced by spark.survreg,
# similarly to R's summary().
@@ -891,4 +1053,4 @@ setMethod("summary", signature(object = "GaussianMixtureModel"),
setMethod("predict", signature(object = "GaussianMixtureModel"),
function(object, newData) {
return(dataFrame(callJMethod(object@jobj, "transform", newData@sdf)))
- })
+ }) \ No newline at end of file
diff --git a/R/pkg/inst/tests/testthat/test_mllib.R b/R/pkg/inst/tests/testthat/test_mllib.R
index 96179864a8..8c380fbf15 100644
--- a/R/pkg/inst/tests/testthat/test_mllib.R
+++ b/R/pkg/inst/tests/testthat/test_mllib.R
@@ -570,4 +570,91 @@ test_that("spark.gaussianMixture", {
unlink(modelPath)
})
+test_that("spark.lda with libsvm", {
+ text <- read.df("data/mllib/sample_lda_libsvm_data.txt", source = "libsvm")
+ model <- spark.lda(text, optimizer = "em")
+
+ stats <- summary(model, 10)
+ isDistributed <- stats$isDistributed
+ logLikelihood <- stats$logLikelihood
+ logPerplexity <- stats$logPerplexity
+ vocabSize <- stats$vocabSize
+ topics <- stats$topicTopTerms
+ weights <- stats$topicTopTermsWeights
+ vocabulary <- stats$vocabulary
+
+ expect_false(isDistributed)
+ expect_true(logLikelihood <= 0 & is.finite(logLikelihood))
+ expect_true(logPerplexity >= 0 & is.finite(logPerplexity))
+ expect_equal(vocabSize, 11)
+ expect_true(is.null(vocabulary))
+
+ # Test model save/load
+ modelPath <- tempfile(pattern = "spark-lda", fileext = ".tmp")
+ write.ml(model, modelPath)
+ expect_error(write.ml(model, modelPath))
+ write.ml(model, modelPath, overwrite = TRUE)
+ model2 <- read.ml(modelPath)
+ stats2 <- summary(model2)
+
+ expect_false(stats2$isDistributed)
+ expect_equal(logLikelihood, stats2$logLikelihood)
+ expect_equal(logPerplexity, stats2$logPerplexity)
+ expect_equal(vocabSize, stats2$vocabSize)
+ expect_equal(vocabulary, stats2$vocabulary)
+
+ unlink(modelPath)
+})
+
+test_that("spark.lda with text input", {
+ text <- read.text("data/mllib/sample_lda_data.txt")
+ model <- spark.lda(text, optimizer = "online", features = "value")
+
+ stats <- summary(model)
+ isDistributed <- stats$isDistributed
+ logLikelihood <- stats$logLikelihood
+ logPerplexity <- stats$logPerplexity
+ vocabSize <- stats$vocabSize
+ topics <- stats$topicTopTerms
+ weights <- stats$topicTopTermsWeights
+ vocabulary <- stats$vocabulary
+
+ expect_false(isDistributed)
+ expect_true(logLikelihood <= 0 & is.finite(logLikelihood))
+ expect_true(logPerplexity >= 0 & is.finite(logPerplexity))
+ expect_equal(vocabSize, 10)
+ expect_true(setequal(stats$vocabulary, c("0", "1", "2", "3", "4", "5", "6", "7", "8", "9")))
+
+ # Test model save/load
+ modelPath <- tempfile(pattern = "spark-lda-text", fileext = ".tmp")
+ write.ml(model, modelPath)
+ expect_error(write.ml(model, modelPath))
+ write.ml(model, modelPath, overwrite = TRUE)
+ model2 <- read.ml(modelPath)
+ stats2 <- summary(model2)
+
+ expect_false(stats2$isDistributed)
+ expect_equal(logLikelihood, stats2$logLikelihood)
+ expect_equal(logPerplexity, stats2$logPerplexity)
+ expect_equal(vocabSize, stats2$vocabSize)
+ expect_true(all.equal(vocabulary, stats2$vocabulary))
+
+ unlink(modelPath)
+})
+
+test_that("spark.posterior and spark.perplexity", {
+ text <- read.text("data/mllib/sample_lda_data.txt")
+ model <- spark.lda(text, features = "value", k = 3)
+
+ # Assert perplexities are equal
+ stats <- summary(model)
+ logPerplexity <- spark.perplexity(model, text)
+ expect_equal(logPerplexity, stats$logPerplexity)
+
+ # Assert the sum of every topic distribution is equal to 1
+ posterior <- spark.posterior(model, text)
+ local.posterior <- collect(posterior)$topicDistribution
+ expect_equal(length(local.posterior), sum(unlist(local.posterior)))
+})
+
sparkR.session.stop()
diff --git a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
index 034f2c3fa2..b5a764b586 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/clustering/LDA.scala
@@ -386,6 +386,10 @@ sealed abstract class LDAModel private[ml] (
@Since("1.6.0")
protected def getModel: OldLDAModel
+ private[ml] def getEffectiveDocConcentration: Array[Double] = getModel.docConcentration.toArray
+
+ private[ml] def getEffectiveTopicConcentration: Double = getModel.topicConcentration
+
/**
* The features for LDA should be a [[Vector]] representing the word counts in a document.
* The vector should be of length vocabSize, with counts for each term (word).
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala
new file mode 100644
index 0000000000..cbe6a70500
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala
@@ -0,0 +1,216 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.r
+
+import scala.collection.mutable
+
+import org.apache.hadoop.fs.Path
+import org.json4s._
+import org.json4s.JsonDSL._
+import org.json4s.jackson.JsonMethods._
+
+import org.apache.spark.SparkException
+import org.apache.spark.ml.{Pipeline, PipelineModel, PipelineStage}
+import org.apache.spark.ml.clustering.{LDA, LDAModel}
+import org.apache.spark.ml.feature.{CountVectorizer, CountVectorizerModel, RegexTokenizer, StopWordsRemover}
+import org.apache.spark.ml.linalg.{Vector, VectorUDT}
+import org.apache.spark.ml.param.ParamPair
+import org.apache.spark.ml.util._
+import org.apache.spark.sql.{DataFrame, Dataset}
+import org.apache.spark.sql.functions._
+import org.apache.spark.sql.types.StringType
+
+
+private[r] class LDAWrapper private (
+ val pipeline: PipelineModel,
+ val logLikelihood: Double,
+ val logPerplexity: Double,
+ val vocabulary: Array[String]) extends MLWritable {
+
+ import LDAWrapper._
+
+ private val lda: LDAModel = pipeline.stages.last.asInstanceOf[LDAModel]
+ private val preprocessor: PipelineModel =
+ new PipelineModel(s"${Identifiable.randomUID(pipeline.uid)}", pipeline.stages.dropRight(1))
+
+ def transform(data: Dataset[_]): DataFrame = {
+ val vec2ary = udf { vec: Vector => vec.toArray }
+ val outputCol = lda.getTopicDistributionCol
+ val tempCol = s"${Identifiable.randomUID(outputCol)}"
+ val preprocessed = preprocessor.transform(data)
+ lda.transform(preprocessed, ParamPair(lda.topicDistributionCol, tempCol))
+ .withColumn(outputCol, vec2ary(col(tempCol)))
+ .drop(TOKENIZER_COL, STOPWORDS_REMOVER_COL, COUNT_VECTOR_COL, tempCol)
+ }
+
+ def computeLogPerplexity(data: Dataset[_]): Double = {
+ lda.logPerplexity(preprocessor.transform(data))
+ }
+
+ def topics(maxTermsPerTopic: Int): DataFrame = {
+ val topicIndices: DataFrame = lda.describeTopics(maxTermsPerTopic)
+ if (vocabulary.isEmpty || vocabulary.length < vocabSize) {
+ topicIndices
+ } else {
+ val index2term = udf { indices: mutable.WrappedArray[Int] => indices.map(i => vocabulary(i)) }
+ topicIndices
+ .select(col("topic"), index2term(col("termIndices")).as("term"), col("termWeights"))
+ }
+ }
+
+ lazy val isDistributed: Boolean = lda.isDistributed
+ lazy val vocabSize: Int = lda.vocabSize
+ lazy val docConcentration: Array[Double] = lda.getEffectiveDocConcentration
+ lazy val topicConcentration: Double = lda.getEffectiveTopicConcentration
+
+ override def write: MLWriter = new LDAWrapper.LDAWrapperWriter(this)
+}
+
+private[r] object LDAWrapper extends MLReadable[LDAWrapper] {
+
+ val TOKENIZER_COL = s"${Identifiable.randomUID("rawTokens")}"
+ val STOPWORDS_REMOVER_COL = s"${Identifiable.randomUID("tokens")}"
+ val COUNT_VECTOR_COL = s"${Identifiable.randomUID("features")}"
+
+ private def getPreStages(
+ features: String,
+ customizedStopWords: Array[String],
+ maxVocabSize: Int): Array[PipelineStage] = {
+ val tokenizer = new RegexTokenizer()
+ .setInputCol(features)
+ .setOutputCol(TOKENIZER_COL)
+ val stopWordsRemover = new StopWordsRemover()
+ .setInputCol(TOKENIZER_COL)
+ .setOutputCol(STOPWORDS_REMOVER_COL)
+ stopWordsRemover.setStopWords(stopWordsRemover.getStopWords ++ customizedStopWords)
+ val countVectorizer = new CountVectorizer()
+ .setVocabSize(maxVocabSize)
+ .setInputCol(STOPWORDS_REMOVER_COL)
+ .setOutputCol(COUNT_VECTOR_COL)
+
+ Array(tokenizer, stopWordsRemover, countVectorizer)
+ }
+
+ def fit(
+ data: DataFrame,
+ features: String,
+ k: Int,
+ maxIter: Int,
+ optimizer: String,
+ subsamplingRate: Double,
+ topicConcentration: Double,
+ docConcentration: Array[Double],
+ customizedStopWords: Array[String],
+ maxVocabSize: Int): LDAWrapper = {
+
+ val lda = new LDA()
+ .setK(k)
+ .setMaxIter(maxIter)
+ .setSubsamplingRate(subsamplingRate)
+
+ val featureSchema = data.schema(features)
+ val stages = featureSchema.dataType match {
+ case d: StringType =>
+ getPreStages(features, customizedStopWords, maxVocabSize) ++
+ Array(lda.setFeaturesCol(COUNT_VECTOR_COL))
+ case d: VectorUDT =>
+ Array(lda.setFeaturesCol(features))
+ case _ =>
+ throw new SparkException(
+ s"Unsupported input features type of ${featureSchema.dataType.typeName}," +
+ s" only String type and Vector type are supported now.")
+ }
+
+ if (topicConcentration != -1) {
+ lda.setTopicConcentration(topicConcentration)
+ } else {
+ // Auto-set topicConcentration
+ }
+
+ if (docConcentration.length == 1) {
+ if (docConcentration.head != -1) {
+ lda.setDocConcentration(docConcentration.head)
+ } else {
+ // Auto-set docConcentration
+ }
+ } else {
+ lda.setDocConcentration(docConcentration)
+ }
+
+ val pipeline = new Pipeline().setStages(stages)
+ val model = pipeline.fit(data)
+
+ val vocabulary: Array[String] = featureSchema.dataType match {
+ case d: StringType =>
+ val countVectorModel = model.stages(2).asInstanceOf[CountVectorizerModel]
+ countVectorModel.vocabulary
+ case _ => Array.empty[String]
+ }
+
+ val ldaModel: LDAModel = model.stages.last.asInstanceOf[LDAModel]
+ val preprocessor: PipelineModel =
+ new PipelineModel(s"${Identifiable.randomUID(pipeline.uid)}", model.stages.dropRight(1))
+
+ val preprocessedData = preprocessor.transform(data)
+
+ new LDAWrapper(
+ model,
+ ldaModel.logLikelihood(preprocessedData),
+ ldaModel.logPerplexity(preprocessedData),
+ vocabulary)
+ }
+
+ override def read: MLReader[LDAWrapper] = new LDAWrapperReader
+
+ override def load(path: String): LDAWrapper = super.load(path)
+
+ class LDAWrapperWriter(instance: LDAWrapper) extends MLWriter {
+
+ override protected def saveImpl(path: String): Unit = {
+ val rMetadataPath = new Path(path, "rMetadata").toString
+ val pipelinePath = new Path(path, "pipeline").toString
+
+ val rMetadata = ("class" -> instance.getClass.getName) ~
+ ("logLikelihood" -> instance.logLikelihood) ~
+ ("logPerplexity" -> instance.logPerplexity) ~
+ ("vocabulary" -> instance.vocabulary.toList)
+ val rMetadataJson: String = compact(render(rMetadata))
+ sc.parallelize(Seq(rMetadataJson), 1).saveAsTextFile(rMetadataPath)
+
+ instance.pipeline.save(pipelinePath)
+ }
+ }
+
+ class LDAWrapperReader extends MLReader[LDAWrapper] {
+
+ override def load(path: String): LDAWrapper = {
+ implicit val format = DefaultFormats
+ val rMetadataPath = new Path(path, "rMetadata").toString
+ val pipelinePath = new Path(path, "pipeline").toString
+
+ val rMetadataStr = sc.textFile(rMetadataPath, 1).first()
+ val rMetadata = parse(rMetadataStr)
+ val logLikelihood = (rMetadata \ "logLikelihood").extract[Double]
+ val logPerplexity = (rMetadata \ "logPerplexity").extract[Double]
+ val vocabulary = (rMetadata \ "vocabulary").extract[List[String]].toArray
+
+ val pipeline = PipelineModel.load(pipelinePath)
+ new LDAWrapper(pipeline, logLikelihood, logPerplexity, vocabulary)
+ }
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
index 88ac26bc5e..e23af51df5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/RWrappers.scala
@@ -44,6 +44,8 @@ private[r] object RWrappers extends MLReader[Object] {
GeneralizedLinearRegressionWrapper.load(path)
case "org.apache.spark.ml.r.KMeansWrapper" =>
KMeansWrapper.load(path)
+ case "org.apache.spark.ml.r.LDAWrapper" =>
+ LDAWrapper.load(path)
case "org.apache.spark.ml.r.IsotonicRegressionWrapper" =>
IsotonicRegressionWrapper.load(path)
case "org.apache.spark.ml.r.GaussianMixtureWrapper" =>