aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala37
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala14
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala37
3 files changed, 15 insertions, 73 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index f1f85994e6..b9d0c56dd1 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -327,14 +327,14 @@ object DecisionTree extends Serializable with Logging {
* @param agg Array storing aggregate calculation, with a set of sufficient statistics for
* each (feature, bin).
* @param treePoint Data point being aggregated.
- * @param bins possible bins for all features, indexed (numFeatures)(numBins)
+ * @param splits possible splits indexed (numFeatures)(numSplits)
* @param unorderedFeatures Set of indices of unordered features.
* @param instanceWeight Weight (importance) of instance in dataset.
*/
private def mixedBinSeqOp(
agg: DTStatsAggregator,
treePoint: TreePoint,
- bins: Array[Array[Bin]],
+ splits: Array[Array[Split]],
unorderedFeatures: Set[Int],
instanceWeight: Double,
featuresForNode: Option[Array[Int]]): Unit = {
@@ -362,7 +362,7 @@ object DecisionTree extends Serializable with Logging {
val numSplits = agg.metadata.numSplits(featureIndex)
var splitIndex = 0
while (splitIndex < numSplits) {
- if (bins(featureIndex)(splitIndex).highSplit.categories.contains(featureValue)) {
+ if (splits(featureIndex)(splitIndex).categories.contains(featureValue)) {
agg.featureUpdate(leftNodeFeatureOffset, splitIndex, treePoint.label,
instanceWeight)
} else {
@@ -506,8 +506,8 @@ object DecisionTree extends Serializable with Logging {
if (metadata.unorderedFeatures.isEmpty) {
orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
} else {
- mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures,
- instanceWeight, featuresForNode)
+ mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, splits,
+ metadata.unorderedFeatures, instanceWeight, featuresForNode)
}
}
}
@@ -1024,35 +1024,15 @@ object DecisionTree extends Serializable with Logging {
// Categorical feature
val featureArity = metadata.featureArity(featureIndex)
if (metadata.isUnordered(featureIndex)) {
- // TODO: The second half of the bins are unused. Actually, we could just use
- // splits and not build bins for unordered features. That should be part of
- // a later PR since it will require changing other code (using splits instead
- // of bins in a few places).
// Unordered features
- // 2^(maxFeatureValue - 1) - 1 combinations
+ // 2^(maxFeatureValue - 1) - 1 combinations
splits(featureIndex) = new Array[Split](numSplits)
- bins(featureIndex) = new Array[Bin](numBins)
var splitIndex = 0
while (splitIndex < numSplits) {
val categories: List[Double] =
extractMultiClassCategories(splitIndex + 1, featureArity)
splits(featureIndex)(splitIndex) =
new Split(featureIndex, Double.MinValue, Categorical, categories)
- bins(featureIndex)(splitIndex) = {
- if (splitIndex == 0) {
- new Bin(
- new DummyCategoricalSplit(featureIndex, Categorical),
- splits(featureIndex)(0),
- Categorical,
- Double.MinValue)
- } else {
- new Bin(
- splits(featureIndex)(splitIndex - 1),
- splits(featureIndex)(splitIndex),
- Categorical,
- Double.MinValue)
- }
- }
splitIndex += 1
}
} else {
@@ -1060,8 +1040,11 @@ object DecisionTree extends Serializable with Logging {
// Bins correspond to feature values, so we do not need to compute splits or bins
// beforehand. Splits are constructed as needed during training.
splits(featureIndex) = new Array[Split](0)
- bins(featureIndex) = new Array[Bin](0)
}
+ // For ordered features, bins correspond to feature values.
+ // For unordered categorical features, there is no need to construct the bins.
+ // since there is a one-to-one correspondence between the splits and the bins.
+ bins(featureIndex) = new Array[Bin](0)
}
featureIndex += 1
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
index 35e361ae30..50b292e71b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala
@@ -55,17 +55,15 @@ private[tree] object TreePoint {
input: RDD[LabeledPoint],
bins: Array[Array[Bin]],
metadata: DecisionTreeMetadata): RDD[TreePoint] = {
- // Construct arrays for featureArity and isUnordered for efficiency in the inner loop.
+ // Construct arrays for featureArity for efficiency in the inner loop.
val featureArity: Array[Int] = new Array[Int](metadata.numFeatures)
- val isUnordered: Array[Boolean] = new Array[Boolean](metadata.numFeatures)
var featureIndex = 0
while (featureIndex < metadata.numFeatures) {
featureArity(featureIndex) = metadata.featureArity.getOrElse(featureIndex, 0)
- isUnordered(featureIndex) = metadata.isUnordered(featureIndex)
featureIndex += 1
}
input.map { x =>
- TreePoint.labeledPointToTreePoint(x, bins, featureArity, isUnordered)
+ TreePoint.labeledPointToTreePoint(x, bins, featureArity)
}
}
@@ -74,19 +72,17 @@ private[tree] object TreePoint {
* @param bins Bins for features, of size (numFeatures, numBins).
* @param featureArity Array indexed by feature, with value 0 for continuous and numCategories
* for categorical features.
- * @param isUnordered Array index by feature, with value true for unordered categorical features.
*/
private def labeledPointToTreePoint(
labeledPoint: LabeledPoint,
bins: Array[Array[Bin]],
- featureArity: Array[Int],
- isUnordered: Array[Boolean]): TreePoint = {
+ featureArity: Array[Int]): TreePoint = {
val numFeatures = labeledPoint.features.size
val arr = new Array[Int](numFeatures)
var featureIndex = 0
while (featureIndex < numFeatures) {
arr(featureIndex) = findBin(featureIndex, labeledPoint, featureArity(featureIndex),
- isUnordered(featureIndex), bins)
+ bins)
featureIndex += 1
}
new TreePoint(labeledPoint.label, arr)
@@ -96,14 +92,12 @@ private[tree] object TreePoint {
* Find bin for one (labeledPoint, feature).
*
* @param featureArity 0 for continuous features; number of categories for categorical features.
- * @param isUnorderedFeature (only applies if feature is categorical)
* @param bins Bins for features, of size (numFeatures, numBins).
*/
private def findBin(
featureIndex: Int,
labeledPoint: LabeledPoint,
featureArity: Int,
- isUnorderedFeature: Boolean,
bins: Array[Array[Bin]]): Int = {
/**
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 7b1aed5ffe..4c162df810 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -190,7 +190,7 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
assert(splits.length === 2)
assert(bins.length === 2)
assert(splits(0).length === 3)
- assert(bins(0).length === 6)
+ assert(bins(0).length === 0)
// Expecting 2^2 - 1 = 3 bins/splits
assert(splits(0)(0).feature === 0)
@@ -228,41 +228,6 @@ class DecisionTreeSuite extends FunSuite with MLlibTestSparkContext {
assert(splits(1)(2).categories.contains(0.0))
assert(splits(1)(2).categories.contains(1.0))
- // Check bins.
-
- assert(bins(0)(0).category === Double.MinValue)
- assert(bins(0)(0).lowSplit.categories.length === 0)
- assert(bins(0)(0).highSplit.categories.length === 1)
- assert(bins(0)(0).highSplit.categories.contains(0.0))
- assert(bins(1)(0).category === Double.MinValue)
- assert(bins(1)(0).lowSplit.categories.length === 0)
- assert(bins(1)(0).highSplit.categories.length === 1)
- assert(bins(1)(0).highSplit.categories.contains(0.0))
-
- assert(bins(0)(1).category === Double.MinValue)
- assert(bins(0)(1).lowSplit.categories.length === 1)
- assert(bins(0)(1).lowSplit.categories.contains(0.0))
- assert(bins(0)(1).highSplit.categories.length === 1)
- assert(bins(0)(1).highSplit.categories.contains(1.0))
- assert(bins(1)(1).category === Double.MinValue)
- assert(bins(1)(1).lowSplit.categories.length === 1)
- assert(bins(1)(1).lowSplit.categories.contains(0.0))
- assert(bins(1)(1).highSplit.categories.length === 1)
- assert(bins(1)(1).highSplit.categories.contains(1.0))
-
- assert(bins(0)(2).category === Double.MinValue)
- assert(bins(0)(2).lowSplit.categories.length === 1)
- assert(bins(0)(2).lowSplit.categories.contains(1.0))
- assert(bins(0)(2).highSplit.categories.length === 2)
- assert(bins(0)(2).highSplit.categories.contains(1.0))
- assert(bins(0)(2).highSplit.categories.contains(0.0))
- assert(bins(1)(2).category === Double.MinValue)
- assert(bins(1)(2).lowSplit.categories.length === 1)
- assert(bins(1)(2).lowSplit.categories.contains(1.0))
- assert(bins(1)(2).highSplit.categories.length === 2)
- assert(bins(1)(2).highSplit.categories.contains(1.0))
- assert(bins(1)(2).highSplit.categories.contains(0.0))
-
}
test("Multiclass classification with ordered categorical features: split and bin calculations") {