diff options
author | sethah <seth.hendrickson16@gmail.com> | 2016-10-10 17:04:11 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-10-10 17:04:11 -0700 |
commit | 03c40202f36ea9fc93071b79fed21ed3f2190ba1 (patch) | |
tree | 98637ce4a1d323fec0ac849095a41ac3690f4c64 /mllib/src/main | |
parent | 29f186bfdf929b1e8ffd8e33ee37b76d5dc5af53 (diff) | |
download | spark-03c40202f36ea9fc93071b79fed21ed3f2190ba1.tar.gz spark-03c40202f36ea9fc93071b79fed21ed3f2190ba1.tar.bz2 spark-03c40202f36ea9fc93071b79fed21ed3f2190ba1.zip |
[SPARK-14610][ML] Remove superfluous split for continuous features in decision tree training
## What changes were proposed in this pull request?
A nonsensical split is produced from method `findSplitsForContinuousFeature` for decision trees. This PR removes the superfluous split and updates unit tests accordingly. Additionally, an assertion to check that the number of found splits is `> 0` is removed, and instead features with zero possible splits are ignored.
## How was this patch tested?
A unit test was added to check that finding splits for a constant feature produces an empty array.
Author: sethah <seth.hendrickson16@gmail.com>
Closes #12374 from sethah/SPARK-14610.
Diffstat (limited to 'mllib/src/main')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala | 31 |
1 files changed, 15 insertions, 16 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 0b7ad92b3c..b504f411d2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -705,14 +705,17 @@ private[spark] object RandomForest extends Logging { node.stats } + val validFeatureSplits = + Range(0, binAggregates.metadata.numFeaturesPerNode).view.map { featureIndexIdx => + featuresForNode.map(features => (featureIndexIdx, features(featureIndexIdx))) + .getOrElse((featureIndexIdx, featureIndexIdx)) + }.withFilter { case (_, featureIndex) => + binAggregates.metadata.numSplits(featureIndex) != 0 + } + // For each (feature, split), calculate the gain, and select the best (feature, split). val (bestSplit, bestSplitStats) = - Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx => - val featureIndex = if (featuresForNode.nonEmpty) { - featuresForNode.get.apply(featureIndexIdx) - } else { - featureIndexIdx - } + validFeatureSplits.map { case (featureIndexIdx, featureIndex) => val numSplits = binAggregates.metadata.numSplits(featureIndex) if (binAggregates.metadata.isContinuous(featureIndex)) { // Cumulative sum (scanLeft) of bin statistics. @@ -966,7 +969,7 @@ private[spark] object RandomForest extends Logging { * NOTE: `metadata.numbins` will be changed accordingly * if there are not enough splits to be found * @param featureIndex feature index to find splits - * @return array of splits + * @return array of split thresholds */ private[tree] def findSplitsForContinuousFeature( featureSamples: Iterable[Double], @@ -975,7 +978,9 @@ private[spark] object RandomForest extends Logging { require(metadata.isContinuous(featureIndex), "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.") - val splits = { + val splits = if (featureSamples.isEmpty) { + Array.empty[Double] + } else { val numSplits = metadata.numSplits(featureIndex) // get count for each distinct value @@ -987,9 +992,9 @@ private[spark] object RandomForest extends Logging { val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray // if possible splits is not enough or just enough, just return all possible splits - val possibleSplits = valueCounts.length + val possibleSplits = valueCounts.length - 1 if (possibleSplits <= numSplits) { - valueCounts.map(_._1) + valueCounts.map(_._1).init } else { // stride between splits val stride: Double = numSamples.toDouble / (numSplits + 1) @@ -1023,12 +1028,6 @@ private[spark] object RandomForest extends Logging { splitsBuilder.result() } } - - // TODO: Do not fail; just ignore the useless feature. - assert(splits.length > 0, - s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." + - " Please remove this feature and then try again.") - splits } |