aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main
diff options
context:
space:
mode:
authorsethah <seth.hendrickson16@gmail.com>2016-10-10 17:04:11 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-10-10 17:04:11 -0700
commit03c40202f36ea9fc93071b79fed21ed3f2190ba1 (patch)
tree98637ce4a1d323fec0ac849095a41ac3690f4c64 /mllib/src/main
parent29f186bfdf929b1e8ffd8e33ee37b76d5dc5af53 (diff)
downloadspark-03c40202f36ea9fc93071b79fed21ed3f2190ba1.tar.gz
spark-03c40202f36ea9fc93071b79fed21ed3f2190ba1.tar.bz2
spark-03c40202f36ea9fc93071b79fed21ed3f2190ba1.zip
[SPARK-14610][ML] Remove superfluous split for continuous features in decision tree training
## What changes were proposed in this pull request? A nonsensical split is produced from method `findSplitsForContinuousFeature` for decision trees. This PR removes the superfluous split and updates unit tests accordingly. Additionally, an assertion to check that the number of found splits is `> 0` is removed, and instead features with zero possible splits are ignored. ## How was this patch tested? A unit test was added to check that finding splits for a constant feature produces an empty array. Author: sethah <seth.hendrickson16@gmail.com> Closes #12374 from sethah/SPARK-14610.
Diffstat (limited to 'mllib/src/main')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala31
1 files changed, 15 insertions, 16 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index 0b7ad92b3c..b504f411d2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -705,14 +705,17 @@ private[spark] object RandomForest extends Logging {
node.stats
}
+ val validFeatureSplits =
+ Range(0, binAggregates.metadata.numFeaturesPerNode).view.map { featureIndexIdx =>
+ featuresForNode.map(features => (featureIndexIdx, features(featureIndexIdx)))
+ .getOrElse((featureIndexIdx, featureIndexIdx))
+ }.withFilter { case (_, featureIndex) =>
+ binAggregates.metadata.numSplits(featureIndex) != 0
+ }
+
// For each (feature, split), calculate the gain, and select the best (feature, split).
val (bestSplit, bestSplitStats) =
- Range(0, binAggregates.metadata.numFeaturesPerNode).map { featureIndexIdx =>
- val featureIndex = if (featuresForNode.nonEmpty) {
- featuresForNode.get.apply(featureIndexIdx)
- } else {
- featureIndexIdx
- }
+ validFeatureSplits.map { case (featureIndexIdx, featureIndex) =>
val numSplits = binAggregates.metadata.numSplits(featureIndex)
if (binAggregates.metadata.isContinuous(featureIndex)) {
// Cumulative sum (scanLeft) of bin statistics.
@@ -966,7 +969,7 @@ private[spark] object RandomForest extends Logging {
* NOTE: `metadata.numbins` will be changed accordingly
* if there are not enough splits to be found
* @param featureIndex feature index to find splits
- * @return array of splits
+ * @return array of split thresholds
*/
private[tree] def findSplitsForContinuousFeature(
featureSamples: Iterable[Double],
@@ -975,7 +978,9 @@ private[spark] object RandomForest extends Logging {
require(metadata.isContinuous(featureIndex),
"findSplitsForContinuousFeature can only be used to find splits for a continuous feature.")
- val splits = {
+ val splits = if (featureSamples.isEmpty) {
+ Array.empty[Double]
+ } else {
val numSplits = metadata.numSplits(featureIndex)
// get count for each distinct value
@@ -987,9 +992,9 @@ private[spark] object RandomForest extends Logging {
val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
// if possible splits is not enough or just enough, just return all possible splits
- val possibleSplits = valueCounts.length
+ val possibleSplits = valueCounts.length - 1
if (possibleSplits <= numSplits) {
- valueCounts.map(_._1)
+ valueCounts.map(_._1).init
} else {
// stride between splits
val stride: Double = numSamples.toDouble / (numSplits + 1)
@@ -1023,12 +1028,6 @@ private[spark] object RandomForest extends Logging {
splitsBuilder.result()
}
}
-
- // TODO: Do not fail; just ignore the useless feature.
- assert(splits.length > 0,
- s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." +
- " Please remove this feature and then try again.")
-
splits
}