aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorIlya Matiach <ilmat@microsoft.com>2017-01-24 10:25:12 -0800
committerJoseph K. Bradley <joseph@databricks.com>2017-01-24 10:25:12 -0800
commitd9783380ff0a6440117348dee3205826d0f9687e (patch)
treecee3600462a2955cdb957bb18663c6672de41a92 /mllib/src
parent59c184e028d79286ef490a448ae7f2536d8753d6 (diff)
downloadspark-d9783380ff0a6440117348dee3205826d0f9687e.tar.gz
spark-d9783380ff0a6440117348dee3205826d0f9687e.tar.bz2
spark-d9783380ff0a6440117348dee3205826d0f9687e.zip
[SPARK-18036][ML][MLLIB] Fixing decision trees handling edge cases
## What changes were proposed in this pull request? Decision trees/GBT/RF do not handle edge cases such as constant features or empty features. In the case of constant features we choose any arbitrary split instead of failing with a cryptic error message. In the case of empty features we fail with a better error message stating: DecisionTree requires number of features > 0, but was given an empty features vector Instead of the cryptic error message: java.lang.UnsupportedOperationException: empty.max ## How was this patch tested? Unit tests are added in the patch for: DecisionTreeRegressor GBTRegressor Random Forest Regressor Author: Ilya Matiach <ilmat@microsoft.com> Closes #16377 from imatiach-msft/ilmat/fix-decision-tree.
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala22
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala33
3 files changed, 51 insertions, 6 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
index bc3c86a57c..8a9dcb486b 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala
@@ -113,6 +113,8 @@ private[spark] object DecisionTreeMetadata extends Logging {
throw new IllegalArgumentException(s"DecisionTree requires size of input RDD > 0, " +
s"but was given by empty one.")
}
+ require(numFeatures > 0, s"DecisionTree requires number of features > 0, " +
+ s"but was given an empty features vector")
val numExamples = input.count()
val numClasses = strategy.algo match {
case Classification => strategy.numClasses
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index a61ea374cb..008dd19c24 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -714,7 +714,7 @@ private[spark] object RandomForest extends Logging {
}
// For each (feature, split), calculate the gain, and select the best (feature, split).
- val (bestSplit, bestSplitStats) =
+ val splitsAndImpurityInfo =
validFeatureSplits.map { case (featureIndexIdx, featureIndex) =>
val numSplits = binAggregates.metadata.numSplits(featureIndex)
if (binAggregates.metadata.isContinuous(featureIndex)) {
@@ -828,8 +828,26 @@ private[spark] object RandomForest extends Logging {
new CategoricalSplit(featureIndex, categoriesForSplit.toArray, numCategories)
(bestFeatureSplit, bestFeatureGainStats)
}
- }.maxBy(_._2.gain)
+ }
+ val (bestSplit, bestSplitStats) =
+ if (splitsAndImpurityInfo.isEmpty) {
+ // If no valid splits for features, then this split is invalid,
+ // return invalid information gain stats. Take any split and continue.
+ // Splits is empty, so arbitrarily choose to split on any threshold
+ val dummyFeatureIndex = featuresForNode.map(_.head).getOrElse(0)
+ val parentImpurityCalculator = binAggregates.getParentImpurityCalculator()
+ if (binAggregates.metadata.isContinuous(dummyFeatureIndex)) {
+ (new ContinuousSplit(dummyFeatureIndex, 0),
+ ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator))
+ } else {
+ val numCategories = binAggregates.metadata.featureArity(dummyFeatureIndex)
+ (new CategoricalSplit(dummyFeatureIndex, Array(), numCategories),
+ ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator))
+ }
+ } else {
+ splitsAndImpurityInfo.maxBy(_._2.gain)
+ }
(bestSplit, bestSplitStats)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
index 3bded9c017..e1ab7c2d65 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala
@@ -26,9 +26,8 @@ import org.apache.spark.ml.linalg.{Vector, Vectors}
import org.apache.spark.ml.tree._
import org.apache.spark.ml.util.TestingUtils._
import org.apache.spark.mllib.tree.{DecisionTreeSuite => OldDTSuite, EnsembleTestHelper}
-import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, QuantileStrategy,
- Strategy => OldStrategy}
-import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, GiniCalculator}
+import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, QuantileStrategy, Strategy => OldStrategy}
+import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, GiniCalculator, Variance}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.util.collection.OpenHashMap
@@ -161,6 +160,21 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
}
}
+ test("train with empty arrays") {
+ val lp = LabeledPoint(1.0, Vectors.dense(Array.empty[Double]))
+ val data = Array.fill(5)(lp)
+ val rdd = sc.parallelize(data)
+
+ val strategy = new OldStrategy(OldAlgo.Regression, Gini, maxDepth = 2,
+ maxBins = 5)
+ withClue("DecisionTree requires number of features > 0," +
+ " but was given an empty features vector") {
+ intercept[IllegalArgumentException] {
+ RandomForest.run(rdd, strategy, 1, "all", 42L, instr = None)
+ }
+ }
+ }
+
test("train with constant features") {
val lp = LabeledPoint(1.0, Vectors.dense(0.0, 0.0, 0.0))
val data = Array.fill(5)(lp)
@@ -170,12 +184,23 @@ class RandomForestSuite extends SparkFunSuite with MLlibTestSparkContext {
Gini,
maxDepth = 2,
numClasses = 2,
- maxBins = 100,
+ maxBins = 5,
categoricalFeaturesInfo = Map(0 -> 1, 1 -> 5))
val Array(tree) = RandomForest.run(rdd, strategy, 1, "all", 42L, instr = None)
assert(tree.rootNode.impurity === -1.0)
assert(tree.depth === 0)
assert(tree.rootNode.prediction === lp.label)
+
+ // Test with no categorical features
+ val strategy2 = new OldStrategy(
+ OldAlgo.Regression,
+ Variance,
+ maxDepth = 2,
+ maxBins = 5)
+ val Array(tree2) = RandomForest.run(rdd, strategy2, 1, "all", 42L, instr = None)
+ assert(tree2.rootNode.impurity === -1.0)
+ assert(tree2.depth === 0)
+ assert(tree2.rootNode.prediction === lp.label)
}
test("Multiclass classification with unordered categorical features: split calculations") {