aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala25
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala114
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala22
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala12
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala204
-rw-r--r--mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala69
6 files changed, 405 insertions, 41 deletions
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index f98730366b..49751a3049 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -62,7 +62,10 @@ object DecisionTreeRunner {
minInfoGain: Double = 0.0,
numTrees: Int = 1,
featureSubsetStrategy: String = "auto",
- fracTest: Double = 0.2) extends AbstractParams[Params]
+ fracTest: Double = 0.2,
+ useNodeIdCache: Boolean = false,
+ checkpointDir: Option[String] = None,
+ checkpointInterval: Int = 10) extends AbstractParams[Params]
def main(args: Array[String]) {
val defaultParams = Params()
@@ -102,6 +105,21 @@ object DecisionTreeRunner {
.text(s"fraction of data to hold out for testing. If given option testInput, " +
s"this option is ignored. default: ${defaultParams.fracTest}")
.action((x, c) => c.copy(fracTest = x))
+ opt[Boolean]("useNodeIdCache")
+ .text(s"whether to use node Id cache during training, " +
+ s"default: ${defaultParams.useNodeIdCache}")
+ .action((x, c) => c.copy(useNodeIdCache = x))
+ opt[String]("checkpointDir")
+ .text(s"checkpoint directory where intermediate node Id caches will be stored, " +
+ s"default: ${defaultParams.checkpointDir match {
+ case Some(strVal) => strVal
+ case None => "None"
+ }}")
+ .action((x, c) => c.copy(checkpointDir = Some(x)))
+ opt[Int]("checkpointInterval")
+ .text(s"how often to checkpoint the node Id cache, " +
+ s"default: ${defaultParams.checkpointInterval}")
+ .action((x, c) => c.copy(checkpointInterval = x))
opt[String]("testInput")
.text(s"input path to test dataset. If given, option fracTest is ignored." +
s" default: ${defaultParams.testInput}")
@@ -236,7 +254,10 @@ object DecisionTreeRunner {
maxBins = params.maxBins,
numClassesForClassification = numClasses,
minInstancesPerNode = params.minInstancesPerNode,
- minInfoGain = params.minInfoGain)
+ minInfoGain = params.minInfoGain,
+ useNodeIdCache = params.useNodeIdCache,
+ checkpointDir = params.checkpointDir,
+ checkpointInterval = params.checkpointInterval)
if (params.numTrees == 1) {
val startTime = System.nanoTime()
val model = DecisionTree.train(training, strategy)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
index 752ed59a03..78acc17f90 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala
@@ -437,6 +437,11 @@ object DecisionTree extends Serializable with Logging {
* @param bins possible bins for all features, indexed (numFeatures)(numBins)
* @param nodeQueue Queue of nodes to split, with values (treeIndex, node).
* Updated with new non-leaf nodes which are created.
+ * @param nodeIdCache Node Id cache containing an RDD of Array[Int] where
+ * each value in the array is the data point's node Id
+ * for a corresponding tree. This is used to prevent the need
+ * to pass the entire tree to the executors during
+ * the node stat aggregation phase.
*/
private[tree] def findBestSplits(
input: RDD[BaggedPoint[TreePoint]],
@@ -447,7 +452,8 @@ object DecisionTree extends Serializable with Logging {
splits: Array[Array[Split]],
bins: Array[Array[Bin]],
nodeQueue: mutable.Queue[(Int, Node)],
- timer: TimeTracker = new TimeTracker): Unit = {
+ timer: TimeTracker = new TimeTracker,
+ nodeIdCache: Option[NodeIdCache] = None): Unit = {
/*
* The high-level descriptions of the best split optimizations are noted here.
@@ -479,6 +485,37 @@ object DecisionTree extends Serializable with Logging {
logDebug("isMulticlass = " + metadata.isMulticlass)
logDebug("isMulticlassWithCategoricalFeatures = " +
metadata.isMulticlassWithCategoricalFeatures)
+ logDebug("using nodeIdCache = " + nodeIdCache.nonEmpty.toString)
+
+ /**
+ * Performs a sequential aggregation over a partition for a particular tree and node.
+ *
+ * For each feature, the aggregate sufficient statistics are updated for the relevant
+ * bins.
+ *
+ * @param treeIndex Index of the tree that we want to perform aggregation for.
+ * @param nodeInfo The node info for the tree node.
+ * @param agg Array storing aggregate calculation, with a set of sufficient statistics
+ * for each (node, feature, bin).
+ * @param baggedPoint Data point being aggregated.
+ */
+ def nodeBinSeqOp(
+ treeIndex: Int,
+ nodeInfo: RandomForest.NodeIndexInfo,
+ agg: Array[DTStatsAggregator],
+ baggedPoint: BaggedPoint[TreePoint]): Unit = {
+ if (nodeInfo != null) {
+ val aggNodeIndex = nodeInfo.nodeIndexInGroup
+ val featuresForNode = nodeInfo.featureSubset
+ val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
+ if (metadata.unorderedFeatures.isEmpty) {
+ orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
+ } else {
+ mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures,
+ instanceWeight, featuresForNode)
+ }
+ }
+ }
/**
* Performs a sequential aggregation over a partition.
@@ -497,20 +534,25 @@ object DecisionTree extends Serializable with Logging {
treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
val nodeIndex = predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures,
bins, metadata.unorderedFeatures)
- val nodeInfo = nodeIndexToInfo.getOrElse(nodeIndex, null)
- // If the example does not reach a node in this group, then nodeIndex = null.
- if (nodeInfo != null) {
- val aggNodeIndex = nodeInfo.nodeIndexInGroup
- val featuresForNode = nodeInfo.featureSubset
- val instanceWeight = baggedPoint.subsampleWeights(treeIndex)
- if (metadata.unorderedFeatures.isEmpty) {
- orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode)
- } else {
- mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures,
- instanceWeight, featuresForNode)
- }
- }
+ nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
+ }
+
+ agg
+ }
+
+ /**
+ * Do the same thing as binSeqOp, but with nodeIdCache.
+ */
+ def binSeqOpWithNodeIdCache(
+ agg: Array[DTStatsAggregator],
+ dataPoint: (BaggedPoint[TreePoint], Array[Int])): Array[DTStatsAggregator] = {
+ treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
+ val baggedPoint = dataPoint._1
+ val nodeIdCache = dataPoint._2
+ val nodeIndex = nodeIdCache(treeIndex)
+ nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
}
+
agg
}
@@ -553,7 +595,26 @@ object DecisionTree extends Serializable with Logging {
// Finally, only best Splits for nodes are collected to driver to construct decision tree.
val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo)
val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures)
- val nodeToBestSplits =
+
+ val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) {
+ input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points =>
+ // Construct a nodeStatsAggregators array to hold node aggregate stats,
+ // each node will have a nodeStatsAggregator
+ val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex =>
+ val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
+ Some(nodeToFeatures(nodeIndex))
+ }
+ new DTStatsAggregator(metadata, featuresForNode)
+ }
+
+ // iterator all instances in current partition and update aggregate stats
+ points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _))
+
+ // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
+ // which can be combined with other partition using `reduceByKey`
+ nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
+ }
+ } else {
input.mapPartitions { points =>
// Construct a nodeStatsAggregators array to hold node aggregate stats,
// each node will have a nodeStatsAggregator
@@ -570,7 +631,10 @@ object DecisionTree extends Serializable with Logging {
// transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs,
// which can be combined with other partition using `reduceByKey`
nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator
- }.reduceByKey((a, b) => a.merge(b))
+ }
+ }
+
+ val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b))
.map { case (nodeIndex, aggStats) =>
val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures =>
Some(nodeToFeatures(nodeIndex))
@@ -584,6 +648,13 @@ object DecisionTree extends Serializable with Logging {
timer.stop("chooseSplits")
+ val nodeIdUpdaters = if (nodeIdCache.nonEmpty) {
+ Array.fill[mutable.Map[Int, NodeIndexUpdater]](
+ metadata.numTrees)(mutable.Map[Int, NodeIndexUpdater]())
+ } else {
+ null
+ }
+
// Iterate over all nodes in this group.
nodesForGroup.foreach { case (treeIndex, nodesForTree) =>
nodesForTree.foreach { node =>
@@ -613,6 +684,13 @@ object DecisionTree extends Serializable with Logging {
node.rightNode = Some(Node(Node.rightChildIndex(nodeIndex),
stats.rightPredict, stats.rightImpurity, rightChildIsLeaf))
+ if (nodeIdCache.nonEmpty) {
+ val nodeIndexUpdater = NodeIndexUpdater(
+ split = split,
+ nodeIndex = nodeIndex)
+ nodeIdUpdaters(treeIndex).put(nodeIndex, nodeIndexUpdater)
+ }
+
// enqueue left child and right child if they are not leaves
if (!leftChildIsLeaf) {
nodeQueue.enqueue((treeIndex, node.leftNode.get))
@@ -629,6 +707,10 @@ object DecisionTree extends Serializable with Logging {
}
}
+ if (nodeIdCache.nonEmpty) {
+ // Update the cache if needed.
+ nodeIdCache.get.updateNodeIndices(input, nodeIdUpdaters, bins)
+ }
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index 1dcaf91438..9683916d9b 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -28,7 +28,7 @@ import org.apache.spark.mllib.tree.configuration.Algo._
import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Average
import org.apache.spark.mllib.tree.configuration.Strategy
-import org.apache.spark.mllib.tree.impl.{BaggedPoint, TreePoint, DecisionTreeMetadata, TimeTracker}
+import org.apache.spark.mllib.tree.impl.{BaggedPoint, TreePoint, DecisionTreeMetadata, TimeTracker, NodeIdCache }
import org.apache.spark.mllib.tree.impurity.Impurities
import org.apache.spark.mllib.tree.model._
import org.apache.spark.rdd.RDD
@@ -160,6 +160,19 @@ private class RandomForest (
* in lower levels).
*/
+ // Create an RDD of node Id cache.
+ // At first, all the rows belong to the root nodes (node Id == 1).
+ val nodeIdCache = if (strategy.useNodeIdCache) {
+ Some(NodeIdCache.init(
+ data = baggedInput,
+ numTrees = numTrees,
+ checkpointDir = strategy.checkpointDir,
+ checkpointInterval = strategy.checkpointInterval,
+ initVal = 1))
+ } else {
+ None
+ }
+
// FIFO queue of nodes to train: (treeIndex, node)
val nodeQueue = new mutable.Queue[(Int, Node)]()
@@ -182,7 +195,7 @@ private class RandomForest (
// Choose node splits, and enqueue new nodes as needed.
timer.start("findBestSplits")
DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup,
- treeToNodeToIndexInfo, splits, bins, nodeQueue, timer)
+ treeToNodeToIndexInfo, splits, bins, nodeQueue, timer, nodeIdCache = nodeIdCache)
timer.stop("findBestSplits")
}
@@ -193,6 +206,11 @@ private class RandomForest (
logInfo("Internal timing for DecisionTree:")
logInfo(s"$timer")
+ // Delete any remaining checkpoints used for node Id cache.
+ if (nodeIdCache.nonEmpty) {
+ nodeIdCache.get.deleteAllCheckpoints()
+ }
+
val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo))
val treeWeights = Array.fill[Double](numTrees)(1.0)
new WeightedEnsembleModel(trees, treeWeights, strategy.algo, Average)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index 2ed63cf002..d09295c507 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -60,6 +60,13 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
* @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is
* 256 MB.
* @param subsamplingRate Fraction of the training data used for learning decision tree.
+ * @param useNodeIdCache If this is true, instead of passing trees to executors, the algorithm will
+ * maintain a separate RDD of node Id cache for each row.
+ * @param checkpointDir If the node Id cache is used, it will help to checkpoint
+ * the node Id cache periodically. This is the checkpoint directory
+ * to be used for the node Id cache.
+ * @param checkpointInterval How often to checkpoint when the node Id cache gets updated.
+ * E.g. 10 means that the cache will get checkpointed every 10 updates.
*/
@Experimental
class Strategy (
@@ -73,7 +80,10 @@ class Strategy (
@BeanProperty var minInstancesPerNode: Int = 1,
@BeanProperty var minInfoGain: Double = 0.0,
@BeanProperty var maxMemoryInMB: Int = 256,
- @BeanProperty var subsamplingRate: Double = 1) extends Serializable {
+ @BeanProperty var subsamplingRate: Double = 1,
+ @BeanProperty var useNodeIdCache: Boolean = false,
+ @BeanProperty var checkpointDir: Option[String] = None,
+ @BeanProperty var checkpointInterval: Int = 10) extends Serializable {
if (algo == Classification) {
require(numClassesForClassification >= 2)
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
new file mode 100644
index 0000000000..83011b48b7
--- /dev/null
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
@@ -0,0 +1,204 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.mllib.tree.impl
+
+import scala.collection.mutable
+
+import org.apache.hadoop.fs.{Path, FileSystem}
+
+import org.apache.spark.rdd.RDD
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.mllib.tree.configuration.FeatureType._
+import org.apache.spark.storage.StorageLevel
+import org.apache.spark.mllib.tree.model.{Bin, Node, Split}
+
+/**
+ * :: DeveloperApi ::
+ * This is used by the node id cache to find the child id that a data point would belong to.
+ * @param split Split information.
+ * @param nodeIndex The current node index of a data point that this will update.
+ */
+@DeveloperApi
+private[tree] case class NodeIndexUpdater(
+ split: Split,
+ nodeIndex: Int) {
+ /**
+ * Determine a child node index based on the feature value and the split.
+ * @param binnedFeatures Binned feature values.
+ * @param bins Bin information to convert the bin indices to approximate feature values.
+ * @return Child node index to update to.
+ */
+ def updateNodeIndex(binnedFeatures: Array[Int], bins: Array[Array[Bin]]): Int = {
+ if (split.featureType == Continuous) {
+ val featureIndex = split.feature
+ val binIndex = binnedFeatures(featureIndex)
+ val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold
+ if (featureValueUpperBound <= split.threshold) {
+ Node.leftChildIndex(nodeIndex)
+ } else {
+ Node.rightChildIndex(nodeIndex)
+ }
+ } else {
+ if (split.categories.contains(binnedFeatures(split.feature).toDouble)) {
+ Node.leftChildIndex(nodeIndex)
+ } else {
+ Node.rightChildIndex(nodeIndex)
+ }
+ }
+ }
+}
+
+/**
+ * :: DeveloperApi ::
+ * A given TreePoint would belong to a particular node per tree.
+ * Each row in the nodeIdsForInstances RDD is an array over trees of the node index
+ * in each tree. Initially, values should all be 1 for root node.
+ * The nodeIdsForInstances RDD needs to be updated at each iteration.
+ * @param nodeIdsForInstances The initial values in the cache
+ * (should be an Array of all 1's (meaning the root nodes)).
+ * @param checkpointDir The checkpoint directory where
+ * the checkpointed files will be stored.
+ * @param checkpointInterval The checkpointing interval
+ * (how often should the cache be checkpointed.).
+ */
+@DeveloperApi
+private[tree] class NodeIdCache(
+ var nodeIdsForInstances: RDD[Array[Int]],
+ val checkpointDir: Option[String],
+ val checkpointInterval: Int) {
+
+ // Keep a reference to a previous node Ids for instances.
+ // Because we will keep on re-persisting updated node Ids,
+ // we want to unpersist the previous RDD.
+ private var prevNodeIdsForInstances: RDD[Array[Int]] = null
+
+ // To keep track of the past checkpointed RDDs.
+ private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]()
+ private var rddUpdateCount = 0
+
+ // If a checkpoint directory is given, and there's no prior checkpoint directory,
+ // then set the checkpoint directory with the given one.
+ if (checkpointDir.nonEmpty && nodeIdsForInstances.sparkContext.getCheckpointDir.isEmpty) {
+ nodeIdsForInstances.sparkContext.setCheckpointDir(checkpointDir.get)
+ }
+
+ /**
+ * Update the node index values in the cache.
+ * This updates the RDD and its lineage.
+ * TODO: Passing bin information to executors seems unnecessary and costly.
+ * @param data The RDD of training rows.
+ * @param nodeIdUpdaters A map of node index updaters.
+ * The key is the indices of nodes that we want to update.
+ * @param bins Bin information needed to find child node indices.
+ */
+ def updateNodeIndices(
+ data: RDD[BaggedPoint[TreePoint]],
+ nodeIdUpdaters: Array[mutable.Map[Int, NodeIndexUpdater]],
+ bins: Array[Array[Bin]]): Unit = {
+ if (prevNodeIdsForInstances != null) {
+ // Unpersist the previous one if one exists.
+ prevNodeIdsForInstances.unpersist()
+ }
+
+ prevNodeIdsForInstances = nodeIdsForInstances
+ nodeIdsForInstances = data.zip(nodeIdsForInstances).map {
+ dataPoint => {
+ var treeId = 0
+ while (treeId < nodeIdUpdaters.length) {
+ val nodeIdUpdater = nodeIdUpdaters(treeId).getOrElse(dataPoint._2(treeId), null)
+ if (nodeIdUpdater != null) {
+ val newNodeIndex = nodeIdUpdater.updateNodeIndex(
+ binnedFeatures = dataPoint._1.datum.binnedFeatures,
+ bins = bins)
+ dataPoint._2(treeId) = newNodeIndex
+ }
+
+ treeId += 1
+ }
+
+ dataPoint._2
+ }
+ }
+
+ // Keep on persisting new ones.
+ nodeIdsForInstances.persist(StorageLevel.MEMORY_AND_DISK)
+ rddUpdateCount += 1
+
+ // Handle checkpointing if the directory is not None.
+ if (nodeIdsForInstances.sparkContext.getCheckpointDir.nonEmpty &&
+ (rddUpdateCount % checkpointInterval) == 0) {
+ // Let's see if we can delete previous checkpoints.
+ var canDelete = true
+ while (checkpointQueue.size > 1 && canDelete) {
+ // We can delete the oldest checkpoint iff
+ // the next checkpoint actually exists in the file system.
+ if (checkpointQueue.get(1).get.getCheckpointFile != None) {
+ val old = checkpointQueue.dequeue()
+
+ // Since the old checkpoint is not deleted by Spark,
+ // we'll manually delete it here.
+ val fs = FileSystem.get(old.sparkContext.hadoopConfiguration)
+ fs.delete(new Path(old.getCheckpointFile.get), true)
+ } else {
+ canDelete = false
+ }
+ }
+
+ nodeIdsForInstances.checkpoint()
+ checkpointQueue.enqueue(nodeIdsForInstances)
+ }
+ }
+
+ /**
+ * Call this after training is finished to delete any remaining checkpoints.
+ */
+ def deleteAllCheckpoints(): Unit = {
+ while (checkpointQueue.size > 0) {
+ val old = checkpointQueue.dequeue()
+ if (old.getCheckpointFile != None) {
+ val fs = FileSystem.get(old.sparkContext.hadoopConfiguration)
+ fs.delete(new Path(old.getCheckpointFile.get), true)
+ }
+ }
+ }
+}
+
+@DeveloperApi
+private[tree] object NodeIdCache {
+ /**
+ * Initialize the node Id cache with initial node Id values.
+ * @param data The RDD of training rows.
+ * @param numTrees The number of trees that we want to create cache for.
+ * @param checkpointDir The checkpoint directory where the checkpointed files will be stored.
+ * @param checkpointInterval The checkpointing interval
+ * (how often should the cache be checkpointed.).
+ * @param initVal The initial values in the cache.
+ * @return A node Id cache containing an RDD of initial root node Indices.
+ */
+ def init(
+ data: RDD[BaggedPoint[TreePoint]],
+ numTrees: Int,
+ checkpointDir: Option[String],
+ checkpointInterval: Int,
+ initVal: Int = 1): NodeIdCache = {
+ new NodeIdCache(
+ data.map(_ => Array.fill[Int](numTrees)(initVal)),
+ checkpointDir,
+ checkpointInterval)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
index 10c046e07f..73c4393c35 100644
--- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala
@@ -34,18 +34,11 @@ import org.apache.spark.mllib.util.LocalSparkContext
* Test suite for [[RandomForest]].
*/
class RandomForestSuite extends FunSuite with LocalSparkContext {
-
- test("Binary classification with continuous features:" +
- " comparing DecisionTree vs. RandomForest(numTrees = 1)") {
-
+ def binaryClassificationTestWithContinuousFeatures(strategy: Strategy) {
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
val rdd = sc.parallelize(arr)
- val categoricalFeaturesInfo = Map.empty[Int, Int]
val numTrees = 1
- val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
- numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
-
val rf = RandomForest.trainClassifier(rdd, strategy, numTrees = numTrees,
featureSubsetStrategy = "auto", seed = 123)
assert(rf.weakHypotheses.size === 1)
@@ -60,18 +53,27 @@ class RandomForestSuite extends FunSuite with LocalSparkContext {
assert(rfTree.toString == dt.toString)
}
- test("Regression with continuous features:" +
+ test("Binary classification with continuous features:" +
" comparing DecisionTree vs. RandomForest(numTrees = 1)") {
+ val categoricalFeaturesInfo = Map.empty[Int, Int]
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
+ numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
+ binaryClassificationTestWithContinuousFeatures(strategy)
+ }
+ test("Binary classification with continuous features and node Id cache :" +
+ " comparing DecisionTree vs. RandomForest(numTrees = 1)") {
+ val categoricalFeaturesInfo = Map.empty[Int, Int]
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
+ numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, useNodeIdCache = true)
+ binaryClassificationTestWithContinuousFeatures(strategy)
+ }
+
+ def regressionTestWithContinuousFeatures(strategy: Strategy) {
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000)
val rdd = sc.parallelize(arr)
- val categoricalFeaturesInfo = Map.empty[Int, Int]
val numTrees = 1
- val strategy = new Strategy(algo = Regression, impurity = Variance,
- maxDepth = 2, maxBins = 10, numClassesForClassification = 2,
- categoricalFeaturesInfo = categoricalFeaturesInfo)
-
val rf = RandomForest.trainRegressor(rdd, strategy, numTrees = numTrees,
featureSubsetStrategy = "auto", seed = 123)
assert(rf.weakHypotheses.size === 1)
@@ -86,14 +88,28 @@ class RandomForestSuite extends FunSuite with LocalSparkContext {
assert(rfTree.toString == dt.toString)
}
- test("Binary classification with continuous features: subsampling features") {
+ test("Regression with continuous features:" +
+ " comparing DecisionTree vs. RandomForest(numTrees = 1)") {
+ val categoricalFeaturesInfo = Map.empty[Int, Int]
+ val strategy = new Strategy(algo = Regression, impurity = Variance,
+ maxDepth = 2, maxBins = 10, numClassesForClassification = 2,
+ categoricalFeaturesInfo = categoricalFeaturesInfo)
+ regressionTestWithContinuousFeatures(strategy)
+ }
+
+ test("Regression with continuous features and node Id cache :" +
+ " comparing DecisionTree vs. RandomForest(numTrees = 1)") {
+ val categoricalFeaturesInfo = Map.empty[Int, Int]
+ val strategy = new Strategy(algo = Regression, impurity = Variance,
+ maxDepth = 2, maxBins = 10, numClassesForClassification = 2,
+ categoricalFeaturesInfo = categoricalFeaturesInfo, useNodeIdCache = true)
+ regressionTestWithContinuousFeatures(strategy)
+ }
+
+ def binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy: Strategy) {
val numFeatures = 50
val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures, 1000)
val rdd = sc.parallelize(arr)
- val categoricalFeaturesInfo = Map.empty[Int, Int]
-
- val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
- numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
// Select feature subset for top nodes. Return true if OK.
def checkFeatureSubsetStrategy(
@@ -149,6 +165,20 @@ class RandomForestSuite extends FunSuite with LocalSparkContext {
checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt)
}
+ test("Binary classification with continuous features: subsampling features") {
+ val categoricalFeaturesInfo = Map.empty[Int, Int]
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
+ numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo)
+ binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy)
+ }
+
+ test("Binary classification with continuous features and node Id cache: subsampling features") {
+ val categoricalFeaturesInfo = Map.empty[Int, Int]
+ val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2,
+ numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, useNodeIdCache = true)
+ binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy)
+ }
+
test("alternating categorical and continuous features with multiclass labels to test indexing") {
val arr = new Array[LabeledPoint](4)
arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0, 3.0, 1.0))
@@ -164,7 +194,6 @@ class RandomForestSuite extends FunSuite with LocalSparkContext {
featureSubsetStrategy = "sqrt", seed = 12345)
EnsembleTestHelper.validateClassifier(model, arr, 1.0)
}
-
}