aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala45
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala29
2 files changed, 56 insertions, 18 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 7d123dd6ae..382e76a9b7 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -498,7 +498,7 @@ object DecisionTree extends Serializable with Logging {
val bin = binForFeatures(mid)
val lowThreshold = bin.lowSplit.threshold
val highThreshold = bin.highSplit.threshold
- if ((lowThreshold < feature) && (highThreshold >= feature)){
+ if ((lowThreshold < feature) && (highThreshold >= feature)) {
return mid
}
else if (lowThreshold >= feature) {
@@ -522,28 +522,36 @@ object DecisionTree extends Serializable with Logging {
}
/**
- * Sequential search helper method to find bin for categorical feature.
+ * Sequential search helper method to find bin for categorical feature
+ * (for classification and regression).
*/
- def sequentialBinSearchForOrderedCategoricalFeatureInClassification(): Int = {
+ def sequentialBinSearchForOrderedCategoricalFeature(): Int = {
val featureCategories = strategy.categoricalFeaturesInfo(featureIndex)
- val numCategoricalBins = math.pow(2.0, featureCategories - 1).toInt - 1
+ val featureValue = labeledPoint.features(featureIndex)
var binIndex = 0
- while (binIndex < numCategoricalBins) {
+ while (binIndex < featureCategories) {
val bin = bins(featureIndex)(binIndex)
val categories = bin.highSplit.categories
- val features = labeledPoint.features
- if (categories.contains(features(featureIndex))) {
+ if (categories.contains(featureValue)) {
return binIndex
}
binIndex += 1
}
+ if (featureValue < 0 || featureValue >= featureCategories) {
+ throw new IllegalArgumentException(
+ s"DecisionTree given invalid data:" +
+ s" Feature $featureIndex is categorical with values in" +
+ s" {0,...,${featureCategories - 1}," +
+ s" but a data point gives it value $featureValue.\n" +
+ " Bad data point: " + labeledPoint.toString)
+ }
-1
}
if (isFeatureContinuous) {
// Perform binary search for finding bin for continuous features.
val binIndex = binarySearchForBins()
- if (binIndex == -1){
+ if (binIndex == -1) {
throw new UnknownError("no bin was found for continuous variable.")
}
binIndex
@@ -555,10 +563,10 @@ object DecisionTree extends Serializable with Logging {
if (isUnorderedFeature) {
sequentialBinSearchForUnorderedCategoricalFeatureInClassification()
} else {
- sequentialBinSearchForOrderedCategoricalFeatureInClassification()
+ sequentialBinSearchForOrderedCategoricalFeature()
}
}
- if (binIndex == -1){
+ if (binIndex == -1) {
throw new UnknownError("no bin was found for categorical variable.")
}
binIndex
@@ -642,11 +650,12 @@ object DecisionTree extends Serializable with Logging {
val arrShift = 1 + numFeatures * nodeIndex
val arrIndex = arrShift + featureIndex
// Update the left or right count for one bin.
- val aggShift = numClasses * numBins * numFeatures * nodeIndex
- val aggIndex
- = aggShift + numClasses * featureIndex * numBins + arr(arrIndex).toInt * numClasses
- val labelInt = label.toInt
- agg(aggIndex + labelInt) = agg(aggIndex + labelInt) + 1
+ val aggIndex =
+ numClasses * numBins * numFeatures * nodeIndex +
+ numClasses * numBins * featureIndex +
+ numClasses * arr(arrIndex).toInt +
+ label.toInt
+ agg(aggIndex) += 1
}
/**
@@ -1127,7 +1136,7 @@ object DecisionTree extends Serializable with Logging {
val rightNodeAgg = Array.ofDim[Double](numFeatures, numBins - 1, numClasses)
var featureIndex = 0
while (featureIndex < numFeatures) {
- if (isMulticlassClassificationWithCategoricalFeatures){
+ if (isMulticlassClassificationWithCategoricalFeatures) {
val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
if (isFeatureContinuous) {
findAggForOrderedFeatureClassification(leftNodeAgg, rightNodeAgg, featureIndex)
@@ -1393,7 +1402,7 @@ object DecisionTree extends Serializable with Logging {
// Iterate over all features.
var featureIndex = 0
- while (featureIndex < numFeatures){
+ while (featureIndex < numFeatures) {
// Check whether the feature is continuous.
val isFeatureContinuous = strategy.categoricalFeaturesInfo.get(featureIndex).isEmpty
if (isFeatureContinuous) {
@@ -1513,7 +1522,7 @@ object DecisionTree extends Serializable with Logging {
if (isFeatureContinuous) { // Bins for categorical variables are already assigned.
bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous),
splits(featureIndex)(0), Continuous, Double.MinValue)
- for (index <- 1 until numBins - 1){
+ for (index <- 1 until numBins - 1) {
val bin = new Bin(splits(featureIndex)(index-1), splits(featureIndex)(index),
Continuous, Double.MinValue)
bins(featureIndex)(index) = bin
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 10462db700..546a132559 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -42,6 +42,18 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(accuracy >= requiredAccuracy)
}
+ def validateRegressor(
+ model: DecisionTreeModel,
+ input: Seq[LabeledPoint],
+ requiredMSE: Double) {
+ val predictions = input.map(x => model.predict(x.features))
+ val squaredError = predictions.zip(input).map { case (prediction, expected) =>
+ (prediction - expected.label) * (prediction - expected.label)
+ }.sum
+ val mse = squaredError / input.length
+ assert(mse <= requiredMSE)
+ }
+
test("split and bin calculation") {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel1()
assert(arr.length === 1000)
@@ -454,6 +466,23 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(stats.impurity > 0.2)
}
+ test("regression stump with categorical variables of arity 2") {
+ val arr = DecisionTreeSuite.generateCategoricalDataPoints()
+ assert(arr.length === 1000)
+ val rdd = sc.parallelize(arr)
+ val strategy = new Strategy(
+ Regression,
+ Variance,
+ maxDepth = 2,
+ maxBins = 100,
+ categoricalFeaturesInfo = Map(0 -> 2, 1-> 2))
+
+ val model = DecisionTree.train(rdd, strategy)
+ validateRegressor(model, arr, 0.0)
+ assert(model.numNodes === 3)
+ assert(model.depth === 1)
+ }
+
test("stump with fixed label 0 for Gini") {
val arr = DecisionTreeSuite.generateOrderedLabeledPointsWithLabel0()
assert(arr.length === 1000)