aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorJoseph K. Bradley <joseph.kurata.bradley@gmail.com>2014-08-01 15:52:21 -0700
committerXiangrui Meng <meng@databricks.com>2014-08-01 15:52:21 -0700
commit7058a5393bccc2f917189fa9b4cf7f314410b0de (patch)
tree1ff8433835c2e61780102356d992d68461e82972 /mllib
parentd88e69561367d65e1a2b94527b80a1f65a2cba90 (diff)
downloadspark-7058a5393bccc2f917189fa9b4cf7f314410b0de.tar.gz
spark-7058a5393bccc2f917189fa9b4cf7f314410b0de.tar.bz2
spark-7058a5393bccc2f917189fa9b4cf7f314410b0de.zip
[SPARK-2796] [mllib] DecisionTree bug fix: ordered categorical features
Bug: In DecisionTree, the method sequentialBinSearchForOrderedCategoricalFeatureInClassification() indexed bins from 0 to (math.pow(2, featureCategories.toInt - 1) - 1). This upper bound is the bound for unordered categorical features, not ordered ones. The upper bound should be the arity (i.e., max value) of the feature. Added new test to DecisionTreeSuite to catch this: "regression stump with categorical variables of arity 2" Bug fix: Modified upper bound discussed above. Also: Small improvements to coding style in DecisionTree. CC mengxr manishamde Author: Joseph K. Bradley <joseph.kurata.bradley@gmail.com> Closes #1720 from jkbradley/decisiontree-bugfix2 and squashes the following commits: 225822f [Joseph K. Bradley] Bug: In DecisionTree, the method sequentialBinSearchForOrderedCategoricalFeatureInClassification() indexed bins from 0 to (math.pow(2, featureCategories.toInt - 1) - 1). This upper bound is the bound for unordered categorical features, not ordered ones. The upper bound should be the arity (i.e., max value) of the feature.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala45
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala29
2 files changed, 56 insertions, 18 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 7d123dd6ae..382e76a9b7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -498,7 +498,7 @@ object DecisionTree extends Serializable with Logging {
val bin = binForFeatures(mid)
val lowThreshold = bin.lowSplit.threshold
val highThreshold = bin.highSplit.threshold
- if ((lowThreshold < feature) && (highThreshold >= feature)){
+ if ((lowThreshold < feature) && (highThreshold >= feature)) {
return mid
}
else if (lowThreshold >= feature) {
@@ -522,28 +522,36 @@ object DecisionTree extends Serializable with Logging {
}
/**
- * Sequential search helper method to find bin for categorical feature.
+ * Sequential search helper method to find bin for categorical feature
+ * (for classification and regression).
*/
- def sequentialBinSearchForOrderedCategoricalFeatureInClassification(): Int = {
+ def sequentialBinSearchForOrderedCategoricalFeature(): Int = {
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
- val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1
+ val featureValue = labeledPoint.features(featureIndex)
var binIndex = 0
- while (binIndex < numCategoricalBins) {
+ while (binIndex < featureCategories) {
val bin = bins(featureIndex)(binIndex)
val categories = bin.highSplit.categories
- val features = labeledPoint.features
- if (categories.contains(features(featureIndex))) {
+ if (categories.contains(featureValue)) {
return binIndex
}
binIndex += 1
}
+ if (featureValue < 0 || featureValue >= featureCategories) {
+ throw new IllegalArgumentException(
+ s"DecisionTree given invalid data:" +
+ s" Feature $featureIndex is categorical with values in" +
+ s" {0,...,${featureCategories - 1}," +
+ s" but a data point gives it value $featureValue.\n" +
+ " Bad data point: " + labeledPoint.toString)
+ }
-1
}
if (isFeatureContinuous) {
// Perform binary search for finding bin for continuous features.
val binIndex = binarySearchForBins()
- if (binIndex == -1){
+ if (binIndex == -1) {
throw new UnknownError("no bin was found for continuous variable.")
}
binIndex
@@ -555,10 +563,10 @@ object DecisionTree extends Serializable with Logging {
if (isUnorderedFeature) {
sequentialBinSearchForUnorderedCategoricalFeatureInClassification()
} else {
- sequentialBinSearchForOrderedCategoricalFeatureInClassification()
+ sequentialBinSearchForOrderedCategoricalFeature()
}
}
- if (binIndex == -1){
+ if (binIndex == -1) {
throw new UnknownError("no bin was found for categorical variable.")
}
binIndex
@@ -642,11 +650,12 @@ object DecisionTree extends Serializable with Logging {
val arrShift = 1 + numFeatures * nodeIndex
val arrIndex = arrShift + featureIndex
// Update the left or right count for one bin.
- val aggShift = numClasses * numBins * numFeatures * nodeIndex
- val aggIndex
- = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses
- val labelInt = label.toInt
- agg(aggIndex + labelInt) = agg(aggIndex + labelInt) + 1
+ val aggIndex =
+ numClasses * numBins * numFeatures * nodeIndex +
+ numClasses * numBins * featureIndex +
+ numClasses * arr(arrIndex).toInt +
+ label.toInt
+ agg(aggIndex) += 1
}
/**
@@ -1127,7 +1136,7 @@ object DecisionTree extends Serializable with Logging {
val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses)
var featureIndex = 0
while (featureIndex < numFeatures) {
- if (isMulticlassClassificationWithCategoricalFeatures){
+ if (isMulticlassClassificationWithCategoricalFeatures) {
val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
if (isFeatureContinuous) {
findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
@@ -1393,7 +1402,7 @@ object DecisionTree extends Serializable with Logging {
// Iterate over all features.
var featureIndex = 0
- while (featureIndex < numFeatures){
+ while (featureIndex < numFeatures) {
// Check whether the feature is continuous.
val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
if (isFeatureContinuous) {
@@ -1513,7 +1522,7 @@ object DecisionTree extends Serializable with Logging {
if (isFeatureContinuous) { // Bins for categorical variables are already assigned.
bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous),
splits(featureIndex)(0), Continuous, Double.MinValue)
- for (index <- 1 until numBins - 1){
+ for (index <- 1 until numBins - 1) {
val bin = new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index),
Continuous, Double.MinValue)
bins(featureIndex)(index) = bin
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 10462db700..546a132559 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -42,6 +42,18 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(accuracy >= requiredAccuracy)
}
+ def validateRegressor(
+ model: DecisionTreeModel,
+ input: Seq[LabeledPoint],
+ requiredMSE: Double) {
+ val predictions = input.map(x => model.predict(x.features))
+ val squaredError = predictions.zip(input).map { case (prediction, expected) =>
+ (prediction - expected.label) * (prediction - expected.label)
+ }.sum
+ val mse = squaredError / input.length
+ assert(mse <= requiredMSE)
+ }
+
test("split and bin calculation") {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
assert(arr.length === 1000)
@@ -454,6 +466,23 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(stats.impurity > 0.2)
}
+ test("regression stump with categorical variables of arity 2") {
+ val arr = DecisionTreeSuite.generateCategoricalDataPoints()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(
+ Regression,
+ Variance,
+ maxDepth = 2,
+ maxBins = 100,
+ categoricalFeaturesInfo = Map(0 -> 2, 1-> 2))
+
+ val model = DecisionTree.train(rdd, strategy)
+ validateRegressor(model, arr, 0.0)
+ assert(model.numNodes === 3)
+ assert(model.depth === 1)
+ }
+
test("stump with fixed label 0 for Gini") {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
assert(arr.length === 1000)