aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@gmail.com>2015-02-10 21:51:15 -0800
committerXiangrui Meng <meng@databricks.com>2015-02-10 21:51:15 -0800
commitf86a89a2e081ee4593ce03398c2283fd77daac6e (patch)
tree8564bce0d6f0ec4dedd5f681626e2bab313214dd
parentb8f88d32723eaea4807c10b5b79d0c76f30b0510 (diff)
downloadspark-f86a89a2e081ee4593ce03398c2283fd77daac6e.tar.gz
spark-f86a89a2e081ee4593ce03398c2283fd77daac6e.tar.bz2
spark-f86a89a2e081ee4593ce03398c2283fd77daac6e.zip
[SPARK-5714][Mllib] Refactor initial step of LDA to remove redundant operations
The `initialState` of LDA performs several RDD operations that looks redundant. This pr tries to simplify these operations. Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #4501 from viirya/sim_lda and squashes the following commits: 4870fe4 [Liang-Chi Hsieh] For comments. 9af1487 [Liang-Chi Hsieh] Refactor initial step of LDA to remove redundant operations.
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala37
1 files changed, 13 insertions, 24 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
index a1d3df03a1..5e17c8da61 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/clustering/LDA.scala
@@ -450,34 +450,23 @@ private[clustering] object LDA {
// Create vertices.
// Initially, we use random soft assignments of tokens to topics (random gamma).
- val edgesWithGamma: RDD[(Edge[TokenCount], TopicCounts)] =
- edges.mapPartitionsWithIndex { case (partIndex, partEdges) =>
- val random = new Random(partIndex + randomSeed)
- partEdges.map { edge =>
- // Create a random gamma_{wjk}
- (edge, normalize(BDV.fill[Double](k)(random.nextDouble()), 1.0))
+ def createVertices(): RDD[(VertexId, TopicCounts)] = {
+ val verticesTMP: RDD[(VertexId, TopicCounts)] =
+ edges.mapPartitionsWithIndex { case (partIndex, partEdges) =>
+ val random = new Random(partIndex + randomSeed)
+ partEdges.flatMap { edge =>
+ val gamma = normalize(BDV.fill[Double](k)(random.nextDouble()), 1.0)
+ val sum = gamma * edge.attr
+ Seq((edge.srcId, sum), (edge.dstId, sum))
+ }
}
- }
- def createVertices(sendToWhere: Edge[TokenCount] => VertexId): RDD[(VertexId, TopicCounts)] = {
- val verticesTMP: RDD[(VertexId, (TokenCount, TopicCounts))] =
- edgesWithGamma.map { case (edge, gamma: TopicCounts) =>
- (sendToWhere(edge), (edge.attr, gamma))
- }
- verticesTMP.aggregateByKey(BDV.zeros[Double](k))(
- (sum, t) => {
- brzAxpy(t._1, t._2, sum)
- sum
- },
- (sum0, sum1) => {
- sum0 += sum1
- }
- )
+ verticesTMP.reduceByKey(_ + _)
}
- val docVertices = createVertices(_.srcId)
- val termVertices = createVertices(_.dstId)
+
+ val docTermVertices = createVertices()
// Partition such that edges are grouped by document
- val graph = Graph(docVertices ++ termVertices, edges)
+ val graph = Graph(docTermVertices, edges)
.partitionBy(PartitionStrategy.EdgePartition1D)
new EMOptimizer(graph, k, vocabSize, docConcentration, topicConcentration, checkpointInterval)