aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorwm624@hotmail.com <wm624@hotmail.com>2017-01-16 06:05:59 -0800
committerYanbo Liang <ybliang8@gmail.com>2017-01-16 06:05:59 -0800
commit12c8c2160829ad8ccdab1741530361cdabdcd39d (patch)
treedd8dcfbc084b620ce8e1ea68ec6b26491757e20f /mllib/src
parente635cbb6e61dee450db0ef45f3d82ac282a85d3c (diff)
downloadspark-12c8c2160829ad8ccdab1741530361cdabdcd39d.tar.gz
spark-12c8c2160829ad8ccdab1741530361cdabdcd39d.tar.bz2
spark-12c8c2160829ad8ccdab1741530361cdabdcd39d.zip
[SPARK-19066][SPARKR] SparkR LDA doesn't set optimizer correctly
## What changes were proposed in this pull request? spark.lda passes the optimizer "em" or "online" as a string to the backend. However, LDAWrapper doesn't set optimizer based on the value from R. Therefore, for optimizer "em", the `isDistributed` field is FALSE, which should be TRUE based on scala code. In addition, the `summary` method should bring back the results related to `DistributedLDAModel`. ## How was this patch tested? Manual tests by comparing with scala example. Modified the current unit test: fix the incorrect unit test and add necessary tests for `summary` method. Author: wm624@hotmail.com <wm624@hotmail.com> Closes #16464 from wangmiao1981/new.
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala10
1 files changed, 9 insertions, 1 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala b/mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala
index cbe6a70500..e096bf1f29 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/r/LDAWrapper.scala
@@ -26,7 +26,7 @@ import org.json4s.jackson.JsonMethods._
import org.apache.spark.SparkException
import org.apache.spark.ml.{Pipeline, PipelineModel, PipelineStage}
-import org.apache.spark.ml.clustering.{LDA, LDAModel}
+import org.apache.spark.ml.clustering.{DistributedLDAModel, LDA, LDAModel}
import org.apache.spark.ml.feature.{CountVectorizer, CountVectorizerModel, RegexTokenizer, StopWordsRemover}
import org.apache.spark.ml.linalg.{Vector, VectorUDT}
import org.apache.spark.ml.param.ParamPair
@@ -45,6 +45,13 @@ private[r] class LDAWrapper private (
import LDAWrapper._
private val lda: LDAModel = pipeline.stages.last.asInstanceOf[LDAModel]
+
+ // The following variables were called by R side code only when the LDA model is distributed
+ lazy private val distributedModel =
+ pipeline.stages.last.asInstanceOf[DistributedLDAModel]
+ lazy val trainingLogLikelihood: Double = distributedModel.trainingLogLikelihood
+ lazy val logPrior: Double = distributedModel.logPrior
+
private val preprocessor: PipelineModel =
new PipelineModel(s"${Identifiable.randomUID(pipeline.uid)}", pipeline.stages.dropRight(1))
@@ -122,6 +129,7 @@ private[r] object LDAWrapper extends MLReadable[LDAWrapper] {
.setK(k)
.setMaxIter(maxIter)
.setSubsamplingRate(subsamplingRate)
+ .setOptimizer(optimizer)
val featureSchema = data.schema(features)
val stages = featureSchema.dataType match {