aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph.kurata.bradley@gmail.com>2014-08-15 14:50:10 -0700
committerXiangrui Meng <meng@databricks.com>2014-08-15 14:50:10 -0700
commitc7032290a3f0f5545aa4f0a9a144c62571344dc8 (patch)
tree4e9da3e875eda32ef0e430f4928f3ab6e2d31e3c /mllib
parent0afe5cb65a195d2f14e8dfcefdbec5dac023651f (diff)
downloadspark-c7032290a3f0f5545aa4f0a9a144c62571344dc8.tar.gz
spark-c7032290a3f0f5545aa4f0a9a144c62571344dc8.tar.bz2
spark-c7032290a3f0f5545aa4f0a9a144c62571344dc8.zip
[SPARK-3022] [SPARK-3041] [mllib] Call findBins once per level + unordered feature bug fix
DecisionTree improvements: (1) TreePoint representation to avoid binning multiple times (2) Bug fix: isSampleValid indexed bins incorrectly for unordered categorical features (3) Timing for DecisionTree internals Details: (1) TreePoint representation to avoid binning multiple times [https://issues.apache.org/jira/browse/SPARK-3022] Added private[tree] TreePoint class for representing binned feature values. The input RDD of LabeledPoint is converted to the TreePoint representation initially and then cached. This avoids the previous problem of re-computing bins multiple times. (2) Bug fix: isSampleValid indexed bins incorrectly for unordered categorical features [https://issues.apache.org/jira/browse/SPARK-3041] isSampleValid used to treat unordered categorical features incorrectly: It treated the bins as if indexed by featured values, rather than by subsets of values/categories. * exhibited for unordered features (multi-class classification with categorical features of low arity) * Fix: Index bins correctly for unordered categorical features. (3) Timing for DecisionTree internals Added tree/impl/TimeTracker.scala class which is private[tree] for now, for timing key parts of DT code. Prints timing info via logDebug. CC: mengxr manishamde chouqin Very similar update, with one bug fix. Many apologies for the conflicting update, but I hope that a few more optimizations I have on the way (which depend on this update) will prove valuable to you: SPARK-3042 and SPARK-3043 Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com> Closes #1950 from jkbradley/dt-opt1 and squashes the following commits: 5f2dec2 [Joseph K. Bradley] Fixed scalastyle issue in TreePoint 6b5651e [Joseph K. Bradley] Updates based on code review. 1 major change: persisting to memory + disk, not just memory. 2d2aaaf [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt1 430d782 [Joseph K. Bradley] Added more debug info on binning error. Added some docs. d036089 [Joseph K. Bradley] Print timing info to logDebug. e66f1b1 [Joseph K. Bradley] TreePoint * Updated doc * Made some methods private 8464a6e [Joseph K. Bradley] Moved TimeTracker to tree/impl/ in its own file, and cleaned it up. Removed debugging println calls from DecisionTree. Made TreePoint extend Serialiable a87e08f [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-opt1 0f676e2 [Joseph K. Bradley] Optimizations + Bug fix for DecisionTree 3211f02 [Joseph K. Bradley] Optimizing DecisionTree * Added TreePoint representation to avoid calling findBin multiple times. * (not working yet, but debugging) f61e9d2 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing bcf874a [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing 511ec85 [Joseph K. Bradley] Merge remote-tracking branch 'upstream/master' into dt-timing a95bc22 [Joseph K. Bradley] timing for DecisionTree internals
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala289
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala43
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala73
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala201
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala50
5 files changed, 449 insertions, 207 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index bb50f07be5..2a3107a13e 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -17,22 +17,24 @@
package org.apache.spark.mllib.tree
-import org.apache.spark.api.java.JavaRDD
-
import scala.collection.JavaConverters._
import org.apache.spark.annotation.Experimental
+import org.apache.spark.api.java.JavaRDD
import org.apache.spark.Logging
import org.apache.spark.mllib.regression.LabeledPoint
-import org.apache.spark.mllib.tree.configuration.{Algo, Strategy}
+import org.apache.spark.mllib.tree.configuration.Strategy
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
-import org.apache.spark.mllib.tree.impurity.{Impurities, Gini, Entropy, Impurity}
+import org.apache.spark.mllib.tree.impl.{TimeTracker, TreePoint}
+import org.apache.spark.mllib.tree.impurity.{Impurities, Impurity}
import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
+import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.random.XORShiftRandom
+
/**
* :: Experimental ::
* A class which implements a decision tree learning algorithm for classification and regression.
@@ -53,16 +55,27 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
*/
def train(input: RDD[LabeledPoint]): DecisionTreeModel = {
- // Cache input RDD for speedup during multiple passes.
- val retaggedInput = input.retag(classOf[LabeledPoint]).cache()
+ val timer = new TimeTracker()
+
+ timer.start("total")
+
+ timer.start("init")
+
+ val retaggedInput = input.retag(classOf[LabeledPoint])
logDebug("algo = " + strategy.algo)
// Find the splits and the corresponding bins (interval between the splits) using a sample
// of the input data.
+ timer.start("findSplitsBins")
val (splits, bins) = DecisionTree.findSplitsBins(retaggedInput, strategy)
val numBins = bins(0).length
+ timer.stop("findSplitsBins")
logDebug("numBins = " + numBins)
+ // Cache input RDD for speedup during multiple passes.
+ val treeInput = TreePoint.convertToTreeRDD(retaggedInput, strategy, bins)
+ .persist(StorageLevel.MEMORY_AND_DISK)
+
// depth of the decision tree
val maxDepth = strategy.maxDepth
// the max number of nodes possible given the depth of the tree
@@ -76,7 +89,7 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
// dummy value for top node (updated during first split calculation)
val nodes = new Array[Node](maxNumNodes)
// num features
- val numFeatures = retaggedInput.take(1)(0).features.size
+ val numFeatures = treeInput.take(1)(0).binnedFeatures.size
// Calculate level for single group construction
@@ -96,6 +109,8 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
(math.log(maxNumberOfNodesPerGroup) / math.log(2)).floor.toInt, 0)
logDebug("max level for single group = " + maxLevelForSingleGroup)
+ timer.stop("init")
+
/*
* The main idea here is to perform level-wise training of the decision tree nodes thus
* reducing the passes over the data from l to log2(l) where l is the total number of nodes.
@@ -113,15 +128,21 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
logDebug("#####################################")
// Find best split for all nodes at a level.
- val splitsStatsForLevel = DecisionTree.findBestSplits(retaggedInput, parentImpurities,
- strategy, level, filters, splits, bins, maxLevelForSingleGroup)
+ timer.start("findBestSplits")
+ val splitsStatsForLevel = DecisionTree.findBestSplits(treeInput, parentImpurities,
+ strategy, level, filters, splits, bins, maxLevelForSingleGroup, timer)
+ timer.stop("findBestSplits")
for ((nodeSplitStats, index) <- splitsStatsForLevel.view.zipWithIndex) {
+ timer.start("extractNodeInfo")
// Extract info for nodes at the current level.
extractNodeInfo(nodeSplitStats, level, index, nodes)
+ timer.stop("extractNodeInfo")
+ timer.start("extractInfoForLowerLevels")
// Extract info for nodes at the next lower level.
extractInfoForLowerLevels(level, index, maxDepth, nodeSplitStats, parentImpurities,
filters)
+ timer.stop("extractInfoForLowerLevels")
logDebug("final best split = " + nodeSplitStats._1)
}
require(math.pow(2, level) == splitsStatsForLevel.length)
@@ -144,6 +165,11 @@ class DecisionTree (private val strategy: Strategy) extends Serializable with Lo
// Build the full tree using the node info calculated in the level-wise best split calculations.
topNode.build(nodes)
+ timer.stop("total")
+
+ logInfo("Internal timing for DecisionTree:")
+ logInfo(s"$timer")
+
new DecisionTreeModel(topNode, strategy.algo)
}
@@ -406,7 +432,7 @@ object DecisionTree extends Serializable with Logging {
* Returns an array of optimal splits for all nodes at a given level. Splits the task into
* multiple groups if the level-wise training task could lead to memory overflow.
*
- * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
+ * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]]
* @param parentImpurities Impurities for all parent nodes for the current level
* @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing
* parameters for constructing the DecisionTree
@@ -415,44 +441,45 @@ object DecisionTree extends Serializable with Logging {
* @param splits possible splits for all features
* @param bins possible bins for all features
* @param maxLevelForSingleGroup the deepest level for single-group level-wise computation.
- * @return array of splits with best splits for all nodes at a given level.
+ * @return array (over nodes) of splits with best split for each node at a given level.
*/
protected[tree] def findBestSplits(
- input: RDD[LabeledPoint],
+ input: RDD[TreePoint],
parentImpurities: Array[Double],
strategy: Strategy,
level: Int,
filters: Array[List[Filter]],
splits: Array[Array[Split]],
bins: Array[Array[Bin]],
- maxLevelForSingleGroup: Int): Array[(Split, InformationGainStats)] = {
+ maxLevelForSingleGroup: Int,
+ timer: TimeTracker = new TimeTracker): Array[(Split, InformationGainStats)] = {
// split into groups to avoid memory overflow during aggregation
if (level > maxLevelForSingleGroup) {
// When information for all nodes at a given level cannot be stored in memory,
// the nodes are divided into multiple groups at each level with the number of groups
// increasing exponentially per level. For example, if maxLevelForSingleGroup is 10,
// numGroups is equal to 2 at level 11 and 4 at level 12, respectively.
- val numGroups = math.pow(2, (level - maxLevelForSingleGroup)).toInt
+ val numGroups = math.pow(2, level - maxLevelForSingleGroup).toInt
logDebug("numGroups = " + numGroups)
var bestSplits = new Array[(Split, InformationGainStats)](0)
// Iterate over each group of nodes at a level.
var groupIndex = 0
while (groupIndex < numGroups) {
val bestSplitsForGroup = findBestSplitsPerGroup(input, parentImpurities, strategy, level,
- filters, splits, bins, numGroups, groupIndex)
+ filters, splits, bins, timer, numGroups, groupIndex)
bestSplits = Array.concat(bestSplits, bestSplitsForGroup)
groupIndex += 1
}
bestSplits
} else {
- findBestSplitsPerGroup(input, parentImpurities, strategy, level, filters, splits, bins)
+ findBestSplitsPerGroup(input, parentImpurities, strategy, level, filters, splits, bins, timer)
}
}
/**
* Returns an array of optimal splits for a group of nodes at a given level
*
- * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
+ * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]]
* @param parentImpurities Impurities for all parent nodes for the current level
* @param strategy [[org.apache.spark.mllib.tree.configuration.Strategy]] instance containing
* parameters for constructing the DecisionTree
@@ -465,13 +492,14 @@ object DecisionTree extends Serializable with Logging {
* @return array of splits with best splits for all nodes at a given level.
*/
private def findBestSplitsPerGroup(
- input: RDD[LabeledPoint],
+ input: RDD[TreePoint],
parentImpurities: Array[Double],
strategy: Strategy,
level: Int,
filters: Array[List[Filter]],
splits: Array[Array[Split]],
bins: Array[Array[Bin]],
+ timer: TimeTracker,
numGroups: Int = 1,
groupIndex: Int = 0): Array[(Split, InformationGainStats)] = {
@@ -507,7 +535,7 @@ object DecisionTree extends Serializable with Logging {
logDebug("numNodes = " + numNodes)
// Find the number of features by looking at the first sample.
- val numFeatures = input.first().features.size
+ val numFeatures = input.first().binnedFeatures.size
logDebug("numFeatures = " + numFeatures)
// numBins: Number of bins = 1 + number of possible splits
@@ -542,33 +570,43 @@ object DecisionTree extends Serializable with Logging {
* Find whether the sample is valid input for the current node, i.e., whether it passes through
* all the filters for the current node.
*/
- def isSampleValid(parentFilters: List[Filter], labeledPoint: LabeledPoint): Boolean = {
+ def isSampleValid(parentFilters: List[Filter], treePoint: TreePoint): Boolean = {
// leaf
if ((level > 0) && (parentFilters.length == 0)) {
return false
}
// Apply each filter and check sample validity. Return false when invalid condition found.
- for (filter <- parentFilters) {
- val features = labeledPoint.features
+ parentFilters.foreach { filter =>
val featureIndex = filter.split.feature
- val threshold = filter.split.threshold
val comparison = filter.comparison
- val categories = filter.split.categories
val isFeatureContinuous = filter.split.featureType == Continuous
- val feature = features(featureIndex)
if (isFeatureContinuous) {
+ val binId = treePoint.binnedFeatures(featureIndex)
+ val bin = bins(featureIndex)(binId)
+ val featureValue = bin.highSplit.threshold
+ val threshold = filter.split.threshold
comparison match {
- case -1 => if (feature > threshold) return false
- case 1 => if (feature <= threshold) return false
+ case -1 => if (featureValue > threshold) return false
+ case 1 => if (featureValue <= threshold) return false
}
} else {
- val containsFeature = categories.contains(feature)
+ val numFeatureCategories = strategy.categoricalFeaturesInfo(featureIndex)
+ val isSpaceSufficientForAllCategoricalSplits =
+ numBins > math.pow(2, numFeatureCategories.toInt - 1) - 1
+ val isUnorderedFeature =
+ isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits
+ val featureValue = if (isUnorderedFeature) {
+ treePoint.binnedFeatures(featureIndex)
+ } else {
+ val binId = treePoint.binnedFeatures(featureIndex)
+ bins(featureIndex)(binId).category
+ }
+ val containsFeature = filter.split.categories.contains(featureValue)
comparison match {
case -1 => if (!containsFeature) return false
case 1 => if (containsFeature) return false
}
-
}
}
@@ -577,103 +615,6 @@ object DecisionTree extends Serializable with Logging {
}
/**
- * Find bin for one (labeledPoint, feature).
- */
- def findBin(
- featureIndex: Int,
- labeledPoint: LabeledPoint,
- isFeatureContinuous: Boolean,
- isSpaceSufficientForAllCategoricalSplits: Boolean): Int = {
- val binForFeatures = bins(featureIndex)
- val feature = labeledPoint.features(featureIndex)
-
- /**
- * Binary search helper method for continuous feature.
- */
- def binarySearchForBins(): Int = {
- var left = 0
- var right = binForFeatures.length - 1
- while (left <= right) {
- val mid = left + (right - left) / 2
- val bin = binForFeatures(mid)
- val lowThreshold = bin.lowSplit.threshold
- val highThreshold = bin.highSplit.threshold
- if ((lowThreshold < feature) && (highThreshold >= feature)) {
- return mid
- }
- else if (lowThreshold >= feature) {
- right = mid - 1
- }
- else {
- left = mid + 1
- }
- }
- -1
- }
-
- /**
- * Sequential search helper method to find bin for categorical feature in multiclass
- * classification. The category is returned since each category can belong to multiple
- * splits. The actual left/right child allocation per split is performed in the
- * sequential phase of the bin aggregate operation.
- */
- def sequentialBinSearchForUnorderedCategoricalFeatureInClassification(): Int = {
- labeledPoint.features(featureIndex).toInt
- }
-
- /**
- * Sequential search helper method to find bin for categorical feature
- * (for classification and regression).
- */
- def sequentialBinSearchForOrderedCategoricalFeature(): Int = {
- val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
- val featureValue = labeledPoint.features(featureIndex)
- var binIndex = 0
- while (binIndex < featureCategories) {
- val bin = bins(featureIndex)(binIndex)
- val categories = bin.highSplit.categories
- if (categories.contains(featureValue)) {
- return binIndex
- }
- binIndex += 1
- }
- if (featureValue < 0 || featureValue >= featureCategories) {
- throw new IllegalArgumentException(
- s"DecisionTree given invalid data:" +
- s" Feature $featureIndex is categorical with values in" +
- s" {0,...,${featureCategories - 1}," +
- s" but a data point gives it value $featureValue.\n" +
- " Bad data point: " + labeledPoint.toString)
- }
- -1
- }
-
- if (isFeatureContinuous) {
- // Perform binary search for finding bin for continuous features.
- val binIndex = binarySearchForBins()
- if (binIndex == -1) {
- throw new UnknownError("no bin was found for continuous variable.")
- }
- binIndex
- } else {
- // Perform sequential search to find bin for categorical features.
- val binIndex = {
- val isUnorderedFeature =
- isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits
- if (isUnorderedFeature) {
- sequentialBinSearchForUnorderedCategoricalFeatureInClassification()
- } else {
- sequentialBinSearchForOrderedCategoricalFeature()
- }
- }
- if (binIndex == -1) {
- throw new UnknownError("no bin was found for categorical variable.")
- }
- binIndex
- }
- }
-
- /**
* Finds bins for all nodes (and all features) at a given level.
* For l nodes, k features the storage is as follows:
* label, b_11, b_12, .. , b_1k, b_21, b_22, .. , b_2k, b_l1, b_l2, .. , b_lk,
@@ -689,17 +630,17 @@ object DecisionTree extends Serializable with Logging {
* bin index for this labeledPoint
* (or InvalidBinIndex if labeledPoint is not handled by this node)
*/
- def findBinsForLevel(labeledPoint: LabeledPoint): Array[Double] = {
+ def findBinsForLevel(treePoint: TreePoint): Array[Double] = {
// Calculate bin index and label per feature per node.
val arr = new Array[Double](1 + (numFeatures * numNodes))
// First element of the array is the label of the instance.
- arr(0) = labeledPoint.label
+ arr(0) = treePoint.label
// Iterate over nodes.
var nodeIndex = 0
while (nodeIndex < numNodes) {
val parentFilters = findParentFilters(nodeIndex)
// Find out whether the sample qualifies for the particular node.
- val sampleValid = isSampleValid(parentFilters, labeledPoint)
+ val sampleValid = isSampleValid(parentFilters, treePoint)
val shift = 1 + numFeatures * nodeIndex
if (!sampleValid) {
// Mark one bin as -1 is sufficient.
@@ -707,19 +648,7 @@ object DecisionTree extends Serializable with Logging {
} else {
var featureIndex = 0
while (featureIndex < numFeatures) {
- val featureInfo = strategy.categoricalFeaturesInfo.get(featureIndex)
- val isFeatureContinuous = featureInfo.isEmpty
- if (isFeatureContinuous) {
- arr(shift + featureIndex)
- = findBin(featureIndex, labeledPoint, isFeatureContinuous, false)
- } else {
- val featureCategories = featureInfo.get
- val isSpaceSufficientForAllCategoricalSplits
- = numBins > math.pow(2, featureCategories.toInt - 1) - 1
- arr(shift + featureIndex)
- = findBin(featureIndex, labeledPoint, isFeatureContinuous,
- isSpaceSufficientForAllCategoricalSplits)
- }
+ arr(shift + featureIndex) = treePoint.binnedFeatures(featureIndex)
featureIndex += 1
}
}
@@ -728,7 +657,8 @@ object DecisionTree extends Serializable with Logging {
arr
}
- // Find feature bins for all nodes at a level.
+ // Find feature bins for all nodes at a level.
+ timer.start("aggregation")
val binMappedRDD = input.map(x => findBinsForLevel(x))
/**
@@ -830,6 +760,8 @@ object DecisionTree extends Serializable with Logging {
}
}
+ val rightChildShift = numClasses * numBins * numFeatures * numNodes
+
/**
* Helper for binSeqOp.
*
@@ -853,7 +785,6 @@ object DecisionTree extends Serializable with Logging {
val validSignalIndex = 1 + numFeatures * nodeIndex
val isSampleValidForNode = arr(validSignalIndex) != InvalidBinIndex
if (isSampleValidForNode) {
- val rightChildShift = numClasses * numBins * numFeatures * numNodes
// actual class label
val label = arr(0)
// Iterate over all features.
@@ -912,7 +843,7 @@ object DecisionTree extends Serializable with Logging {
val aggIndex = aggShift + 3 * featureIndex * numBins + arr(arrIndex).toInt * 3
agg(aggIndex) = agg(aggIndex) + 1
agg(aggIndex + 1) = agg(aggIndex + 1) + label
- agg(aggIndex + 2) = agg(aggIndex + 2) + label*label
+ agg(aggIndex + 2) = agg(aggIndex + 2) + label * label
featureIndex += 1
}
}
@@ -977,6 +908,7 @@ object DecisionTree extends Serializable with Logging {
val binAggregates = {
binMappedRDD.aggregate(Array.fill[Double](binAggregateLength)(0))(binSeqOp,binCombOp)
}
+ timer.stop("aggregation")
logDebug("binAggregates.length = " + binAggregates.length)
/**
@@ -1031,10 +963,17 @@ object DecisionTree extends Serializable with Logging {
def indexOfLargestArrayElement(array: Array[Double]): Int = {
val result = array.foldLeft(-1, Double.MinValue, 0) {
case ((maxIndex, maxValue, currentIndex), currentValue) =>
- if(currentValue > maxValue) (currentIndex, currentValue, currentIndex + 1)
- else (maxIndex, maxValue, currentIndex + 1)
+ if (currentValue > maxValue) {
+ (currentIndex, currentValue, currentIndex + 1)
+ } else {
+ (maxIndex, maxValue, currentIndex + 1)
+ }
+ }
+ if (result._1 < 0) {
+ throw new RuntimeException("DecisionTree internal error:" +
+ " calculateGainForSplit failed in indexOfLargestArrayElement")
}
- if (result._1 < 0) 0 else result._1
+ result._1
}
val predict = indexOfLargestArrayElement(leftRightCounts)
@@ -1057,6 +996,7 @@ object DecisionTree extends Serializable with Logging {
val gain = impurity - leftWeight * leftImpurity - rightWeight * rightImpurity
new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, predict, prob)
+
case Regression =>
val leftCount = leftNodeAgg(featureIndex)(splitIndex)(0)
val leftSum = leftNodeAgg(featureIndex)(splitIndex)(1)
@@ -1280,16 +1220,42 @@ object DecisionTree extends Serializable with Logging {
nodeImpurity: Double): Array[Array[InformationGainStats]] = {
val gains = Array.ofDim[InformationGainStats](numFeatures, numBins - 1)
- for (featureIndex <- 0 until numFeatures) {
- for (splitIndex <- 0 until numBins - 1) {
+ var featureIndex = 0
+ while (featureIndex < numFeatures) {
+ val numSplitsForFeature = getNumSplitsForFeature(featureIndex)
+ var splitIndex = 0
+ while (splitIndex < numSplitsForFeature) {
gains(featureIndex)(splitIndex) = calculateGainForSplit(leftNodeAgg, featureIndex,
splitIndex, rightNodeAgg, nodeImpurity)
+ splitIndex += 1
}
+ featureIndex += 1
}
gains
}
/**
+ * Get the number of splits for a feature.
+ */
+ def getNumSplitsForFeature(featureIndex: Int): Int = {
+ val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
+ if (isFeatureContinuous) {
+ numBins - 1
+ } else {
+ // Categorical feature
+ val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
+ val isSpaceSufficientForAllCategoricalSplits =
+ numBins > math.pow(2, featureCategories.toInt - 1) - 1
+ if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) {
+ math.pow(2.0, featureCategories - 1).toInt - 1
+ } else {
+ // Ordered features
+ featureCategories
+ }
+ }
+ }
+
+ /**
* Find the best split for a node.
* @param binData Bin data slice for this node, given by getBinDataForNode.
* @param nodeImpurity impurity of the top node
@@ -1307,7 +1273,7 @@ object DecisionTree extends Serializable with Logging {
// Calculate gains for all splits.
val gains = calculateGainsForAllNodeSplits(leftNodeAgg, rightNodeAgg, nodeImpurity)
- val (bestFeatureIndex,bestSplitIndex, gainStats) = {
+ val (bestFeatureIndex, bestSplitIndex, gainStats) = {
// Initialize with infeasible values.
var bestFeatureIndex = Int.MinValue
var bestSplitIndex = Int.MinValue
@@ -1317,22 +1283,8 @@ object DecisionTree extends Serializable with Logging {
while (featureIndex < numFeatures) {
// Iterate over all splits.
var splitIndex = 0
- val maxSplitIndex: Double = {
- val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
- if (isFeatureContinuous) {
- numBins - 1
- } else { // Categorical feature
- val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
- val isSpaceSufficientForAllCategoricalSplits
- = numBins > math.pow(2, featureCategories.toInt - 1) - 1
- if (isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits) {
- math.pow(2.0, featureCategories - 1).toInt - 1
- } else { // Binary classification
- featureCategories
- }
- }
- }
- while (splitIndex < maxSplitIndex) {
+ val numSplitsForFeature = getNumSplitsForFeature(featureIndex)
+ while (splitIndex < numSplitsForFeature) {
val gainStats = gains(featureIndex)(splitIndex)
if (gainStats.gain > bestGainStats.gain) {
bestGainStats = gainStats
@@ -1383,6 +1335,7 @@ object DecisionTree extends Serializable with Logging {
}
// Calculate best splits for all nodes at a given level
+ timer.start("chooseSplits")
val bestSplits = new Array[(Split, InformationGainStats)](numNodes)
// Iterating over all nodes at this level
var node = 0
@@ -1395,6 +1348,8 @@ object DecisionTree extends Serializable with Logging {
bestSplits(node) = binsToBestSplit(binsForNode, parentNodeImpurity)
node += 1
}
+ timer.stop("chooseSplits")
+
bestSplits
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index f31a503608..cfc8192a85 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -27,22 +27,30 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
/**
* :: Experimental ::
* Stores all the configuration options for tree construction
- * @param algo classification or regression
- * @param impurity criterion used for information gain calculation
+ * @param algo Learning goal. Supported:
+ * [[org.apache.spark.mllib.tree.configuration.Algo.Classification]],
+ * [[org.apache.spark.mllib.tree.configuration.Algo.Regression]]
+ * @param impurity Criterion used for information gain calculation.
+ * Supported for Classification: [[org.apache.spark.mllib.tree.impurity.Gini]],
+ * [[org.apache.spark.mllib.tree.impurity.Entropy]].
+ * Supported for Regression: [[org.apache.spark.mllib.tree.impurity.Variance]].
* @param maxDepth Maximum depth of the tree.
* E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
- * @param numClassesForClassification number of classes for classification. Default value is 2
- * leads to binary classification
- * @param maxBins maximum number of bins used for splitting features
- * @param quantileCalculationStrategy algorithm for calculating quantiles
+ * @param numClassesForClassification Number of classes for classification.
+ * (Ignored for regression.)
+ * Default value is 2 (binary classification).
+ * @param maxBins Maximum number of bins used for discretizing continuous features and
+ * for choosing how to split on features at each node.
+ * More bins give higher granularity.
+ * @param quantileCalculationStrategy Algorithm for calculating quantiles. Supported:
+ * [[org.apache.spark.mllib.tree.configuration.QuantileStrategy.Sort]]
* @param categoricalFeaturesInfo A map storing information about the categorical variables and the
* number of discrete values they take. For example, an entry (n ->
* k) implies the feature n is categorical with k categories 0,
* 1, 2, ... , k-1. It's important to note that features are
* zero-indexed.
- * @param maxMemoryInMB maximum memory in MB allocated to histogram aggregation. Default value is
+ * @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is
* 128 MB.
- *
*/
@Experimental
class Strategy (
@@ -64,20 +72,7 @@ class Strategy (
= isMulticlassClassification && (categoricalFeaturesInfo.size > 0)
/**
- * Java-friendly constructor.
- *
- * @param algo classification or regression
- * @param impurity criterion used for information gain calculation
- * @param maxDepth Maximum depth of the tree.
- * E.g., depth 0 means 1 leaf node; depth 1 means 1 internal node + 2 leaf nodes.
- * @param numClassesForClassification number of classes for classification. Default value is 2
- * leads to binary classification
- * @param maxBins maximum number of bins used for splitting features
- * @param categoricalFeaturesInfo A map storing information about the categorical variables and
- * the number of discrete values they take. For example, an entry
- * (n -> k) implies the feature n is categorical with k categories
- * 0, 1, 2, ... , k-1. It's important to note that features are
- * zero-indexed.
+ * Java-friendly constructor for [[org.apache.spark.mllib.tree.configuration.Strategy]]
*/
def this(
algo: Algo,
@@ -90,6 +85,10 @@ class Strategy (
categoricalFeaturesInfo.asInstanceOf[java.util.Map[Int, Int]].asScala.toMap)
}
+ /**
+ * Check validity of parameters.
+ * Throws exception if invalid.
+ */
private[tree] def assertValid(): Unit = {
algo match {
case Classification =>
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala
new file mode 100644
index 0000000000..d215d68c42
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala
@@ -0,0 +1,73 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.impl
+
+import scala.collection.mutable.{HashMap => MutableHashMap}
+
+import org.apache.spark.annotation.Experimental
+
+/**
+ * Time tracker implementation which holds labeled timers.
+ */
+@Experimental
+private[tree] class TimeTracker extends Serializable {
+
+ private val starts: MutableHashMap[String, Long] = new MutableHashMap[String, Long]()
+
+ private val totals: MutableHashMap[String, Long] = new MutableHashMap[String, Long]()
+
+ /**
+ * Starts a new timer, or re-starts a stopped timer.
+ */
+ def start(timerLabel: String): Unit = {
+ val currentTime = System.nanoTime()
+ if (starts.contains(timerLabel)) {
+ throw new RuntimeException(s"TimeTracker.start(timerLabel) called again on" +
+ s" timerLabel = $timerLabel before that timer was stopped.")
+ }
+ starts(timerLabel) = currentTime
+ }
+
+ /**
+ * Stops a timer and returns the elapsed time in seconds.
+ */
+ def stop(timerLabel: String): Double = {
+ val currentTime = System.nanoTime()
+ if (!starts.contains(timerLabel)) {
+ throw new RuntimeException(s"TimeTracker.stop(timerLabel) called on" +
+ s" timerLabel = $timerLabel, but that timer was not started.")
+ }
+ val elapsed = currentTime - starts(timerLabel)
+ starts.remove(timerLabel)
+ if (totals.contains(timerLabel)) {
+ totals(timerLabel) += elapsed
+ } else {
+ totals(timerLabel) = elapsed
+ }
+ elapsed / 1e9
+ }
+
+ /**
+ * Print all timing results in seconds.
+ */
+ override def toString: String = {
+ totals.map { case (label, elapsed) =>
+ s" $label: ${elapsed / 1e9}"
+ }.mkString("\n")
+ }
+}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
new file mode 100644
index 0000000000..ccac1031fd
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
@@ -0,0 +1,201 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.impl
+
+import org.apache.spark.mllib.regression.LabeledPoint
+import org.apache.spark.mllib.tree.configuration.Strategy
+import org.apache.spark.mllib.tree.model.Bin
+import org.apache.spark.rdd.RDD
+
+
+/**
+ * Internal representation of LabeledPoint for DecisionTree.
+ * This bins feature values based on a subsampled of data as follows:
+ * (a) Continuous features are binned into ranges.
+ * (b) Unordered categorical features are binned based on subsets of feature values.
+ * "Unordered categorical features" are categorical features with low arity used in
+ * multiclass classification.
+ * (c) Ordered categorical features are binned based on feature values.
+ * "Ordered categorical features" are categorical features with high arity,
+ * or any categorical feature used in regression or binary classification.
+ *
+ * @param label Label from LabeledPoint
+ * @param binnedFeatures Binned feature values.
+ * Same length as LabeledPoint.features, but values are bin indices.
+ */
+private[tree] class TreePoint(val label: Double, val binnedFeatures: Array[Int])
+ extends Serializable {
+}
+
+private[tree] object TreePoint {
+
+ /**
+ * Convert an input dataset into its TreePoint representation,
+ * binning feature values in preparation for DecisionTree training.
+ * @param input Input dataset.
+ * @param strategy DecisionTree training info, used for dataset metadata.
+ * @param bins Bins for features, of size (numFeatures, numBins).
+ * @return TreePoint dataset representation
+ */
+ def convertToTreeRDD(
+ input: RDD[LabeledPoint],
+ strategy: Strategy,
+ bins: Array[Array[Bin]]): RDD[TreePoint] = {
+ input.map { x =>
+ TreePoint.labeledPointToTreePoint(x, strategy.isMulticlassClassification, bins,
+ strategy.categoricalFeaturesInfo)
+ }
+ }
+
+ /**
+ * Convert one LabeledPoint into its TreePoint representation.
+ * @param bins Bins for features, of size (numFeatures, numBins).
+ * @param categoricalFeaturesInfo Map over categorical features: feature index --> feature arity
+ */
+ private def labeledPointToTreePoint(
+ labeledPoint: LabeledPoint,
+ isMulticlassClassification: Boolean,
+ bins: Array[Array[Bin]],
+ categoricalFeaturesInfo: Map[Int, Int]): TreePoint = {
+
+ val numFeatures = labeledPoint.features.size
+ val numBins = bins(0).size
+ val arr = new Array[Int](numFeatures)
+ var featureIndex = 0
+ while (featureIndex < numFeatures) {
+ val featureInfo = categoricalFeaturesInfo.get(featureIndex)
+ val isFeatureContinuous = featureInfo.isEmpty
+ if (isFeatureContinuous) {
+ arr(featureIndex) = findBin(featureIndex, labeledPoint, isFeatureContinuous, false,
+ bins, categoricalFeaturesInfo)
+ } else {
+ val featureCategories = featureInfo.get
+ val isSpaceSufficientForAllCategoricalSplits
+ = numBins > math.pow(2, featureCategories.toInt - 1) - 1
+ val isUnorderedFeature =
+ isMulticlassClassification && isSpaceSufficientForAllCategoricalSplits
+ arr(featureIndex) = findBin(featureIndex, labeledPoint, isFeatureContinuous,
+ isUnorderedFeature, bins, categoricalFeaturesInfo)
+ }
+ featureIndex += 1
+ }
+
+ new TreePoint(labeledPoint.label, arr)
+ }
+
+ /**
+ * Find bin for one (labeledPoint, feature).
+ *
+ * @param isUnorderedFeature (only applies if feature is categorical)
+ * @param bins Bins for features, of size (numFeatures, numBins).
+ * @param categoricalFeaturesInfo Map over categorical features: feature index --> feature arity
+ */
+ private def findBin(
+ featureIndex: Int,
+ labeledPoint: LabeledPoint,
+ isFeatureContinuous: Boolean,
+ isUnorderedFeature: Boolean,
+ bins: Array[Array[Bin]],
+ categoricalFeaturesInfo: Map[Int, Int]): Int = {
+
+ /**
+ * Binary search helper method for continuous feature.
+ */
+ def binarySearchForBins(): Int = {
+ val binForFeatures = bins(featureIndex)
+ val feature = labeledPoint.features(featureIndex)
+ var left = 0
+ var right = binForFeatures.length - 1
+ while (left <= right) {
+ val mid = left + (right - left) / 2
+ val bin = binForFeatures(mid)
+ val lowThreshold = bin.lowSplit.threshold
+ val highThreshold = bin.highSplit.threshold
+ if ((lowThreshold < feature) && (highThreshold >= feature)) {
+ return mid
+ } else if (lowThreshold >= feature) {
+ right = mid - 1
+ } else {
+ left = mid + 1
+ }
+ }
+ -1
+ }
+
+ /**
+ * Sequential search helper method to find bin for categorical feature in multiclass
+ * classification. The category is returned since each category can belong to multiple
+ * splits. The actual left/right child allocation per split is performed in the
+ * sequential phase of the bin aggregate operation.
+ */
+ def sequentialBinSearchForUnorderedCategoricalFeatureInClassification(): Int = {
+ labeledPoint.features(featureIndex).toInt
+ }
+
+ /**
+ * Sequential search helper method to find bin for categorical feature
+ * (for classification and regression).
+ */
+ def sequentialBinSearchForOrderedCategoricalFeature(): Int = {
+ val featureCategories = categoricalFeaturesInfo(featureIndex)
+ val featureValue = labeledPoint.features(featureIndex)
+ var binIndex = 0
+ while (binIndex < featureCategories) {
+ val bin = bins(featureIndex)(binIndex)
+ val categories = bin.highSplit.categories
+ if (categories.contains(featureValue)) {
+ return binIndex
+ }
+ binIndex += 1
+ }
+ if (featureValue < 0 || featureValue >= featureCategories) {
+ throw new IllegalArgumentException(
+ s"DecisionTree given invalid data:" +
+ s" Feature $featureIndex is categorical with values in" +
+ s" {0,...,${featureCategories - 1}," +
+ s" but a data point gives it value $featureValue.\n" +
+ " Bad data point: " + labeledPoint.toString)
+ }
+ -1
+ }
+
+ if (isFeatureContinuous) {
+ // Perform binary search for finding bin for continuous features.
+ val binIndex = binarySearchForBins()
+ if (binIndex == -1) {
+ throw new RuntimeException("No bin was found for continuous feature." +
+ " This error can occur when given invalid data values (such as NaN)." +
+ s" Feature index: $featureIndex. Feature value: ${labeledPoint.features(featureIndex)}")
+ }
+ binIndex
+ } else {
+ // Perform sequential search to find bin for categorical features.
+ val binIndex = if (isUnorderedFeature) {
+ sequentialBinSearchForUnorderedCategoricalFeatureInClassification()
+ } else {
+ sequentialBinSearchForOrderedCategoricalFeature()
+ }
+ if (binIndex == -1) {
+ throw new RuntimeException("No bin was found for categorical feature." +
+ " This error can occur when given invalid data values (such as NaN)." +
+ s" Feature index: $featureIndex. Feature value: ${labeledPoint.features(featureIndex)}")
+ }
+ binIndex
+ }
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 70ca7c8a26..a5c49a38dc 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -21,11 +21,12 @@ import scala.collection.JavaConverters._
import org.scalatest.FunSuite
-import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
-import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Filter, Split}
-import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy}
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
+import org.apache.spark.mllib.tree.configuration.{FeatureType, Strategy}
+import org.apache.spark.mllib.tree.impl.TreePoint
+import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
+import org.apache.spark.mllib.tree.model.{DecisionTreeModel, Filter, Split}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.util.LocalSparkContext
import org.apache.spark.mllib.regression.LabeledPoint
@@ -41,7 +42,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
prediction != expected.label
}
val accuracy = (input.length - numOffPredictions).toDouble / input.length
- assert(accuracy >= requiredAccuracy)
+ assert(accuracy >= requiredAccuracy,
+ s"validateClassifier calculated accuracy $accuracy but required $requiredAccuracy.")
}
def validateRegressor(
@@ -54,7 +56,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
err * err
}.sum
val mse = squaredError / input.length
- assert(mse <= requiredMSE)
+ assert(mse <= requiredMSE, s"validateRegressor calculated MSE $mse but required $requiredMSE.")
}
test("split and bin calculation") {
@@ -427,7 +429,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
val (splits, bins) = DecisionTree.findSplitsBins(rdd, strategy)
- val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
+ val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0,
Array[List[Filter]](), splits, bins, 10)
val split = bestSplits(0)._1
@@ -454,7 +457,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
maxBins = 100,
categoricalFeaturesInfo = Map(0 -> 3, 1-> 3))
val (splits, bins) = DecisionTree.findSplitsBins(rdd,strategy)
- val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
+ val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0,
Array[List[Filter]](), splits, bins, 10)
val split = bestSplits(0)._1
@@ -499,7 +503,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(splits(0).length === 99)
assert(bins(0).length === 100)
- val bestSplits = DecisionTree.findBestSplits(rdd, new Array(7), strategy, 0,
+ val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(7), strategy, 0,
Array[List[Filter]](), splits, bins, 10)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
@@ -521,7 +526,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(splits(0).length === 99)
assert(bins(0).length === 100)
- val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
+ val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
+ val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0,
Array[List[Filter]](), splits, bins, 10)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
@@ -544,7 +550,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(splits(0).length === 99)
assert(bins(0).length === 100)
- val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
+ val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
+ val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0,
Array[List[Filter]](), splits, bins, 10)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
@@ -567,7 +574,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(splits(0).length === 99)
assert(bins(0).length === 100)
- val bestSplits = DecisionTree.findBestSplits(rdd, Array(0.0), strategy, 0,
+ val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
+ val bestSplits = DecisionTree.findBestSplits(treeInput, Array(0.0), strategy, 0,
Array[List[Filter]](), splits, bins, 10)
assert(bestSplits.length === 1)
assert(bestSplits(0)._1.feature === 0)
@@ -596,7 +604,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
val parentImpurities = Array(0.5, 0.5, 0.5)
// Single group second level tree construction.
- val bestSplits = DecisionTree.findBestSplits(rdd, parentImpurities, strategy, 1, filters,
+ val treeInput = TreePoint.convertToTreeRDD(rdd, strategy, bins)
+ val bestSplits = DecisionTree.findBestSplits(treeInput, parentImpurities, strategy, 1, filters,
splits, bins, 10)
assert(bestSplits.length === 2)
assert(bestSplits(0)._2.gain > 0)
@@ -604,7 +613,7 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
// maxLevelForSingleGroup parameter is set to 0 to force splitting into groups for second
// level tree construction.
- val bestSplitsWithGroups = DecisionTree.findBestSplits(rdd, parentImpurities, strategy, 1,
+ val bestSplitsWithGroups = DecisionTree.findBestSplits(treeInput, parentImpurities, strategy, 1,
filters, splits, bins, 0)
assert(bestSplitsWithGroups.length === 2)
assert(bestSplitsWithGroups(0)._2.gain > 0)
@@ -630,7 +639,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 3, 1 -> 3))
assert(strategy.isMulticlassClassification)
val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
- val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0,
+ val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins)
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0,
Array[List[Filter]](), splits, bins, 10)
assert(bestSplits.length === 1)
@@ -689,7 +699,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(model.depth === 1)
val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
- val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0,
+ val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins)
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0,
Array[List[Filter]](), splits, bins, 10)
assert(bestSplits.length === 1)
@@ -714,7 +725,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
validateClassifier(model, arr, 0.9)
val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
- val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0,
+ val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins)
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0,
Array[List[Filter]](), splits, bins, 10)
assert(bestSplits.length === 1)
@@ -738,7 +750,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
validateClassifier(model, arr, 0.9)
val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
- val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0,
+ val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins)
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0,
Array[List[Filter]](), splits, bins, 10)
assert(bestSplits.length === 1)
@@ -757,7 +770,8 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
numClassesForClassification = 3, categoricalFeaturesInfo = Map(0 -> 10, 1 -> 10))
assert(strategy.isMulticlassClassification)
val (splits, bins) = DecisionTree.findSplitsBins(input, strategy)
- val bestSplits = DecisionTree.findBestSplits(input, new Array(31), strategy, 0,
+ val treeInput = TreePoint.convertToTreeRDD(input, strategy, bins)
+ val bestSplits = DecisionTree.findBestSplits(treeInput, new Array(31), strategy, 0,
Array[List[Filter]](), splits, bins, 10)
assert(bestSplits.length === 1)