aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorsethah <seth.hendrickson16@gmail.com>2016-03-20 12:31:28 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-03-20 12:31:28 -0700
commit811a5247227b5c68e6cd74c0a88d809862184507 (patch)
tree2c70b9bda4bc59c5b2acbfc7bf915330c8d25566
parentd630a203d696cb9457ca05dd40db3faa81f0ad64 (diff)
downloadspark-811a5247227b5c68e6cd74c0a88d809862184507.tar.gz
spark-811a5247227b5c68e6cd74c0a88d809862184507.tar.bz2
spark-811a5247227b5c68e6cd74c0a88d809862184507.zip
[SPARK-12182][ML] Distributed binning for trees in spark.ml
This PR changes the `findSplits` method in spark.ml to perform split calculations on the workers. This PR is meant to copy [PR-8246](https://github.com/apache/spark/pull/8246) which added the same feature for MLlib. Author: sethah <seth.hendrickson16@gmail.com> Closes #10231 from sethah/SPARK-12182.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala110
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala11
2 files changed, 60 insertions, 61 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index dd9a5f261f..afbb9d974d 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -477,8 +477,8 @@ private[ml] object RandomForest extends Logging {
// Construct a nodeStatsAggregators array to hold node aggregate stats,
// each node will have a nodeStatsAggregator
val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
- val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
- Some(nodeToFeatures(nodeIndex))
+ val featuresForNode = nodeToFeaturesBc.value.map { nodeToFeatures =>
+ nodeToFeatures(nodeIndex)
}
new DTStatsAggregator(metadata, featuresForNode)
}
@@ -827,8 +827,8 @@ private[ml] object RandomForest extends Logging {
val numFeatures = metadata.numFeatures
// Sample the input only if there are continuous features.
- val hasContinuousFeatures = Range(0, numFeatures).exists(metadata.isContinuous)
- val sampledInput = if (hasContinuousFeatures) {
+ val continuousFeatures = Range(0, numFeatures).filter(metadata.isContinuous)
+ val sampledInput = if (continuousFeatures.nonEmpty) {
// Calculate the number of samples for approximate quantile calculation.
val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000)
val fraction = if (requiredSamples < metadata.numExamples) {
@@ -837,58 +837,57 @@ private[ml] object RandomForest extends Logging {
1.0
}
logDebug("fraction of data used for calculating quantiles = " + fraction)
- input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()).collect()
+ input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt())
} else {
- new Array[LabeledPoint](0)
+ input.sparkContext.emptyRDD[LabeledPoint]
}
- val splits = new Array[Array[Split]](numFeatures)
-
- // Find all splits.
- // Iterate over all features.
- var featureIndex = 0
- while (featureIndex < numFeatures) {
- if (metadata.isContinuous(featureIndex)) {
- val featureSamples = sampledInput.map(_.features(featureIndex))
- val featureSplits = findSplitsForContinuousFeature(featureSamples, metadata, featureIndex)
+ findSplitsBinsBySorting(sampledInput, metadata, continuousFeatures)
+ }
- val numSplits = featureSplits.length
- logDebug(s"featureIndex = $featureIndex, numSplits = $numSplits")
- splits(featureIndex) = new Array[Split](numSplits)
+ private def findSplitsBinsBySorting(
+ input: RDD[LabeledPoint],
+ metadata: DecisionTreeMetadata,
+ continuousFeatures: IndexedSeq[Int]): Array[Array[Split]] = {
+
+ val continuousSplits: scala.collection.Map[Int, Array[Split]] = {
+ // reduce the parallelism for split computations when there are less
+ // continuous features than input partitions. this prevents tasks from
+ // being spun up that will definitely do no work.
+ val numPartitions = math.min(continuousFeatures.length, input.partitions.length)
+
+ input
+ .flatMap(point => continuousFeatures.map(idx => (idx, point.features(idx))))
+ .groupByKey(numPartitions)
+ .map { case (idx, samples) =>
+ val thresholds = findSplitsForContinuousFeature(samples, metadata, idx)
+ val splits: Array[Split] = thresholds.map(thresh => new ContinuousSplit(idx, thresh))
+ logDebug(s"featureIndex = $idx, numSplits = ${splits.length}")
+ (idx, splits)
+ }.collectAsMap()
+ }
- var splitIndex = 0
- while (splitIndex < numSplits) {
- val threshold = featureSplits(splitIndex)
- splits(featureIndex)(splitIndex) = new ContinuousSplit(featureIndex, threshold)
- splitIndex += 1
- }
- } else {
- // Categorical feature
- if (metadata.isUnordered(featureIndex)) {
- val numSplits = metadata.numSplits(featureIndex)
- val featureArity = metadata.featureArity(featureIndex)
- // TODO: Use an implicit representation mapping each category to a subset of indices.
- // I.e., track indices such that we can calculate the set of bins for which
- // feature value x splits to the left.
- // Unordered features
- // 2^(maxFeatureValue - 1) - 1 combinations
- splits(featureIndex) = new Array[Split](numSplits)
- var splitIndex = 0
- while (splitIndex < numSplits) {
- val categories: List[Double] =
- extractMultiClassCategories(splitIndex + 1, featureArity)
- splits(featureIndex)(splitIndex) =
- new CategoricalSplit(featureIndex, categories.toArray, featureArity)
- splitIndex += 1
- }
- } else {
- // Ordered features
- // Bins correspond to feature values, so we do not need to compute splits or bins
- // beforehand. Splits are constructed as needed during training.
- splits(featureIndex) = new Array[Split](0)
+ val numFeatures = metadata.numFeatures
+ val splits: Array[Array[Split]] = Array.tabulate(numFeatures) {
+ case i if metadata.isContinuous(i) =>
+ val split = continuousSplits(i)
+ metadata.setNumSplits(i, split.length)
+ split
+
+ case i if metadata.isCategorical(i) && metadata.isUnordered(i) =>
+ // Unordered features
+ // 2^(maxFeatureValue - 1) - 1 combinations
+ val featureArity = metadata.featureArity(i)
+ Array.tabulate[Split](metadata.numSplits(i)) { splitIndex =>
+ val categories = extractMultiClassCategories(splitIndex + 1, featureArity)
+ new CategoricalSplit(i, categories.toArray, featureArity)
}
- }
- featureIndex += 1
+
+ case i if metadata.isCategorical(i) =>
+ // Ordered features
+ // Bins correspond to feature values, so we do not need to compute splits or bins
+ // beforehand. Splits are constructed as needed during training.
+ Array.empty[Split]
}
splits
}
@@ -930,7 +929,7 @@ private[ml] object RandomForest extends Logging {
* @return array of splits
*/
private[tree] def findSplitsForContinuousFeature(
- featureSamples: Array[Double],
+ featureSamples: Iterable[Double],
metadata: DecisionTreeMetadata,
featureIndex: Int): Array[Double] = {
require(metadata.isContinuous(featureIndex),
@@ -940,8 +939,9 @@ private[ml] object RandomForest extends Logging {
val numSplits = metadata.numSplits(featureIndex)
// get count for each distinct value
- val valueCountMap = featureSamples.foldLeft(Map.empty[Double, Int]) { (m, x) =>
- m + ((x, m.getOrElse(x, 0) + 1))
+ val (valueCountMap, numSamples) = featureSamples.foldLeft((Map.empty[Double, Int], 0)) {
+ case ((m, cnt), x) =>
+ (m + ((x, m.getOrElse(x, 0) + 1)), cnt + 1)
}
// sort distinct values
val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
@@ -952,7 +952,7 @@ private[ml] object RandomForest extends Logging {
valueCounts.map(_._1)
} else {
// stride between splits
- val stride: Double = featureSamples.length.toDouble / (numSplits + 1)
+ val stride: Double = numSamples.toDouble / (numSplits + 1)
logDebug("stride = " + stride)
// iterate `valueCount` to find splits
@@ -988,8 +988,6 @@ private[ml] object RandomForest extends Logging {
assert(splits.length > 0,
s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." +
" Please remove this feature and then try again.")
- // set number of splits accordingly
- metadata.setNumSplits(featureIndex, splits.length)
splits
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index c0934d241f..8f02e098ac 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -1010,7 +1010,7 @@ object DecisionTree extends Serializable with Logging {
featureSamples: Iterable[Double]): (Int, (Array[Split], Array[Bin])) = {
val splits = {
val featureSplits = findSplitsForContinuousFeature(
- featureSamples.toArray,
+ featureSamples,
metadata,
featureIndex)
logDebug(s"featureIndex = $featureIndex, numSplits = ${featureSplits.length}")
@@ -1115,7 +1115,7 @@ object DecisionTree extends Serializable with Logging {
* @return Array of splits.
*/
private[tree] def findSplitsForContinuousFeature(
- featureSamples: Array[Double],
+ featureSamples: Iterable[Double],
metadata: DecisionTreeMetadata,
featureIndex: Int): Array[Double] = {
require(metadata.isContinuous(featureIndex),
@@ -1125,8 +1125,9 @@ object DecisionTree extends Serializable with Logging {
val numSplits = metadata.numSplits(featureIndex)
// get count for each distinct value
- val valueCountMap = featureSamples.foldLeft(Map.empty[Double, Int]) { (m, x) =>
- m + ((x, m.getOrElse(x, 0) + 1))
+ val (valueCountMap, numSamples) = featureSamples.foldLeft((Map.empty[Double, Int], 0)) {
+ case ((m, cnt), x) =>
+ (m + ((x, m.getOrElse(x, 0) + 1)), cnt + 1)
}
// sort distinct values
val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
@@ -1137,7 +1138,7 @@ object DecisionTree extends Serializable with Logging {
valueCounts.map(_._1)
} else {
// stride between splits
- val stride: Double = featureSamples.length.toDouble / (numSplits + 1)
+ val stride: Double = numSamples.toDouble / (numSplits + 1)
logDebug("stride = " + stride)
// iterate `valueCount` to find splits