aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala164
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala18
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala6
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala4
4 files changed, 97 insertions, 95 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 4a77d4adcd..53d6482f80 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -19,7 +19,6 @@ package org.apache.spark.mllib.tree
import scala.collection.JavaConverters._
import scala.collection.mutable
-import scala.collection.mutable.ArrayBuilder
import org.apache.spark.Logging
import org.apache.spark.annotation.{Experimental, Since}
@@ -643,8 +642,8 @@ object DecisionTree extends Serializable with Logging {
val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b))
.map { case (nodeIndex, aggStats) =>
- val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
- Some(nodeToFeatures(nodeIndex))
+ val featuresForNode = nodeToFeaturesBc.value.map { nodeToFeatures =>
+ nodeToFeatures(nodeIndex)
}
// find best split for each node
@@ -976,8 +975,8 @@ object DecisionTree extends Serializable with Logging {
val numFeatures = metadata.numFeatures
// Sample the input only if there are continuous features.
- val hasContinuousFeatures = Range(0, numFeatures).exists(metadata.isContinuous)
- val sampledInput = if (hasContinuousFeatures) {
+ val continuousFeatures = Range(0, numFeatures).filter(metadata.isContinuous)
+ val sampledInput = if (continuousFeatures.nonEmpty) {
// Calculate the number of samples for approximate quantile calculation.
val requiredSamples = math.max(metadata.maxBins * metadata.maxBins, 10000)
val fraction = if (requiredSamples < metadata.numExamples) {
@@ -986,81 +985,14 @@ object DecisionTree extends Serializable with Logging {
1.0
}
logDebug("fraction of data used for calculating quantiles = " + fraction)
- input.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt()).collect()
+ input.sample(withReplacement = false, fraction, new XORShiftRandom().nextInt())
} else {
- new Array[LabeledPoint](0)
+ input.sparkContext.emptyRDD[LabeledPoint]
}
metadata.quantileStrategy match {
case Sort =>
- val splits = new Array[Array[Split]](numFeatures)
- val bins = new Array[Array[Bin]](numFeatures)
-
- // Find all splits.
- // Iterate over all features.
- var featureIndex = 0
- while (featureIndex < numFeatures) {
- if (metadata.isContinuous(featureIndex)) {
- val featureSamples = sampledInput.map(lp => lp.features(featureIndex))
- val featureSplits = findSplitsForContinuousFeature(featureSamples,
- metadata, featureIndex)
-
- val numSplits = featureSplits.length
- val numBins = numSplits + 1
- logDebug(s"featureIndex = $featureIndex, numSplits = $numSplits")
- splits(featureIndex) = new Array[Split](numSplits)
- bins(featureIndex) = new Array[Bin](numBins)
-
- var splitIndex = 0
- while (splitIndex < numSplits) {
- val threshold = featureSplits(splitIndex)
- splits(featureIndex)(splitIndex) =
- new Split(featureIndex, threshold, Continuous, List())
- splitIndex += 1
- }
- bins(featureIndex)(0) = new Bin(new DummyLowSplit(featureIndex, Continuous),
- splits(featureIndex)(0), Continuous, Double.MinValue)
-
- splitIndex = 1
- while (splitIndex < numSplits) {
- bins(featureIndex)(splitIndex) =
- new Bin(splits(featureIndex)(splitIndex - 1), splits(featureIndex)(splitIndex),
- Continuous, Double.MinValue)
- splitIndex += 1
- }
- bins(featureIndex)(numSplits) = new Bin(splits(featureIndex)(numSplits - 1),
- new DummyHighSplit(featureIndex, Continuous), Continuous, Double.MinValue)
- } else {
- val numSplits = metadata.numSplits(featureIndex)
- val numBins = metadata.numBins(featureIndex)
- // Categorical feature
- val featureArity = metadata.featureArity(featureIndex)
- if (metadata.isUnordered(featureIndex)) {
- // Unordered features
- // 2^(maxFeatureValue - 1) - 1 combinations
- splits(featureIndex) = new Array[Split](numSplits)
- var splitIndex = 0
- while (splitIndex < numSplits) {
- val categories: List[Double] =
- extractMultiClassCategories(splitIndex + 1, featureArity)
- splits(featureIndex)(splitIndex) =
- new Split(featureIndex, Double.MinValue, Categorical, categories)
- splitIndex += 1
- }
- } else {
- // Ordered features
- // Bins correspond to feature values, so we do not need to compute splits or bins
- // beforehand. Splits are constructed as needed during training.
- splits(featureIndex) = new Array[Split](0)
- }
- // For ordered features, bins correspond to feature values.
- // For unordered categorical features, there is no need to construct the bins.
- // since there is a one-to-one correspondence between the splits and the bins.
- bins(featureIndex) = new Array[Bin](0)
- }
- featureIndex += 1
- }
- (splits, bins)
+ findSplitsBinsBySorting(sampledInput, metadata, continuousFeatures)
case MinMax =>
throw new UnsupportedOperationException("minmax not supported yet.")
case ApproxHist =>
@@ -1068,6 +1000,82 @@ object DecisionTree extends Serializable with Logging {
}
}
+ private def findSplitsBinsBySorting(
+ input: RDD[LabeledPoint],
+ metadata: DecisionTreeMetadata,
+ continuousFeatures: IndexedSeq[Int]): (Array[Array[Split]], Array[Array[Bin]]) = {
+ def findSplits(
+ featureIndex: Int,
+ featureSamples: Iterable[Double]): (Int, (Array[Split], Array[Bin])) = {
+ val splits = {
+ val featureSplits = findSplitsForContinuousFeature(
+ featureSamples.toArray,
+ metadata,
+ featureIndex)
+ logDebug(s"featureIndex = $featureIndex, numSplits = ${featureSplits.length}")
+
+ featureSplits.map(threshold => new Split(featureIndex, threshold, Continuous, Nil))
+ }
+
+ val bins = {
+ val lowSplit = new DummyLowSplit(featureIndex, Continuous)
+ val highSplit = new DummyHighSplit(featureIndex, Continuous)
+
+ // tack the dummy splits on either side of the computed splits
+ val allSplits = lowSplit +: splits.toSeq :+ highSplit
+
+ // slide across the split points pairwise to allocate the bins
+ allSplits.sliding(2).map {
+ case Seq(left, right) => new Bin(left, right, Continuous, Double.MinValue)
+ }.toArray
+ }
+
+ (featureIndex, (splits, bins))
+ }
+
+ val continuousSplits = {
+ // reduce the parallelism for split computations when there are less
+ // continuous features than input partitions. this prevents tasks from
+ // being spun up that will definitely do no work.
+ val numPartitions = math.min(continuousFeatures.length, input.partitions.length)
+
+ input
+ .flatMap(point => continuousFeatures.map(idx => (idx, point.features(idx))))
+ .groupByKey(numPartitions)
+ .map { case (k, v) => findSplits(k, v) }
+ .collectAsMap()
+ }
+
+ val numFeatures = metadata.numFeatures
+ val (splits, bins) = Range(0, numFeatures).unzip {
+ case i if metadata.isContinuous(i) =>
+ val (split, bin) = continuousSplits(i)
+ metadata.setNumSplits(i, split.length)
+ (split, bin)
+
+ case i if metadata.isCategorical(i) && metadata.isUnordered(i) =>
+ // Unordered features
+ // 2^(maxFeatureValue - 1) - 1 combinations
+ val featureArity = metadata.featureArity(i)
+ val split = Range(0, metadata.numSplits(i)).map { splitIndex =>
+ val categories = extractMultiClassCategories(splitIndex + 1, featureArity)
+ new Split(i, Double.MinValue, Categorical, categories)
+ }
+
+ // For unordered categorical features, there is no need to construct the bins.
+ // since there is a one-to-one correspondence between the splits and the bins.
+ (split.toArray, Array.empty[Bin])
+
+ case i if metadata.isCategorical(i) =>
+ // Ordered features
+ // Bins correspond to feature values, so we do not need to compute splits or bins
+ // beforehand. Splits are constructed as needed during training.
+ (Array.empty[Split], Array.empty[Bin])
+ }
+
+ (splits.toArray, bins.toArray)
+ }
+
/**
* Nested method to extract list of eligible categories given an index. It extracts the
* position of ones in a binary representation of the input. If binary
@@ -1131,7 +1139,7 @@ object DecisionTree extends Serializable with Logging {
logDebug("stride = " + stride)
// iterate `valueCount` to find splits
- val splitsBuilder = ArrayBuilder.make[Double]
+ val splitsBuilder = Array.newBuilder[Double]
var index = 1
// currentCount: sum of counts of values that have been visited
var currentCount = valueCounts(0)._2
@@ -1163,8 +1171,8 @@ object DecisionTree extends Serializable with Logging {
assert(splits.length > 0,
s"DecisionTree could not handle feature $featureIndex since it had only 1 unique value." +
" Please remove this feature and then try again.")
- // set number of splits accordingly
- metadata.setNumSplits(featureIndex, splits.length)
+
+ // the split metadata must be updated on the driver
splits
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
index 0abed54111..1c611976a9 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
@@ -108,21 +108,21 @@ private[spark] class NodeIdCache(
prevNodeIdsForInstances = nodeIdsForInstances
nodeIdsForInstances = data.zip(nodeIdsForInstances).map {
- dataPoint => {
+ case (point, node) => {
var treeId = 0
while (treeId < nodeIdUpdaters.length) {
- val nodeIdUpdater = nodeIdUpdaters(treeId).getOrElse(dataPoint._2(treeId), null)
+ val nodeIdUpdater = nodeIdUpdaters(treeId).getOrElse(node(treeId), null)
if (nodeIdUpdater != null) {
val newNodeIndex = nodeIdUpdater.updateNodeIndex(
- binnedFeatures = dataPoint._1.datum.binnedFeatures,
+ binnedFeatures = point.datum.binnedFeatures,
bins = bins)
- dataPoint._2(treeId) = newNodeIndex
+ node(treeId) = newNodeIndex
}
treeId += 1
}
- dataPoint._2
+ node
}
}
@@ -138,7 +138,7 @@ private[spark] class NodeIdCache(
while (checkpointQueue.size > 1 && canDelete) {
// We can delete the oldest checkpoint iff
// the next checkpoint actually exists in the file system.
- if (checkpointQueue.get(1).get.getCheckpointFile != None) {
+ if (checkpointQueue.get(1).get.getCheckpointFile.isDefined) {
val old = checkpointQueue.dequeue()
// Since the old checkpoint is not deleted by Spark,
@@ -159,11 +159,11 @@ private[spark] class NodeIdCache(
* Call this after training is finished to delete any remaining checkpoints.
*/
def deleteAllCheckpoints(): Unit = {
- while (checkpointQueue.size > 0) {
+ while (checkpointQueue.nonEmpty) {
val old = checkpointQueue.dequeue()
- if (old.getCheckpointFile != None) {
+ for (checkpointFile <- old.getCheckpointFile) {
val fs = FileSystem.get(old.sparkContext.hadoopConfiguration)
- fs.delete(new Path(old.getCheckpointFile.get), true)
+ fs.delete(new Path(checkpointFile), true)
}
}
if (prevNodeIdsForInstances != null) {
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
index 356d957f15..1a4299db4e 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala
@@ -135,8 +135,6 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
val featureSamples = Array(1, 1, 2, 2, 2, 2, 2, 2, 2, 2, 3).map(_.toDouble)
val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
assert(splits.length === 3)
- assert(fakeMetadata.numSplits(0) === 3)
- assert(fakeMetadata.numBins(0) === 4)
// check returned splits are distinct
assert(splits.distinct.length === splits.length)
}
@@ -151,8 +149,6 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
val featureSamples = Array(2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 3, 4, 5).map(_.toDouble)
val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
assert(splits.length === 2)
- assert(fakeMetadata.numSplits(0) === 2)
- assert(fakeMetadata.numBins(0) === 3)
assert(splits(0) === 2.0)
assert(splits(1) === 3.0)
}
@@ -167,8 +163,6 @@ class DecisionTreeSuite extends SparkFunSuite with MLlibTestSparkContext {
val featureSamples = Array(0, 1, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2).map(_.toDouble)
val splits = DecisionTree.findSplitsForContinuousFeature(featureSamples, fakeMetadata, 0)
assert(splits.length === 1)
- assert(fakeMetadata.numSplits(0) === 1)
- assert(fakeMetadata.numBins(0) === 2)
assert(splits(0) === 1.0)
}
}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
index 334bf3790f..3d3f80063f 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/EnsembleTestHelper.scala
@@ -69,8 +69,8 @@ object EnsembleTestHelper {
required: Double,
metricName: String = "mse") {
val predictions = input.map(x => model.predict(x.features))
- val errors = predictions.zip(input.map(_.label)).map { case (prediction, label) =>
- label - prediction
+ val errors = predictions.zip(input).map { case (prediction, point) =>
+ point.label - prediction
}
val metric = metricName match {
case "mse" =>