aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorSandeep <sandeep@techaddict.me>2014-04-23 22:47:59 -0700
committerReynold Xin <rxin@apache.org>2014-04-23 22:48:05 -0700
commite8907718a6c811ede565f3faf78c49d8ea083a95 (patch)
treecef596e825345fb6a28f73c7e20a0b0927d8c61b /mllib/src
parent9716a72cb927bd630188b142237c64bc48371d0c (diff)
downloadspark-e8907718a6c811ede565f3faf78c49d8ea083a95.tar.gz
spark-e8907718a6c811ede565f3faf78c49d8ea083a95.tar.bz2
spark-e8907718a6c811ede565f3faf78c49d8ea083a95.zip
[Fix #79] Replace Breakable For Loops By While Loops
Author: Sandeep <sandeep@techaddict.me> Closes #503 from techaddict/fix-79 and squashes the following commits: e3f6746 [Sandeep] Style changes 07a4f6b [Sandeep] for loop to While loop 0a6d8e9 [Sandeep] Breakable for loop to While loop (cherry picked from commit bb68f47745eec2954814d3da277a672d5cf89980) Signed-off-by: Reynold Xin <rxin@apache.org>
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala60
1 files changed, 31 insertions, 29 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 3019447ce4..f68076f426 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -17,8 +17,6 @@
package org.apache.spark.mllib.tree
-import scala.util.control.Breaks._
-
import org.apache.spark.annotation.Experimental
import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.SparkContext._
@@ -82,31 +80,34 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
* still survived the filters of the parent nodes.
*/
- // TODO: Convert for loop to while loop
- breakable {
- for (level <- 0 until maxDepth) {
-
- logDebug("#####################################")
- logDebug("level = " + level)
- logDebug("#####################################")
-
- // Find best split for all nodes at a level.
- val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy,
- level, filters, splits, bins)
-
- for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
- // Extract info for nodes at the current level.
- extractNodeInfo(nodeSplitStats, level, index, nodes)
- // Extract info for nodes at the next lower level.
- extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities,
- filters)
- logDebug("final best split = " + nodeSplitStats._1)
- }
- require(scala.math.pow(2, level) == splitsStatsForLevel.length)
- // Check whether all the nodes at the current level at leaves.
- val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0)
- logDebug("all leaf = " + allLeaf)
- if (allLeaf) break // no more tree construction
+ var level = 0
+ var break = false
+ while (level < maxDepth && !break) {
+
+ logDebug("#####################################")
+ logDebug("level = " + level)
+ logDebug("#####################################")
+
+ // Find best split for all nodes at a level.
+ val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy,
+ level, filters, splits, bins)
+
+ for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
+ // Extract info for nodes at the current level.
+ extractNodeInfo(nodeSplitStats, level, index, nodes)
+ // Extract info for nodes at the next lower level.
+ extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities,
+ filters)
+ logDebug("final best split = " + nodeSplitStats._1)
+ }
+ require(scala.math.pow(2, level) == splitsStatsForLevel.length)
+ // Check whether all the nodes at the current level at leaves.
+ val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0)
+ logDebug("all leaf = " + allLeaf)
+ if (allLeaf) {
+ break = true // no more tree construction
+ } else {
+ level += 1
}
}
@@ -146,8 +147,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
parentImpurities: Array[Double],
filters: Array[List[Filter]]): Unit = {
// 0 corresponds to the left child node and 1 corresponds to the right child node.
- // TODO: Convert to while loop
- for (i <- 0 to 1) {
+ var i = 0
+ while (i <= 1) {
// Calculate the index of the node from the node level and the index at the current level.
val nodeIndex = scala.math.pow(2, level + 1).toInt - 1 + 2 * index + i
if (level < maxDepth - 1) {
@@ -166,6 +167,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
logDebug("Filter = " + filter)
}
}
+ i += 1
}
}
}