aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorSandeep <sandeep@techaddict.me>2014-04-23 22:47:59 -0700
committerReynold Xin <rxin@apache.org>2014-04-23 22:47:59 -0700
commitbb68f47745eec2954814d3da277a672d5cf89980 (patch)
treee26139e6ffc9c5b341ac08748c51cfd9d4f4ed50 /mllib
parent6ab7578067e3bb78b64f99fd67c97e9607050ffe (diff)
downloadspark-bb68f47745eec2954814d3da277a672d5cf89980.tar.gz
spark-bb68f47745eec2954814d3da277a672d5cf89980.tar.bz2
spark-bb68f47745eec2954814d3da277a672d5cf89980.zip
[Fix #79] Replace Breakable For Loops By While Loops
Author: Sandeep <sandeep@techaddict.me> Closes #503 from techaddict/fix-79 and squashes the following commits: e3f6746 [Sandeep] Style changes 07a4f6b [Sandeep] for loop to While loop 0a6d8e9 [Sandeep] Breakable for loop to While loop
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala60
1 files changed, 31 insertions, 29 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 3019447ce4..f68076f426 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -17,8 +17,6 @@
package org.apache.spark.mllib.tree
-import scala.util.control.Breaks._
-
import org.apache.spark.annotation.Experimental
import org.apache.spark.{Logging, SparkContext}
import org.apache.spark.SparkContext._
@@ -82,31 +80,34 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
* still survived the filters of the parent nodes.
*/
- // TODO: Convert for loop to while loop
- breakable {
- for (level <- 0 until maxDepth) {
-
- logDebug("#####################################")
- logDebug("level = " + level)
- logDebug("#####################################")
-
- // Find best split for all nodes at a level.
- val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy,
- level, filters, splits, bins)
-
- for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
- // Extract info for nodes at the current level.
- extractNodeInfo(nodeSplitStats, level, index, nodes)
- // Extract info for nodes at the next lower level.
- extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities,
- filters)
- logDebug("final best split = " + nodeSplitStats._1)
- }
- require(scala.math.pow(2, level) == splitsStatsForLevel.length)
- // Check whether all the nodes at the current level at leaves.
- val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0)
- logDebug("all leaf = " + allLeaf)
- if (allLeaf) break // no more tree construction
+ var level = 0
+ var break = false
+ while (level < maxDepth && !break) {
+
+ logDebug("#####################################")
+ logDebug("level = " + level)
+ logDebug("#####################################")
+
+ // Find best split for all nodes at a level.
+ val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy,
+ level, filters, splits, bins)
+
+ for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
+ // Extract info for nodes at the current level.
+ extractNodeInfo(nodeSplitStats, level, index, nodes)
+ // Extract info for nodes at the next lower level.
+ extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities,
+ filters)
+ logDebug("final best split = " + nodeSplitStats._1)
+ }
+ require(scala.math.pow(2, level) == splitsStatsForLevel.length)
+ // Check whether all the nodes at the current level at leaves.
+ val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0)
+ logDebug("all leaf = " + allLeaf)
+ if (allLeaf) {
+ break = true // no more tree construction
+ } else {
+ level += 1
}
}
@@ -146,8 +147,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
parentImpurities: Array[Double],
filters: Array[List[Filter]]): Unit = {
// 0 corresponds to the left child node and 1 corresponds to the right child node.
- // TODO: Convert to while loop
- for (i <- 0 to 1) {
+ var i = 0
+ while (i <= 1) {
// Calculate the index of the node from the node level and the index at the current level.
val nodeIndex = scala.math.pow(2, level + 1).toInt - 1 + 2 * index + i
if (level < maxDepth - 1) {
@@ -166,6 +167,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
logDebug("Filter = " + filter)
}
}
+ i += 1
}
}
}