From bb68f47745eec2954814d3da277a672d5cf89980 Mon Sep 17 00:00:00 2001 From: Sandeep Date: Wed, 23 Apr 2014 22:47:59 -0700 Subject: [Fix #79] Replace Breakable For Loops By While Loops Author: Sandeep Closes #503 from techaddict/fix-79 and squashes the following commits: e3f6746 [Sandeep] Style changes 07a4f6b [Sandeep] for loop to While loop 0a6d8e9 [Sandeep] Breakable for loop to While loop --- .../org/apache/spark/mllib/tree/DecisionTree.scala | 60 +++++++++++----------- 1 file changed, 31 insertions(+), 29 deletions(-) (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 3019447ce4..f68076f426 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -17,8 +17,6 @@ package org.apache.spark.mllib.tree -import scala.util.control.Breaks._ - import org.apache.spark.annotation.Experimental import org.apache.spark.{Logging, SparkContext} import org.apache.spark.SparkContext._ @@ -82,31 +80,34 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo * still survived the filters of the parent nodes. */ - // TODO: Convert for loop to while loop - breakable { - for (level <- 0 until maxDepth) { - - logDebug("#####################################") - logDebug("level = " + level) - logDebug("#####################################") - - // Find best split for all nodes at a level. - val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy, - level, filters, splits, bins) - - for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { - // Extract info for nodes at the current level. - extractNodeInfo(nodeSplitStats, level, index, nodes) - // Extract info for nodes at the next lower level. - extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities, - filters) - logDebug("final best split = " + nodeSplitStats._1) - } - require(scala.math.pow(2, level) == splitsStatsForLevel.length) - // Check whether all the nodes at the current level at leaves. - val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0) - logDebug("all leaf = " + allLeaf) - if (allLeaf) break // no more tree construction + var level = 0 + var break = false + while (level < maxDepth && !break) { + + logDebug("#####################################") + logDebug("level = " + level) + logDebug("#####################################") + + // Find best split for all nodes at a level. + val splitsStatsForLevel = DecisionTree.findBestSplits(input, parentImpurities, strategy, + level, filters, splits, bins) + + for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) { + // Extract info for nodes at the current level. + extractNodeInfo(nodeSplitStats, level, index, nodes) + // Extract info for nodes at the next lower level. + extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities, + filters) + logDebug("final best split = " + nodeSplitStats._1) + } + require(scala.math.pow(2, level) == splitsStatsForLevel.length) + // Check whether all the nodes at the current level at leaves. + val allLeaf = splitsStatsForLevel.forall(_._2.gain <= 0) + logDebug("all leaf = " + allLeaf) + if (allLeaf) { + break = true // no more tree construction + } else { + level += 1 } } @@ -146,8 +147,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo parentImpurities: Array[Double], filters: Array[List[Filter]]): Unit = { // 0 corresponds to the left child node and 1 corresponds to the right child node. - // TODO: Convert to while loop - for (i <- 0 to 1) { + var i = 0 + while (i <= 1) { // Calculate the index of the node from the node level and the index at the current level. val nodeIndex = scala.math.pow(2, level + 1).toInt - 1 + 2 * index + i if (level < maxDepth - 1) { @@ -166,6 +167,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo logDebug("Filter = " + filter) } } + i += 1 } } } -- cgit v1.2.3