From 56f2c61cde3f5d906c2a58e9af1a661222f2c679 Mon Sep 17 00:00:00 2001 From: Sung Chung Date: Sat, 1 Nov 2014 16:58:26 -0700 Subject: [SPARK-3161][MLLIB] Adding a node Id caching mechanism for training deci... ...sion trees. jkbradley mengxr chouqin Please review this. Author: Sung Chung Closes #2868 from codedeft/SPARK-3161 and squashes the following commits: 5f5a156 [Sung Chung] [SPARK-3161][MLLIB] Adding a node Id caching mechanism for training decision trees. --- .../org/apache/spark/mllib/tree/DecisionTree.scala | 114 ++++++++++-- .../org/apache/spark/mllib/tree/RandomForest.scala | 22 ++- .../spark/mllib/tree/configuration/Strategy.scala | 12 +- .../apache/spark/mllib/tree/impl/NodeIdCache.scala | 204 +++++++++++++++++++++ .../spark/mllib/tree/RandomForestSuite.scala | 69 +++++-- 5 files changed, 382 insertions(+), 39 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala index 752ed59a03..78acc17f90 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/DecisionTree.scala @@ -437,6 +437,11 @@ object DecisionTree extends Serializable with Logging { * @param bins possible bins for all features, indexed (numFeatures)(numBins) * @param nodeQueue Queue of nodes to split, with values (treeIndex, node). * Updated with new non-leaf nodes which are created. + * @param nodeIdCache Node Id cache containing an RDD of Array[Int] where + * each value in the array is the data point's node Id + * for a corresponding tree. This is used to prevent the need + * to pass the entire tree to the executors during + * the node stat aggregation phase. */ private[tree] def findBestSplits( input: RDD[BaggedPoint[TreePoint]], @@ -447,7 +452,8 @@ object DecisionTree extends Serializable with Logging { splits: Array[Array[Split]], bins: Array[Array[Bin]], nodeQueue: mutable.Queue[(Int, Node)], - timer: TimeTracker = new TimeTracker): Unit = { + timer: TimeTracker = new TimeTracker, + nodeIdCache: Option[NodeIdCache] = None): Unit = { /* * The high-level descriptions of the best split optimizations are noted here. @@ -479,6 +485,37 @@ object DecisionTree extends Serializable with Logging { logDebug("isMulticlass = " + metadata.isMulticlass) logDebug("isMulticlassWithCategoricalFeatures = " + metadata.isMulticlassWithCategoricalFeatures) + logDebug("using nodeIdCache = " + nodeIdCache.nonEmpty.toString) + + /** + * Performs a sequential aggregation over a partition for a particular tree and node. + * + * For each feature, the aggregate sufficient statistics are updated for the relevant + * bins. + * + * @param treeIndex Index of the tree that we want to perform aggregation for. + * @param nodeInfo The node info for the tree node. + * @param agg Array storing aggregate calculation, with a set of sufficient statistics + * for each (node, feature, bin). + * @param baggedPoint Data point being aggregated. + */ + def nodeBinSeqOp( + treeIndex: Int, + nodeInfo: RandomForest.NodeIndexInfo, + agg: Array[DTStatsAggregator], + baggedPoint: BaggedPoint[TreePoint]): Unit = { + if (nodeInfo != null) { + val aggNodeIndex = nodeInfo.nodeIndexInGroup + val featuresForNode = nodeInfo.featureSubset + val instanceWeight = baggedPoint.subsampleWeights(treeIndex) + if (metadata.unorderedFeatures.isEmpty) { + orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode) + } else { + mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures, + instanceWeight, featuresForNode) + } + } + } /** * Performs a sequential aggregation over a partition. @@ -497,20 +534,25 @@ object DecisionTree extends Serializable with Logging { treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) => val nodeIndex = predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures, bins, metadata.unorderedFeatures) - val nodeInfo = nodeIndexToInfo.getOrElse(nodeIndex, null) - // If the example does not reach a node in this group, then nodeIndex = null. - if (nodeInfo != null) { - val aggNodeIndex = nodeInfo.nodeIndexInGroup - val featuresForNode = nodeInfo.featureSubset - val instanceWeight = baggedPoint.subsampleWeights(treeIndex) - if (metadata.unorderedFeatures.isEmpty) { - orderedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, instanceWeight, featuresForNode) - } else { - mixedBinSeqOp(agg(aggNodeIndex), baggedPoint.datum, bins, metadata.unorderedFeatures, - instanceWeight, featuresForNode) - } - } + nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint) + } + + agg + } + + /** + * Do the same thing as binSeqOp, but with nodeIdCache. + */ + def binSeqOpWithNodeIdCache( + agg: Array[DTStatsAggregator], + dataPoint: (BaggedPoint[TreePoint], Array[Int])): Array[DTStatsAggregator] = { + treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) => + val baggedPoint = dataPoint._1 + val nodeIdCache = dataPoint._2 + val nodeIndex = nodeIdCache(treeIndex) + nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint) } + agg } @@ -553,7 +595,26 @@ object DecisionTree extends Serializable with Logging { // Finally, only best Splits for nodes are collected to driver to construct decision tree. val nodeToFeatures = getNodeToFeatures(treeToNodeToIndexInfo) val nodeToFeaturesBc = input.sparkContext.broadcast(nodeToFeatures) - val nodeToBestSplits = + + val partitionAggregates : RDD[(Int, DTStatsAggregator)] = if (nodeIdCache.nonEmpty) { + input.zip(nodeIdCache.get.nodeIdsForInstances).mapPartitions { points => + // Construct a nodeStatsAggregators array to hold node aggregate stats, + // each node will have a nodeStatsAggregator + val nodeStatsAggregators = Array.tabulate(numNodes) { nodeIndex => + val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => + Some(nodeToFeatures(nodeIndex)) + } + new DTStatsAggregator(metadata, featuresForNode) + } + + // iterator all instances in current partition and update aggregate stats + points.foreach(binSeqOpWithNodeIdCache(nodeStatsAggregators, _)) + + // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs, + // which can be combined with other partition using `reduceByKey` + nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator + } + } else { input.mapPartitions { points => // Construct a nodeStatsAggregators array to hold node aggregate stats, // each node will have a nodeStatsAggregator @@ -570,7 +631,10 @@ object DecisionTree extends Serializable with Logging { // transform nodeStatsAggregators array to (nodeIndex, nodeAggregateStats) pairs, // which can be combined with other partition using `reduceByKey` nodeStatsAggregators.view.zipWithIndex.map(_.swap).iterator - }.reduceByKey((a, b) => a.merge(b)) + } + } + + val nodeToBestSplits = partitionAggregates.reduceByKey((a, b) => a.merge(b)) .map { case (nodeIndex, aggStats) => val featuresForNode = nodeToFeaturesBc.value.flatMap { nodeToFeatures => Some(nodeToFeatures(nodeIndex)) @@ -584,6 +648,13 @@ object DecisionTree extends Serializable with Logging { timer.stop("chooseSplits") + val nodeIdUpdaters = if (nodeIdCache.nonEmpty) { + Array.fill[mutable.Map[Int, NodeIndexUpdater]]( + metadata.numTrees)(mutable.Map[Int, NodeIndexUpdater]()) + } else { + null + } + // Iterate over all nodes in this group. nodesForGroup.foreach { case (treeIndex, nodesForTree) => nodesForTree.foreach { node => @@ -613,6 +684,13 @@ object DecisionTree extends Serializable with Logging { node.rightNode = Some(Node(Node.rightChildIndex(nodeIndex), stats.rightPredict, stats.rightImpurity, rightChildIsLeaf)) + if (nodeIdCache.nonEmpty) { + val nodeIndexUpdater = NodeIndexUpdater( + split = split, + nodeIndex = nodeIndex) + nodeIdUpdaters(treeIndex).put(nodeIndex, nodeIndexUpdater) + } + // enqueue left child and right child if they are not leaves if (!leftChildIsLeaf) { nodeQueue.enqueue((treeIndex, node.leftNode.get)) @@ -629,6 +707,10 @@ object DecisionTree extends Serializable with Logging { } } + if (nodeIdCache.nonEmpty) { + // Update the cache if needed. + nodeIdCache.get.updateNodeIndices(input, nodeIdUpdaters, bins) + } } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala index 1dcaf91438..9683916d9b 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala @@ -28,7 +28,7 @@ import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ import org.apache.spark.mllib.tree.configuration.EnsembleCombiningStrategy.Average import org.apache.spark.mllib.tree.configuration.Strategy -import org.apache.spark.mllib.tree.impl.{BaggedPoint, TreePoint, DecisionTreeMetadata, TimeTracker} +import org.apache.spark.mllib.tree.impl.{BaggedPoint, TreePoint, DecisionTreeMetadata, TimeTracker, NodeIdCache } import org.apache.spark.mllib.tree.impurity.Impurities import org.apache.spark.mllib.tree.model._ import org.apache.spark.rdd.RDD @@ -160,6 +160,19 @@ private class RandomForest ( * in lower levels). */ + // Create an RDD of node Id cache. + // At first, all the rows belong to the root nodes (node Id == 1). + val nodeIdCache = if (strategy.useNodeIdCache) { + Some(NodeIdCache.init( + data = baggedInput, + numTrees = numTrees, + checkpointDir = strategy.checkpointDir, + checkpointInterval = strategy.checkpointInterval, + initVal = 1)) + } else { + None + } + // FIFO queue of nodes to train: (treeIndex, node) val nodeQueue = new mutable.Queue[(Int, Node)]() @@ -182,7 +195,7 @@ private class RandomForest ( // Choose node splits, and enqueue new nodes as needed. timer.start("findBestSplits") DecisionTree.findBestSplits(baggedInput, metadata, topNodes, nodesForGroup, - treeToNodeToIndexInfo, splits, bins, nodeQueue, timer) + treeToNodeToIndexInfo, splits, bins, nodeQueue, timer, nodeIdCache = nodeIdCache) timer.stop("findBestSplits") } @@ -193,6 +206,11 @@ private class RandomForest ( logInfo("Internal timing for DecisionTree:") logInfo(s"$timer") + // Delete any remaining checkpoints used for node Id cache. + if (nodeIdCache.nonEmpty) { + nodeIdCache.get.deleteAllCheckpoints() + } + val trees = topNodes.map(topNode => new DecisionTreeModel(topNode, strategy.algo)) val treeWeights = Array.fill[Double](numTrees)(1.0) new WeightedEnsembleModel(trees, treeWeights, strategy.algo, Average) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala index 2ed63cf002..d09295c507 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala @@ -60,6 +60,13 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ * @param maxMemoryInMB Maximum memory in MB allocated to histogram aggregation. Default value is * 256 MB. * @param subsamplingRate Fraction of the training data used for learning decision tree. + * @param useNodeIdCache If this is true, instead of passing trees to executors, the algorithm will + * maintain a separate RDD of node Id cache for each row. + * @param checkpointDir If the node Id cache is used, it will help to checkpoint + * the node Id cache periodically. This is the checkpoint directory + * to be used for the node Id cache. + * @param checkpointInterval How often to checkpoint when the node Id cache gets updated. + * E.g. 10 means that the cache will get checkpointed every 10 updates. */ @Experimental class Strategy ( @@ -73,7 +80,10 @@ class Strategy ( @BeanProperty var minInstancesPerNode: Int = 1, @BeanProperty var minInfoGain: Double = 0.0, @BeanProperty var maxMemoryInMB: Int = 256, - @BeanProperty var subsamplingRate: Double = 1) extends Serializable { + @BeanProperty var subsamplingRate: Double = 1, + @BeanProperty var useNodeIdCache: Boolean = false, + @BeanProperty var checkpointDir: Option[String] = None, + @BeanProperty var checkpointInterval: Int = 10) extends Serializable { if (algo == Classification) { require(numClassesForClassification >= 2) diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala new file mode 100644 index 0000000000..83011b48b7 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala @@ -0,0 +1,204 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.mllib.tree.impl + +import scala.collection.mutable + +import org.apache.hadoop.fs.{Path, FileSystem} + +import org.apache.spark.rdd.RDD +import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.tree.configuration.FeatureType._ +import org.apache.spark.storage.StorageLevel +import org.apache.spark.mllib.tree.model.{Bin, Node, Split} + +/** + * :: DeveloperApi :: + * This is used by the node id cache to find the child id that a data point would belong to. + * @param split Split information. + * @param nodeIndex The current node index of a data point that this will update. + */ +@DeveloperApi +private[tree] case class NodeIndexUpdater( + split: Split, + nodeIndex: Int) { + /** + * Determine a child node index based on the feature value and the split. + * @param binnedFeatures Binned feature values. + * @param bins Bin information to convert the bin indices to approximate feature values. + * @return Child node index to update to. + */ + def updateNodeIndex(binnedFeatures: Array[Int], bins: Array[Array[Bin]]): Int = { + if (split.featureType == Continuous) { + val featureIndex = split.feature + val binIndex = binnedFeatures(featureIndex) + val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold + if (featureValueUpperBound <= split.threshold) { + Node.leftChildIndex(nodeIndex) + } else { + Node.rightChildIndex(nodeIndex) + } + } else { + if (split.categories.contains(binnedFeatures(split.feature).toDouble)) { + Node.leftChildIndex(nodeIndex) + } else { + Node.rightChildIndex(nodeIndex) + } + } + } +} + +/** + * :: DeveloperApi :: + * A given TreePoint would belong to a particular node per tree. + * Each row in the nodeIdsForInstances RDD is an array over trees of the node index + * in each tree. Initially, values should all be 1 for root node. + * The nodeIdsForInstances RDD needs to be updated at each iteration. + * @param nodeIdsForInstances The initial values in the cache + * (should be an Array of all 1's (meaning the root nodes)). + * @param checkpointDir The checkpoint directory where + * the checkpointed files will be stored. + * @param checkpointInterval The checkpointing interval + * (how often should the cache be checkpointed.). + */ +@DeveloperApi +private[tree] class NodeIdCache( + var nodeIdsForInstances: RDD[Array[Int]], + val checkpointDir: Option[String], + val checkpointInterval: Int) { + + // Keep a reference to a previous node Ids for instances. + // Because we will keep on re-persisting updated node Ids, + // we want to unpersist the previous RDD. + private var prevNodeIdsForInstances: RDD[Array[Int]] = null + + // To keep track of the past checkpointed RDDs. + private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]() + private var rddUpdateCount = 0 + + // If a checkpoint directory is given, and there's no prior checkpoint directory, + // then set the checkpoint directory with the given one. + if (checkpointDir.nonEmpty && nodeIdsForInstances.sparkContext.getCheckpointDir.isEmpty) { + nodeIdsForInstances.sparkContext.setCheckpointDir(checkpointDir.get) + } + + /** + * Update the node index values in the cache. + * This updates the RDD and its lineage. + * TODO: Passing bin information to executors seems unnecessary and costly. + * @param data The RDD of training rows. + * @param nodeIdUpdaters A map of node index updaters. + * The key is the indices of nodes that we want to update. + * @param bins Bin information needed to find child node indices. + */ + def updateNodeIndices( + data: RDD[BaggedPoint[TreePoint]], + nodeIdUpdaters: Array[mutable.Map[Int, NodeIndexUpdater]], + bins: Array[Array[Bin]]): Unit = { + if (prevNodeIdsForInstances != null) { + // Unpersist the previous one if one exists. + prevNodeIdsForInstances.unpersist() + } + + prevNodeIdsForInstances = nodeIdsForInstances + nodeIdsForInstances = data.zip(nodeIdsForInstances).map { + dataPoint => { + var treeId = 0 + while (treeId < nodeIdUpdaters.length) { + val nodeIdUpdater = nodeIdUpdaters(treeId).getOrElse(dataPoint._2(treeId), null) + if (nodeIdUpdater != null) { + val newNodeIndex = nodeIdUpdater.updateNodeIndex( + binnedFeatures = dataPoint._1.datum.binnedFeatures, + bins = bins) + dataPoint._2(treeId) = newNodeIndex + } + + treeId += 1 + } + + dataPoint._2 + } + } + + // Keep on persisting new ones. + nodeIdsForInstances.persist(StorageLevel.MEMORY_AND_DISK) + rddUpdateCount += 1 + + // Handle checkpointing if the directory is not None. + if (nodeIdsForInstances.sparkContext.getCheckpointDir.nonEmpty && + (rddUpdateCount % checkpointInterval) == 0) { + // Let's see if we can delete previous checkpoints. + var canDelete = true + while (checkpointQueue.size > 1 && canDelete) { + // We can delete the oldest checkpoint iff + // the next checkpoint actually exists in the file system. + if (checkpointQueue.get(1).get.getCheckpointFile != None) { + val old = checkpointQueue.dequeue() + + // Since the old checkpoint is not deleted by Spark, + // we'll manually delete it here. + val fs = FileSystem.get(old.sparkContext.hadoopConfiguration) + fs.delete(new Path(old.getCheckpointFile.get), true) + } else { + canDelete = false + } + } + + nodeIdsForInstances.checkpoint() + checkpointQueue.enqueue(nodeIdsForInstances) + } + } + + /** + * Call this after training is finished to delete any remaining checkpoints. + */ + def deleteAllCheckpoints(): Unit = { + while (checkpointQueue.size > 0) { + val old = checkpointQueue.dequeue() + if (old.getCheckpointFile != None) { + val fs = FileSystem.get(old.sparkContext.hadoopConfiguration) + fs.delete(new Path(old.getCheckpointFile.get), true) + } + } + } +} + +@DeveloperApi +private[tree] object NodeIdCache { + /** + * Initialize the node Id cache with initial node Id values. + * @param data The RDD of training rows. + * @param numTrees The number of trees that we want to create cache for. + * @param checkpointDir The checkpoint directory where the checkpointed files will be stored. + * @param checkpointInterval The checkpointing interval + * (how often should the cache be checkpointed.). + * @param initVal The initial values in the cache. + * @return A node Id cache containing an RDD of initial root node Indices. + */ + def init( + data: RDD[BaggedPoint[TreePoint]], + numTrees: Int, + checkpointDir: Option[String], + checkpointInterval: Int, + initVal: Int = 1): NodeIdCache = { + new NodeIdCache( + data.map(_ => Array.fill[Int](numTrees)(initVal)), + checkpointDir, + checkpointInterval) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala index 10c046e07f..73c4393c35 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/RandomForestSuite.scala @@ -34,18 +34,11 @@ import org.apache.spark.mllib.util.LocalSparkContext * Test suite for [[RandomForest]]. */ class RandomForestSuite extends FunSuite with LocalSparkContext { - - test("Binary classification with continuous features:" + - " comparing DecisionTree vs. RandomForest(numTrees = 1)") { - + def binaryClassificationTestWithContinuousFeatures(strategy: Strategy) { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) val rdd = sc.parallelize(arr) - val categoricalFeaturesInfo = Map.empty[Int, Int] val numTrees = 1 - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, - numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo) - val rf = RandomForest.trainClassifier(rdd, strategy, numTrees = numTrees, featureSubsetStrategy = "auto", seed = 123) assert(rf.weakHypotheses.size === 1) @@ -60,18 +53,27 @@ class RandomForestSuite extends FunSuite with LocalSparkContext { assert(rfTree.toString == dt.toString) } - test("Regression with continuous features:" + + test("Binary classification with continuous features:" + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { + val categoricalFeaturesInfo = Map.empty[Int, Int] + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, + numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo) + binaryClassificationTestWithContinuousFeatures(strategy) + } + test("Binary classification with continuous features and node Id cache :" + + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { + val categoricalFeaturesInfo = Map.empty[Int, Int] + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, + numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, useNodeIdCache = true) + binaryClassificationTestWithContinuousFeatures(strategy) + } + + def regressionTestWithContinuousFeatures(strategy: Strategy) { val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures = 50, 1000) val rdd = sc.parallelize(arr) - val categoricalFeaturesInfo = Map.empty[Int, Int] val numTrees = 1 - val strategy = new Strategy(algo = Regression, impurity = Variance, - maxDepth = 2, maxBins = 10, numClassesForClassification = 2, - categoricalFeaturesInfo = categoricalFeaturesInfo) - val rf = RandomForest.trainRegressor(rdd, strategy, numTrees = numTrees, featureSubsetStrategy = "auto", seed = 123) assert(rf.weakHypotheses.size === 1) @@ -86,14 +88,28 @@ class RandomForestSuite extends FunSuite with LocalSparkContext { assert(rfTree.toString == dt.toString) } - test("Binary classification with continuous features: subsampling features") { + test("Regression with continuous features:" + + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { + val categoricalFeaturesInfo = Map.empty[Int, Int] + val strategy = new Strategy(algo = Regression, impurity = Variance, + maxDepth = 2, maxBins = 10, numClassesForClassification = 2, + categoricalFeaturesInfo = categoricalFeaturesInfo) + regressionTestWithContinuousFeatures(strategy) + } + + test("Regression with continuous features and node Id cache :" + + " comparing DecisionTree vs. RandomForest(numTrees = 1)") { + val categoricalFeaturesInfo = Map.empty[Int, Int] + val strategy = new Strategy(algo = Regression, impurity = Variance, + maxDepth = 2, maxBins = 10, numClassesForClassification = 2, + categoricalFeaturesInfo = categoricalFeaturesInfo, useNodeIdCache = true) + regressionTestWithContinuousFeatures(strategy) + } + + def binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy: Strategy) { val numFeatures = 50 val arr = EnsembleTestHelper.generateOrderedLabeledPoints(numFeatures, 1000) val rdd = sc.parallelize(arr) - val categoricalFeaturesInfo = Map.empty[Int, Int] - - val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, - numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo) // Select feature subset for top nodes. Return true if OK. def checkFeatureSubsetStrategy( @@ -149,6 +165,20 @@ class RandomForestSuite extends FunSuite with LocalSparkContext { checkFeatureSubsetStrategy(numTrees = 2, "onethird", (numFeatures / 3.0).ceil.toInt) } + test("Binary classification with continuous features: subsampling features") { + val categoricalFeaturesInfo = Map.empty[Int, Int] + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, + numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo) + binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy) + } + + test("Binary classification with continuous features and node Id cache: subsampling features") { + val categoricalFeaturesInfo = Map.empty[Int, Int] + val strategy = new Strategy(algo = Classification, impurity = Gini, maxDepth = 2, + numClassesForClassification = 2, categoricalFeaturesInfo = categoricalFeaturesInfo, useNodeIdCache = true) + binaryClassificationTestWithContinuousFeaturesAndSubsampledFeatures(strategy) + } + test("alternating categorical and continuous features with multiclass labels to test indexing") { val arr = new Array[LabeledPoint](4) arr(0) = new LabeledPoint(0.0, Vectors.dense(1.0, 0.0, 0.0, 3.0, 1.0)) @@ -164,7 +194,6 @@ class RandomForestSuite extends FunSuite with LocalSparkContext { featureSubsetStrategy = "sqrt", seed = 12345) EnsembleTestHelper.validateClassifier(model, arr, 1.0) } - } -- cgit v1.2.3