aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/test/scala/org
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph@databricks.com>2015-11-10 16:20:10 -0800
committerJoseph K. Bradley <joseph@databricks.com>2015-11-10 16:20:10 -0800
commite281b87398f1298cc3df8e0409c7040acdddce03 (patch)
tree0b3c9361181479c47bc61e1000e103c831d52f72 /mllib/src/test/scala/org
parent1dde39d796bbf42336051a86bedf871c7fddd513 (diff)
downloadspark-e281b87398f1298cc3df8e0409c7040acdddce03.tar.gz
spark-e281b87398f1298cc3df8e0409c7040acdddce03.tar.bz2
spark-e281b87398f1298cc3df8e0409c7040acdddce03.zip
[SPARK-5565][ML] LDA wrapper for Pipelines API
This adds LDA to spark.ml, the Pipelines API. It follows the design doc in the JIRA: [https://issues.apache.org/jira/browse/SPARK-5565], with one major change: * I eliminated doc IDs. These are not necessary with DataFrames since the user can add an ID column as needed. Note: This will conflict with [https://github.com/apache/spark/pull/9484], but I'll try to merge [https://github.com/apache/spark/pull/9484] first and then rebase this PR. CC: hhbyyh feynmanliang If you have a chance to make a pass, that'd be really helpful--thanks! Now that I'm done traveling & this PR is almost ready, I'll see about reviewing other PRs critical for 1.6. CC: mengxr Author: Joseph K. Bradley <joseph@databricks.com> Closes #9513 from jkbradley/lda-pipelines.
Diffstat (limited to 'mllib/src/test/scala/org')
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala221
1 files changed, 221 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
new file mode 100644
index 0000000000..edb927495e
--- /dev/null
+++ b/mllib/src/test/scala/org/apache/spark/ml/clustering/LDASuite.scala
@@ -0,0 +1,221 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ml.clustering
+
+import org.apache.spark.SparkFunSuite
+import org.apache.spark.ml.util.MLTestingUtils
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
+import org.apache.spark.mllib.util.MLlibTestSparkContext
+import org.apache.spark.sql.{DataFrame, Row, SQLContext}
+
+
+object LDASuite {
+ def generateLDAData(
+ sql: SQLContext,
+ rows: Int,
+ k: Int,
+ vocabSize: Int): DataFrame = {
+ val avgWC = 1 // average instances of each word in a doc
+ val sc = sql.sparkContext
+ val rng = new java.util.Random()
+ rng.setSeed(1)
+ val rdd = sc.parallelize(1 to rows).map { i =>
+ Vectors.dense(Array.fill(vocabSize)(rng.nextInt(2 * avgWC).toDouble))
+ }.map(v => new TestRow(v))
+ sql.createDataFrame(rdd)
+ }
+}
+
+
+class LDASuite extends SparkFunSuite with MLlibTestSparkContext {
+
+ val k: Int = 5
+ val vocabSize: Int = 30
+ @transient var dataset: DataFrame = _
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ dataset = LDASuite.generateLDAData(sqlContext, 50, k, vocabSize)
+ }
+
+ test("default parameters") {
+ val lda = new LDA()
+
+ assert(lda.getFeaturesCol === "features")
+ assert(lda.getMaxIter === 20)
+ assert(lda.isDefined(lda.seed))
+ assert(lda.getCheckpointInterval === 10)
+ assert(lda.getK === 10)
+ assert(!lda.isSet(lda.docConcentration))
+ assert(!lda.isSet(lda.topicConcentration))
+ assert(lda.getOptimizer === "online")
+ assert(lda.getLearningDecay === 0.51)
+ assert(lda.getLearningOffset === 1024)
+ assert(lda.getSubsamplingRate === 0.05)
+ assert(lda.getOptimizeDocConcentration)
+ assert(lda.getTopicDistributionCol === "topicDistribution")
+ }
+
+ test("set parameters") {
+ val lda = new LDA()
+ .setFeaturesCol("test_feature")
+ .setMaxIter(33)
+ .setSeed(123)
+ .setCheckpointInterval(7)
+ .setK(9)
+ .setTopicConcentration(0.56)
+ .setTopicDistributionCol("myOutput")
+
+ assert(lda.getFeaturesCol === "test_feature")
+ assert(lda.getMaxIter === 33)
+ assert(lda.getSeed === 123)
+ assert(lda.getCheckpointInterval === 7)
+ assert(lda.getK === 9)
+ assert(lda.getTopicConcentration === 0.56)
+ assert(lda.getTopicDistributionCol === "myOutput")
+
+
+ // setOptimizer
+ lda.setOptimizer("em")
+ assert(lda.getOptimizer === "em")
+ lda.setOptimizer("online")
+ assert(lda.getOptimizer === "online")
+ lda.setLearningDecay(0.53)
+ assert(lda.getLearningDecay === 0.53)
+ lda.setLearningOffset(1027)
+ assert(lda.getLearningOffset === 1027)
+ lda.setSubsamplingRate(0.06)
+ assert(lda.getSubsamplingRate === 0.06)
+ lda.setOptimizeDocConcentration(false)
+ assert(!lda.getOptimizeDocConcentration)
+ }
+
+ test("parameters validation") {
+ val lda = new LDA()
+
+ // misc Params
+ intercept[IllegalArgumentException] {
+ new LDA().setK(1)
+ }
+ intercept[IllegalArgumentException] {
+ new LDA().setOptimizer("no_such_optimizer")
+ }
+ intercept[IllegalArgumentException] {
+ new LDA().setDocConcentration(-1.1)
+ }
+ intercept[IllegalArgumentException] {
+ new LDA().setTopicConcentration(-1.1)
+ }
+
+ // validateParams()
+ lda.validateParams()
+ lda.setDocConcentration(1.1)
+ lda.validateParams()
+ lda.setDocConcentration(Range(0, lda.getK).map(_ + 2.0).toArray)
+ lda.validateParams()
+ lda.setDocConcentration(Range(0, lda.getK - 1).map(_ + 2.0).toArray)
+ withClue("LDA docConcentration validity check failed for bad array length") {
+ intercept[IllegalArgumentException] {
+ lda.validateParams()
+ }
+ }
+
+ // Online LDA
+ intercept[IllegalArgumentException] {
+ new LDA().setLearningOffset(0)
+ }
+ intercept[IllegalArgumentException] {
+ new LDA().setLearningDecay(0)
+ }
+ intercept[IllegalArgumentException] {
+ new LDA().setSubsamplingRate(0)
+ }
+ intercept[IllegalArgumentException] {
+ new LDA().setSubsamplingRate(1.1)
+ }
+ }
+
+ test("fit & transform with Online LDA") {
+ val lda = new LDA().setK(k).setSeed(1).setOptimizer("online").setMaxIter(2)
+ val model = lda.fit(dataset)
+
+ MLTestingUtils.checkCopy(model)
+
+ assert(!model.isInstanceOf[DistributedLDAModel])
+ assert(model.vocabSize === vocabSize)
+ assert(model.estimatedDocConcentration.size === k)
+ assert(model.topicsMatrix.numRows === vocabSize)
+ assert(model.topicsMatrix.numCols === k)
+ assert(!model.isDistributed)
+
+ // transform()
+ val transformed = model.transform(dataset)
+ val expectedColumns = Array("features", lda.getTopicDistributionCol)
+ expectedColumns.foreach { column =>
+ assert(transformed.columns.contains(column))
+ }
+ transformed.select(lda.getTopicDistributionCol).collect().foreach { r =>
+ val topicDistribution = r.getAs[Vector](0)
+ assert(topicDistribution.size === k)
+ assert(topicDistribution.toArray.forall(w => w >= 0.0 && w <= 1.0))
+ }
+
+ // logLikelihood, logPerplexity
+ val ll = model.logLikelihood(dataset)
+ assert(ll <= 0.0 && ll != Double.NegativeInfinity)
+ val lp = model.logPerplexity(dataset)
+ assert(lp >= 0.0 && lp != Double.PositiveInfinity)
+
+ // describeTopics
+ val topics = model.describeTopics(3)
+ assert(topics.count() === k)
+ assert(topics.select("topic").map(_.getInt(0)).collect().toSet === Range(0, k).toSet)
+ topics.select("termIndices").collect().foreach { case r: Row =>
+ val termIndices = r.getAs[Seq[Int]](0)
+ assert(termIndices.length === 3 && termIndices.toSet.size === 3)
+ }
+ topics.select("termWeights").collect().foreach { case r: Row =>
+ val termWeights = r.getAs[Seq[Double]](0)
+ assert(termWeights.length === 3 && termWeights.forall(w => w >= 0.0 && w <= 1.0))
+ }
+ }
+
+ test("fit & transform with EM LDA") {
+ val lda = new LDA().setK(k).setSeed(1).setOptimizer("em").setMaxIter(2)
+ val model_ = lda.fit(dataset)
+
+ MLTestingUtils.checkCopy(model_)
+
+ assert(model_.isInstanceOf[DistributedLDAModel])
+ val model = model_.asInstanceOf[DistributedLDAModel]
+ assert(model.vocabSize === vocabSize)
+ assert(model.estimatedDocConcentration.size === k)
+ assert(model.topicsMatrix.numRows === vocabSize)
+ assert(model.topicsMatrix.numCols === k)
+ assert(model.isDistributed)
+
+ val localModel = model.toLocal
+ assert(!localModel.isInstanceOf[DistributedLDAModel])
+
+ // training logLikelihood, logPrior
+ val ll = model.trainingLogLikelihood
+ assert(ll <= 0.0 && ll != Double.NegativeInfinity)
+ val lp = model.logPrior
+ assert(lp <= 0.0 && lp != Double.NegativeInfinity)
+ }
+}