aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala104
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala11
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala68
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala5
4 files changed, 174 insertions, 14 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 03eeaa7077..6737a2f417 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -19,6 +19,8 @@ package org.apache.spark.mllib.tree
import scala.collection.JavaConverters._
import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
+
import org.apache.spark.annotation.Experimental
import org.apache.spark.api.java.JavaRDD
@@ -909,32 +911,39 @@ object DecisionTree extends Serializable with Logging {
// Iterate over all features.
var featureIndex = 0
while (featureIndex < numFeatures) {
- val numSplits = metadata.numSplits(featureIndex)
- val numBins = metadata.numBins(featureIndex)
if (metadata.isContinuous(featureIndex)) {
- val numSamples = sampledInput.length
+ val featureSamples = sampledInput.map(lp => lp.features(featureIndex))
+ val featureSplits = findSplitsForContinuousFeature(featureSamples,
+ metadata, featureIndex)
+
+ val numSplits = featureSplits.length
+ val numBins = numSplits + 1
+ logDebug(s"featureIndex = $featureIndex, numSplits = $numSplits")
splits(featureIndex) = new Array[Split](numSplits)
bins(featureIndex) = new Array[Bin](numBins)
- val featureSamples = sampledInput.map(lp => lp.features(featureIndex)).sorted
- val stride: Double = numSamples.toDouble / metadata.numBins(featureIndex)
- logDebug("stride = " + stride)
- for (splitIndex <- 0 until numSplits) {
- val sampleIndex = splitIndex * stride.toInt
- // Set threshold halfway in between 2 samples.
- val threshold = (featureSamples(sampleIndex) + featureSamples(sampleIndex + 1)) / 2.0
+
+ var splitIndex = 0
+ while (splitIndex < numSplits) {
+ val threshold = featureSplits(splitIndex)
splits(featureIndex)(splitIndex) =
new Split(featureIndex, threshold, Continuous, List())
+ splitIndex += 1
}
bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous),
splits(featureIndex)(0), Continuous, Double.MinValue)
- for (splitIndex <- 1 until numSplits) {
+
+ splitIndex = 1
+ while (splitIndex < numSplits) {
bins(featureIndex)(splitIndex) =
new Bin(splits(featureIndex)(splitIndex - 1), splits(featureIndex)(splitIndex),
Continuous, Double.MinValue)
+ splitIndex += 1
}
bins(featureIndex)(numSplits) = new Bin(splits(featureIndex)(numSplits - 1),
new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue)
} else {
+ val numSplits = metadata.numSplits(featureIndex)
+ val numBins = metadata.numBins(featureIndex)
// Categorical feature
val featureArity = metadata.featureArity(featureIndex)
if (metadata.isUnordered(featureIndex)) {
@@ -1011,4 +1020,77 @@ object DecisionTree extends Serializable with Logging {
categories
}
+ /**
+ * Find splits for a continuous feature
+ * NOTE: Returned number of splits is set based on `featureSamples` and
+ * could be different from the specified `numSplits`.
+ * The `numSplits` attribute in the `DecisionTreeMetadata` class will be set accordingly.
+ * @param featureSamples feature values of each sample
+ * @param metadata decision tree metadata
+ * NOTE: `metadata.numbins` will be changed accordingly
+ * if there are not enough splits to be found
+ * @param featureIndex feature index to find splits
+ * @return array of splits
+ */
+ private[tree] def findSplitsForContinuousFeature(
+ featureSamples: Array[Double],
+ metadata: DecisionTreeMetadata,
+ featureIndex: Int): Array[Double] = {
+ require(metadata.isContinuous(featureIndex),
+ "findSplitsForContinuousFeature can only be used to find splits for a continuous feature.")
+
+ val splits = {
+ val numSplits = metadata.numSplits(featureIndex)
+
+ // get count for each distinct value
+ val valueCountMap = featureSamples.foldLeft(Map.empty[Double, Int]) { (m, x) =>
+ m + ((x, m.getOrElse(x, 0) + 1))
+ }
+ // sort distinct values
+ val valueCounts = valueCountMap.toSeq.sortBy(_._1).toArray
+
+ // if possible splits is not enough or just enough, just return all possible splits
+ val possibleSplits = valueCounts.length
+ if (possibleSplits <= numSplits) {
+ valueCounts.map(_._1)
+ } else {
+ // stride between splits
+ val stride: Double = featureSamples.length.toDouble / (numSplits + 1)
+ logDebug("stride = " + stride)
+
+ // iterate `valueCount` to find splits
+ val splits = new ArrayBuffer[Double]
+ var index = 1
+ // currentCount: sum of counts of values that have been visited
+ var currentCount = valueCounts(0)._2
+ // targetCount: target value for `currentCount`.
+ // If `currentCount` is closest value to `targetCount`,
+ // then current value is a split threshold.
+ // After finding a split threshold, `targetCount` is added by stride.
+ var targetCount = stride
+ while (index < valueCounts.length) {
+ val previousCount = currentCount
+ currentCount += valueCounts(index)._2
+ val previousGap = math.abs(previousCount - targetCount)
+ val currentGap = math.abs(currentCount - targetCount)
+ // If adding count of current value to currentCount
+ // makes the gap between currentCount and targetCount smaller,
+ // previous value is a split threshold.
+ if (previousGap < currentGap) {
+ splits.append(valueCounts(index - 1)._1)
+ targetCount += stride
+ }
+ index += 1
+ }
+
+ splits.toArray
+ }
+ }
+
+ assert(splits.length > 0)
+ // set number of splits accordingly
+ metadata.setNumSplits(featureIndex, splits.length)
+
+ splits
+ }
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
index 772c02670e..5bc0f2635c 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala
@@ -76,6 +76,17 @@ private[tree] class DecisionTreeMetadata(
numBins(featureIndex) - 1
}
+
+ /**
+ * Set number of splits for a continuous feature.
+ * For a continuous feature, number of bins is number of splits plus 1.
+ */
+ def setNumSplits(featureIndex: Int, numSplits: Int) {
+ require(isContinuous(featureIndex),
+ s"Only number of bin for a continuous feature can be set.")
+ numBins(featureIndex) = numSplits + 1
+ }
+
/**
* Indicates if feature subsampling is being used.
*/
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 98a72b0c4d..8fc5e111bb 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -26,7 +26,7 @@ import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.FeatureType._
-import org.apache.spark.mllib.tree.configuration.Strategy
+import org.apache.spark.mllib.tree.configuration.{QuantileStrategy, Strategy}
import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, TreePoint}
import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance}
import org.apache.spark.mllib.tree.model.{InformationGainStats, DecisionTreeModel, Node}
@@ -102,6 +102,72 @@ class DecisionTreeSuite extends FunSuite with LocalSparkContext {
assert(List(3.0, 2.0, 0.0).toSeq === l.toSeq)
}
+ test("find splits for a continuous feature") {
+ // find splits for normal case
+ {
+ val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+ Map(), Set(),
+ Array(6), Gini, QuantileStrategy.Sort,
+ 0, 0, 0.0, 0, 0
+ )
+ val featureSamples = Array.fill(200000)(math.random)
+ val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
+ assert(splits.length === 5)
+ assert(fakeMetadata.numSplits(0) === 5)
+ assert(fakeMetadata.numBins(0) === 6)
+ // check returned splits are distinct
+ assert(splits.distinct.length === splits.length)
+ }
+
+ // find splits should not return identical splits
+ // when there are not enough split candidates, reduce the number of splits in metadata
+ {
+ val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+ Map(), Set(),
+ Array(5), Gini, QuantileStrategy.Sort,
+ 0, 0, 0.0, 0, 0
+ )
+ val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble)
+ val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
+ assert(splits.length === 3)
+ assert(fakeMetadata.numSplits(0) === 3)
+ assert(fakeMetadata.numBins(0) === 4)
+ // check returned splits are distinct
+ assert(splits.distinct.length === splits.length)
+ }
+
+ // find splits when most samples close to the minimum
+ {
+ val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+ Map(), Set(),
+ Array(3), Gini, QuantileStrategy.Sort,
+ 0, 0, 0.0, 0, 0
+ )
+ val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble)
+ val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
+ assert(splits.length === 2)
+ assert(fakeMetadata.numSplits(0) === 2)
+ assert(fakeMetadata.numBins(0) === 3)
+ assert(splits(0) === 2.0)
+ assert(splits(1) === 3.0)
+ }
+
+ // find splits when most samples close to the maximum
+ {
+ val fakeMetadata = new DecisionTreeMetadata(1, 0, 0, 0,
+ Map(), Set(),
+ Array(3), Gini, QuantileStrategy.Sort,
+ 0, 0, 0.0, 0, 0
+ )
+ val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble)
+ val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
+ assert(splits.length === 1)
+ assert(fakeMetadata.numSplits(0) === 1)
+ assert(fakeMetadata.numBins(0) === 2)
+ assert(splits(0) === 1.0)
+ }
+ }
+
test("Multiclass classification with unordered categorical features:" +
" split and bin calculations") {
val arr = DecisionTreeSuite.generateCategoricalDataPoints()
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
index fb44ceb0f5..6b13765b98 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
@@ -93,8 +93,9 @@ class RandomForestSuite extends FunSuite with LocalSparkContext {
val categoricalFeaturesInfo = Map.empty[Int, Int]
val numTrees = 1
- val strategy = new Strategy(algo = Regression, impurity = Variance, maxDepth = 2,
- numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
+ val strategy = new Strategy(algo = Regression, impurity = Variance,
+ maxDepth = 2, maxBins = 10, numClassesForClassification = 2,
+ categoricalFeaturesInfo = categoricalFeaturesInfo)
val rf = RandomForest.trainRegressor(rdd, strategy, numTrees = numTrees,
featureSubsetStrategy = "auto", seed = 123)