From 4fc35e6f5c590feb47cbcb5b1136f2e985677b3f Mon Sep 17 00:00:00 2001 From: sethah Date: Fri, 1 Apr 2016 21:23:35 -0700 Subject: [SPARK-14308][ML][MLLIB] Remove unused mllib tree classes and move private classes to ML ## What changes were proposed in this pull request? Decision tree helper classes will be migrated to ML. This patch moves those internal classes that are not part of the public API and removes ones that are no longer used, after [SPARK-12183](https://github.com/apache/spark/pull/11855). No functional changes are made. Details: * Bin.scala is removed as the ML implementation does not require bins * mllib NodeIdCache is removed. It was only used by the mllib implementation previously, which no longer exists * mllib TreePoint is removed. It was only used by the mllib implementation previously, which no longer exists * BaggedPoint, DTStatsAggregator, DecisionTreeMetadata, BaggedPointSuite and TimeTracker are all moved to ML. ## How was this patch tested? No functional changes are made. Existing unit tests ensure behavior is unchanged. Author: sethah Closes #12097 from sethah/cleanup_mllib_tree. --- .../apache/spark/ml/tree/impl/BaggedPoint.scala | 125 ++++++++++++ .../spark/ml/tree/impl/DTStatsAggregator.scala | 181 +++++++++++++++++ .../spark/ml/tree/impl/DecisionTreeMetadata.scala | 217 +++++++++++++++++++++ .../spark/ml/tree/impl/GradientBoostedTrees.scala | 1 - .../apache/spark/ml/tree/impl/NodeIdCache.scala | 1 - .../apache/spark/ml/tree/impl/RandomForest.scala | 4 +- .../apache/spark/ml/tree/impl/TimeTracker.scala | 70 +++++++ .../org/apache/spark/ml/tree/impl/TreePoint.scala | 1 - .../spark/mllib/tree/GradientBoostedTrees.scala | 3 +- .../apache/spark/mllib/tree/impl/BaggedPoint.scala | 125 ------------ .../spark/mllib/tree/impl/DTStatsAggregator.scala | 178 ----------------- .../mllib/tree/impl/DecisionTreeMetadata.scala | 217 --------------------- .../apache/spark/mllib/tree/impl/NodeIdCache.scala | 195 ------------------ .../apache/spark/mllib/tree/impl/TimeTracker.scala | 70 ------- .../apache/spark/mllib/tree/impl/TreePoint.scala | 150 -------------- .../apache/spark/mllib/tree/impurity/Entropy.scala | 2 +- .../apache/spark/mllib/tree/impurity/Gini.scala | 2 +- .../spark/mllib/tree/impurity/Variance.scala | 2 +- .../org/apache/spark/mllib/tree/model/Bin.scala | 47 ----- .../spark/ml/tree/impl/BaggedPointSuite.scala | 99 ++++++++++ .../spark/ml/tree/impl/RandomForestSuite.scala | 1 - .../spark/mllib/tree/DecisionTreeSuite.scala | 2 +- .../spark/mllib/tree/impl/BaggedPointSuite.scala | 99 ---------- 23 files changed, 699 insertions(+), 1093 deletions(-) create mode 100644 mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala create mode 100644 mllib/src/main/scala/org/apache/spark/ml/tree/impl/TimeTracker.scala delete mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala delete mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala delete mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala delete mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala delete mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala delete mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala delete mode 100644 mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala create mode 100644 mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala delete mode 100644 mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala (limited to 'mllib') diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala new file mode 100644 index 0000000000..4e372702f0 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/BaggedPoint.scala @@ -0,0 +1,125 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.commons.math3.distribution.PoissonDistribution + +import org.apache.spark.rdd.RDD +import org.apache.spark.util.Utils +import org.apache.spark.util.random.XORShiftRandom + +/** + * Internal representation of a datapoint which belongs to several subsamples of the same dataset, + * particularly for bagging (e.g., for random forests). + * + * This holds one instance, as well as an array of weights which represent the (weighted) + * number of times which this instance appears in each subsamplingRate. + * E.g., (datum, [1, 0, 4]) indicates that there are 3 subsamples of the dataset and that + * this datum has 1 copy, 0 copies, and 4 copies in the 3 subsamples, respectively. + * + * @param datum Data instance + * @param subsampleWeights Weight of this instance in each subsampled dataset. + * + * TODO: This does not currently support (Double) weighted instances. Once MLlib has weighted + * dataset support, update. (We store subsampleWeights as Double for this future extension.) + */ +private[spark] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double]) + extends Serializable + +private[spark] object BaggedPoint { + + /** + * Convert an input dataset into its BaggedPoint representation, + * choosing subsamplingRate counts for each instance. + * Each subsamplingRate has the same number of instances as the original dataset, + * and is created by subsampling without replacement. + * @param input Input dataset. + * @param subsamplingRate Fraction of the training data used for learning decision tree. + * @param numSubsamples Number of subsamples of this RDD to take. + * @param withReplacement Sampling with/without replacement. + * @param seed Random seed. + * @return BaggedPoint dataset representation. + */ + def convertToBaggedRDD[Datum] ( + input: RDD[Datum], + subsamplingRate: Double, + numSubsamples: Int, + withReplacement: Boolean, + seed: Long = Utils.random.nextLong()): RDD[BaggedPoint[Datum]] = { + if (withReplacement) { + convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples, seed) + } else { + if (numSubsamples == 1 && subsamplingRate == 1.0) { + convertToBaggedRDDWithoutSampling(input) + } else { + convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples, seed) + } + } + } + + private def convertToBaggedRDDSamplingWithoutReplacement[Datum] ( + input: RDD[Datum], + subsamplingRate: Double, + numSubsamples: Int, + seed: Long): RDD[BaggedPoint[Datum]] = { + input.mapPartitionsWithIndex { (partitionIndex, instances) => + // Use random seed = seed + partitionIndex + 1 to make generation reproducible. + val rng = new XORShiftRandom + rng.setSeed(seed + partitionIndex + 1) + instances.map { instance => + val subsampleWeights = new Array[Double](numSubsamples) + var subsampleIndex = 0 + while (subsampleIndex < numSubsamples) { + val x = rng.nextDouble() + subsampleWeights(subsampleIndex) = { + if (x < subsamplingRate) 1.0 else 0.0 + } + subsampleIndex += 1 + } + new BaggedPoint(instance, subsampleWeights) + } + } + } + + private def convertToBaggedRDDSamplingWithReplacement[Datum] ( + input: RDD[Datum], + subsample: Double, + numSubsamples: Int, + seed: Long): RDD[BaggedPoint[Datum]] = { + input.mapPartitionsWithIndex { (partitionIndex, instances) => + // Use random seed = seed + partitionIndex + 1 to make generation reproducible. + val poisson = new PoissonDistribution(subsample) + poisson.reseedRandomGenerator(seed + partitionIndex + 1) + instances.map { instance => + val subsampleWeights = new Array[Double](numSubsamples) + var subsampleIndex = 0 + while (subsampleIndex < numSubsamples) { + subsampleWeights(subsampleIndex) = poisson.sample() + subsampleIndex += 1 + } + new BaggedPoint(instance, subsampleWeights) + } + } + } + + private def convertToBaggedRDDWithoutSampling[Datum] ( + input: RDD[Datum]): RDD[BaggedPoint[Datum]] = { + input.map(datum => new BaggedPoint(datum, Array(1.0))) + } + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala new file mode 100644 index 0000000000..61091bb803 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DTStatsAggregator.scala @@ -0,0 +1,181 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.mllib.tree.impurity._ + + + +/** + * DecisionTree statistics aggregator for a node. + * This holds a flat array of statistics for a set of (features, bins) + * and helps with indexing. + * This class is abstract to support learning with and without feature subsampling. + */ +private[spark] class DTStatsAggregator( + val metadata: DecisionTreeMetadata, + featureSubset: Option[Array[Int]]) extends Serializable { + + /** + * [[ImpurityAggregator]] instance specifying the impurity type. + */ + val impurityAggregator: ImpurityAggregator = metadata.impurity match { + case Gini => new GiniAggregator(metadata.numClasses) + case Entropy => new EntropyAggregator(metadata.numClasses) + case Variance => new VarianceAggregator() + case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}") + } + + /** + * Number of elements (Double values) used for the sufficient statistics of each bin. + */ + private val statsSize: Int = impurityAggregator.statsSize + + /** + * Number of bins for each feature. This is indexed by the feature index. + */ + private val numBins: Array[Int] = { + if (featureSubset.isDefined) { + featureSubset.get.map(metadata.numBins(_)) + } else { + metadata.numBins + } + } + + /** + * Offset for each feature for calculating indices into the [[allStats]] array. + */ + private val featureOffsets: Array[Int] = { + numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins) + } + + /** + * Total number of elements stored in this aggregator + */ + private val allStatsSize: Int = featureOffsets.last + + /** + * Flat array of elements. + * Index for start of stats for a (feature, bin) is: + * index = featureOffsets(featureIndex) + binIndex * statsSize + */ + private val allStats: Array[Double] = new Array[Double](allStatsSize) + + /** + * Array of parent node sufficient stats. + * + * Note: this is necessary because stats for the parent node are not available + * on the first iteration of tree learning. + */ + private val parentStats: Array[Double] = new Array[Double](statsSize) + + /** + * Get an [[ImpurityCalculator]] for a given (node, feature, bin). + * + * @param featureOffset This is a pre-computed (node, feature) offset + * from [[getFeatureOffset]]. + */ + def getImpurityCalculator(featureOffset: Int, binIndex: Int): ImpurityCalculator = { + impurityAggregator.getCalculator(allStats, featureOffset + binIndex * statsSize) + } + + /** + * Get an [[ImpurityCalculator]] for the parent node. + */ + def getParentImpurityCalculator(): ImpurityCalculator = { + impurityAggregator.getCalculator(parentStats, 0) + } + + /** + * Update the stats for a given (feature, bin) for ordered features, using the given label. + */ + def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight: Double): Unit = { + val i = featureOffsets(featureIndex) + binIndex * statsSize + impurityAggregator.update(allStats, i, label, instanceWeight) + } + + /** + * Update the parent node stats using the given label. + */ + def updateParent(label: Double, instanceWeight: Double): Unit = { + impurityAggregator.update(parentStats, 0, label, instanceWeight) + } + + /** + * Faster version of [[update]]. + * Update the stats for a given (feature, bin), using the given label. + * + * @param featureOffset This is a pre-computed feature offset + * from [[getFeatureOffset]]. + */ + def featureUpdate( + featureOffset: Int, + binIndex: Int, + label: Double, + instanceWeight: Double): Unit = { + impurityAggregator.update(allStats, featureOffset + binIndex * statsSize, + label, instanceWeight) + } + + /** + * Pre-compute feature offset for use with [[featureUpdate]]. + * For ordered features only. + */ + def getFeatureOffset(featureIndex: Int): Int = featureOffsets(featureIndex) + + /** + * For a given feature, merge the stats for two bins. + * + * @param featureOffset This is a pre-computed feature offset + * from [[getFeatureOffset]]. + * @param binIndex The other bin is merged into this bin. + * @param otherBinIndex This bin is not modified. + */ + def mergeForFeature(featureOffset: Int, binIndex: Int, otherBinIndex: Int): Unit = { + impurityAggregator.merge(allStats, featureOffset + binIndex * statsSize, + featureOffset + otherBinIndex * statsSize) + } + + /** + * Merge this aggregator with another, and returns this aggregator. + * This method modifies this aggregator in-place. + */ + def merge(other: DTStatsAggregator): DTStatsAggregator = { + require(allStatsSize == other.allStatsSize, + s"DTStatsAggregator.merge requires that both aggregators have the same length stats vectors." + + s" This aggregator is of length $allStatsSize, but the other is ${other.allStatsSize}.") + var i = 0 + // TODO: Test BLAS.axpy + while (i < allStatsSize) { + allStats(i) += other.allStats(i) + i += 1 + } + + require(statsSize == other.statsSize, + s"DTStatsAggregator.merge requires that both aggregators have the same length parent " + + s"stats vectors. This aggregator's parent stats are length $statsSize, " + + s"but the other is ${other.statsSize}.") + var j = 0 + while (j < statsSize) { + parentStats(j) += other.parentStats(j) + j += 1 + } + + this + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala new file mode 100644 index 0000000000..df8eb5d1f9 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala @@ -0,0 +1,217 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import scala.collection.mutable + +import org.apache.spark.internal.Logging +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ +import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.impurity.Impurity +import org.apache.spark.rdd.RDD + +/** + * Learning and dataset metadata for DecisionTree. + * + * @param numClasses For classification: labels can take values {0, ..., numClasses - 1}. + * For regression: fixed at 0 (no meaning). + * @param maxBins Maximum number of bins, for all features. + * @param featureArity Map: categorical feature index --> arity. + * I.e., the feature takes values in {0, ..., arity - 1}. + * @param numBins Number of bins for each feature. + */ +private[spark] class DecisionTreeMetadata( + val numFeatures: Int, + val numExamples: Long, + val numClasses: Int, + val maxBins: Int, + val featureArity: Map[Int, Int], + val unorderedFeatures: Set[Int], + val numBins: Array[Int], + val impurity: Impurity, + val quantileStrategy: QuantileStrategy, + val maxDepth: Int, + val minInstancesPerNode: Int, + val minInfoGain: Double, + val numTrees: Int, + val numFeaturesPerNode: Int) extends Serializable { + + def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex) + + def isClassification: Boolean = numClasses >= 2 + + def isMulticlass: Boolean = numClasses > 2 + + def isMulticlassWithCategoricalFeatures: Boolean = isMulticlass && (featureArity.size > 0) + + def isCategorical(featureIndex: Int): Boolean = featureArity.contains(featureIndex) + + def isContinuous(featureIndex: Int): Boolean = !featureArity.contains(featureIndex) + + /** + * Number of splits for the given feature. + * For unordered features, there is 1 bin per split. + * For ordered features, there is 1 more bin than split. + */ + def numSplits(featureIndex: Int): Int = if (isUnordered(featureIndex)) { + numBins(featureIndex) + } else { + numBins(featureIndex) - 1 + } + + + /** + * Set number of splits for a continuous feature. + * For a continuous feature, number of bins is number of splits plus 1. + */ + def setNumSplits(featureIndex: Int, numSplits: Int) { + require(isContinuous(featureIndex), + s"Only number of bin for a continuous feature can be set.") + numBins(featureIndex) = numSplits + 1 + } + + /** + * Indicates if feature subsampling is being used. + */ + def subsamplingFeatures: Boolean = numFeatures != numFeaturesPerNode + +} + +private[spark] object DecisionTreeMetadata extends Logging { + + /** + * Construct a [[DecisionTreeMetadata]] instance for this dataset and parameters. + * This computes which categorical features will be ordered vs. unordered, + * as well as the number of splits and bins for each feature. + */ + def buildMetadata( + input: RDD[LabeledPoint], + strategy: Strategy, + numTrees: Int, + featureSubsetStrategy: String): DecisionTreeMetadata = { + + val numFeatures = input.map(_.features.size).take(1).headOption.getOrElse { + throw new IllegalArgumentException(s"DecisionTree requires size of input RDD > 0, " + + s"but was given by empty one.") + } + val numExamples = input.count() + val numClasses = strategy.algo match { + case Classification => strategy.numClasses + case Regression => 0 + } + + val maxPossibleBins = math.min(strategy.maxBins, numExamples).toInt + if (maxPossibleBins < strategy.maxBins) { + logWarning(s"DecisionTree reducing maxBins from ${strategy.maxBins} to $maxPossibleBins" + + s" (= number of training instances)") + } + + // We check the number of bins here against maxPossibleBins. + // This needs to be checked here instead of in Strategy since maxPossibleBins can be modified + // based on the number of training examples. + if (strategy.categoricalFeaturesInfo.nonEmpty) { + val maxCategoriesPerFeature = strategy.categoricalFeaturesInfo.values.max + val maxCategory = + strategy.categoricalFeaturesInfo.find(_._2 == maxCategoriesPerFeature).get._1 + require(maxCategoriesPerFeature <= maxPossibleBins, + s"DecisionTree requires maxBins (= $maxPossibleBins) to be at least as large as the " + + s"number of values in each categorical feature, but categorical feature $maxCategory " + + s"has $maxCategoriesPerFeature values. Considering remove this and other categorical " + + "features with a large number of values, or add more training examples.") + } + + val unorderedFeatures = new mutable.HashSet[Int]() + val numBins = Array.fill[Int](numFeatures)(maxPossibleBins) + if (numClasses > 2) { + // Multiclass classification + val maxCategoriesForUnorderedFeature = + ((math.log(maxPossibleBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt + strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) => + // Hack: If a categorical feature has only 1 category, we treat it as continuous. + // TODO(SPARK-9957): Handle this properly by filtering out those features. + if (numCategories > 1) { + // Decide if some categorical features should be treated as unordered features, + // which require 2 * ((1 << numCategories - 1) - 1) bins. + // We do this check with log values to prevent overflows in case numCategories is large. + // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins + if (numCategories <= maxCategoriesForUnorderedFeature) { + unorderedFeatures.add(featureIndex) + numBins(featureIndex) = numUnorderedBins(numCategories) + } else { + numBins(featureIndex) = numCategories + } + } + } + } else { + // Binary classification or regression + strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) => + // If a categorical feature has only 1 category, we treat it as continuous: SPARK-9957 + if (numCategories > 1) { + numBins(featureIndex) = numCategories + } + } + } + + // Set number of features to use per node (for random forests). + val _featureSubsetStrategy = featureSubsetStrategy match { + case "auto" => + if (numTrees == 1) { + "all" + } else { + if (strategy.algo == Classification) { + "sqrt" + } else { + "onethird" + } + } + case _ => featureSubsetStrategy + } + val numFeaturesPerNode: Int = _featureSubsetStrategy match { + case "all" => numFeatures + case "sqrt" => math.sqrt(numFeatures).ceil.toInt + case "log2" => math.max(1, (math.log(numFeatures) / math.log(2)).ceil.toInt) + case "onethird" => (numFeatures / 3.0).ceil.toInt + } + + new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max, + strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins, + strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth, + strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode) + } + + /** + * Version of [[DecisionTreeMetadata#buildMetadata]] for DecisionTree. + */ + def buildMetadata( + input: RDD[LabeledPoint], + strategy: Strategy): DecisionTreeMetadata = { + buildMetadata(input, strategy, numTrees = 1, featureSubsetStrategy = "all") + } + + /** + * Given the arity of a categorical feature (arity = number of categories), + * return the number of bins for the feature if it is to be treated as an unordered feature. + * There is 1 split for every partitioning of categories into 2 disjoint, non-empty sets; + * there are math.pow(2, arity - 1) - 1 such splits. + * Each split has 2 corresponding bins. + */ + def numUnorderedBins(arity: Int): Int = (1 << arity - 1) - 1 + +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala index b37f4e891e..0749d93b7d 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/GradientBoostedTrees.scala @@ -25,7 +25,6 @@ import org.apache.spark.mllib.linalg.Vector import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo} import org.apache.spark.mllib.tree.configuration.{BoostingStrategy => OldBoostingStrategy} -import org.apache.spark.mllib.tree.impl.TimeTracker import org.apache.spark.mllib.tree.impurity.{Variance => OldVariance} import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} import org.apache.spark.rdd.RDD diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala index 2c8286766f..9d697a36b6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala @@ -26,7 +26,6 @@ import org.apache.hadoop.fs.{FileSystem, Path} import org.apache.spark.annotation.DeveloperApi import org.apache.spark.internal.Logging import org.apache.spark.ml.tree.{LearningNode, Split} -import org.apache.spark.mllib.tree.impl.BaggedPoint import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index cccf052b3e..7b1fd089f2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -28,8 +28,6 @@ import org.apache.spark.ml.regression.DecisionTreeRegressionModel import org.apache.spark.ml.tree._ import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} -import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata, DTStatsAggregator, - TimeTracker} import org.apache.spark.mllib.tree.impurity.ImpurityCalculator import org.apache.spark.mllib.tree.model.ImpurityStats import org.apache.spark.rdd.RDD @@ -330,7 +328,7 @@ private[spark] object RandomForest extends Logging { /** * Given a group of nodes, this finds the best split for each node. * - * @param input Training data: RDD of [[org.apache.spark.mllib.tree.impl.TreePoint]] + * @param input Training data: RDD of [[org.apache.spark.ml.tree.impl.TreePoint]] * @param metadata Learning and dataset metadata * @param topNodes Root node for each tree. Used for matching instances with nodes. * @param nodesForGroup Mapping: treeIndex --> nodes to be split in tree diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TimeTracker.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TimeTracker.scala new file mode 100644 index 0000000000..4cc250aa46 --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TimeTracker.scala @@ -0,0 +1,70 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import scala.collection.mutable.{HashMap => MutableHashMap} + +/** + * Time tracker implementation which holds labeled timers. + */ +private[spark] class TimeTracker extends Serializable { + + private val starts: MutableHashMap[String, Long] = new MutableHashMap[String, Long]() + + private val totals: MutableHashMap[String, Long] = new MutableHashMap[String, Long]() + + /** + * Starts a new timer, or re-starts a stopped timer. + */ + def start(timerLabel: String): Unit = { + val currentTime = System.nanoTime() + if (starts.contains(timerLabel)) { + throw new RuntimeException(s"TimeTracker.start(timerLabel) called again on" + + s" timerLabel = $timerLabel before that timer was stopped.") + } + starts(timerLabel) = currentTime + } + + /** + * Stops a timer and returns the elapsed time in seconds. + */ + def stop(timerLabel: String): Double = { + val currentTime = System.nanoTime() + if (!starts.contains(timerLabel)) { + throw new RuntimeException(s"TimeTracker.stop(timerLabel) called on" + + s" timerLabel = $timerLabel, but that timer was not started.") + } + val elapsed = currentTime - starts(timerLabel) + starts.remove(timerLabel) + if (totals.contains(timerLabel)) { + totals(timerLabel) += elapsed + } else { + totals(timerLabel) = elapsed + } + elapsed / 1e9 + } + + /** + * Print all timing results in seconds. + */ + override def toString: String = { + totals.map { case (label, elapsed) => + s" $label: ${elapsed / 1e9}" + }.mkString("\n") + } +} diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala index 9fa27e5e1f..3a2bf3c725 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/TreePoint.scala @@ -19,7 +19,6 @@ package org.apache.spark.ml.tree.impl import org.apache.spark.ml.tree.{ContinuousSplit, Split} import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata import org.apache.spark.rdd.RDD diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala index d166dc7905..0f0c6b466d 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/GradientBoostedTrees.scala @@ -20,11 +20,11 @@ package org.apache.spark.mllib.tree import org.apache.spark.annotation.Since import org.apache.spark.api.java.JavaRDD import org.apache.spark.internal.Logging +import org.apache.spark.ml.tree.impl.TimeTracker import org.apache.spark.mllib.impl.PeriodicRDDCheckpointer import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.BoostingStrategy -import org.apache.spark.mllib.tree.impl.TimeTracker import org.apache.spark.mllib.tree.impurity.Variance import org.apache.spark.mllib.tree.model.{DecisionTreeModel, GradientBoostedTreesModel} import org.apache.spark.rdd.RDD @@ -165,6 +165,7 @@ object GradientBoostedTrees extends Logging { /** * Internal method for performing regression using trees as base learners. + * * @param input Training dataset. * @param validationInput Validation dataset, ignored if validate is set to false. * @param boostingStrategy Boosting parameters. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala deleted file mode 100644 index 572815df0b..0000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/BaggedPoint.scala +++ /dev/null @@ -1,125 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree.impl - -import org.apache.commons.math3.distribution.PoissonDistribution - -import org.apache.spark.rdd.RDD -import org.apache.spark.util.Utils -import org.apache.spark.util.random.XORShiftRandom - -/** - * Internal representation of a datapoint which belongs to several subsamples of the same dataset, - * particularly for bagging (e.g., for random forests). - * - * This holds one instance, as well as an array of weights which represent the (weighted) - * number of times which this instance appears in each subsamplingRate. - * E.g., (datum, [1, 0, 4]) indicates that there are 3 subsamples of the dataset and that - * this datum has 1 copy, 0 copies, and 4 copies in the 3 subsamples, respectively. - * - * @param datum Data instance - * @param subsampleWeights Weight of this instance in each subsampled dataset. - * - * TODO: This does not currently support (Double) weighted instances. Once MLlib has weighted - * dataset support, update. (We store subsampleWeights as Double for this future extension.) - */ -private[spark] class BaggedPoint[Datum](val datum: Datum, val subsampleWeights: Array[Double]) - extends Serializable - -private[spark] object BaggedPoint { - - /** - * Convert an input dataset into its BaggedPoint representation, - * choosing subsamplingRate counts for each instance. - * Each subsamplingRate has the same number of instances as the original dataset, - * and is created by subsampling without replacement. - * @param input Input dataset. - * @param subsamplingRate Fraction of the training data used for learning decision tree. - * @param numSubsamples Number of subsamples of this RDD to take. - * @param withReplacement Sampling with/without replacement. - * @param seed Random seed. - * @return BaggedPoint dataset representation. - */ - def convertToBaggedRDD[Datum] ( - input: RDD[Datum], - subsamplingRate: Double, - numSubsamples: Int, - withReplacement: Boolean, - seed: Long = Utils.random.nextLong()): RDD[BaggedPoint[Datum]] = { - if (withReplacement) { - convertToBaggedRDDSamplingWithReplacement(input, subsamplingRate, numSubsamples, seed) - } else { - if (numSubsamples == 1 && subsamplingRate == 1.0) { - convertToBaggedRDDWithoutSampling(input) - } else { - convertToBaggedRDDSamplingWithoutReplacement(input, subsamplingRate, numSubsamples, seed) - } - } - } - - private def convertToBaggedRDDSamplingWithoutReplacement[Datum] ( - input: RDD[Datum], - subsamplingRate: Double, - numSubsamples: Int, - seed: Long): RDD[BaggedPoint[Datum]] = { - input.mapPartitionsWithIndex { (partitionIndex, instances) => - // Use random seed = seed + partitionIndex + 1 to make generation reproducible. - val rng = new XORShiftRandom - rng.setSeed(seed + partitionIndex + 1) - instances.map { instance => - val subsampleWeights = new Array[Double](numSubsamples) - var subsampleIndex = 0 - while (subsampleIndex < numSubsamples) { - val x = rng.nextDouble() - subsampleWeights(subsampleIndex) = { - if (x < subsamplingRate) 1.0 else 0.0 - } - subsampleIndex += 1 - } - new BaggedPoint(instance, subsampleWeights) - } - } - } - - private def convertToBaggedRDDSamplingWithReplacement[Datum] ( - input: RDD[Datum], - subsample: Double, - numSubsamples: Int, - seed: Long): RDD[BaggedPoint[Datum]] = { - input.mapPartitionsWithIndex { (partitionIndex, instances) => - // Use random seed = seed + partitionIndex + 1 to make generation reproducible. - val poisson = new PoissonDistribution(subsample) - poisson.reseedRandomGenerator(seed + partitionIndex + 1) - instances.map { instance => - val subsampleWeights = new Array[Double](numSubsamples) - var subsampleIndex = 0 - while (subsampleIndex < numSubsamples) { - subsampleWeights(subsampleIndex) = poisson.sample() - subsampleIndex += 1 - } - new BaggedPoint(instance, subsampleWeights) - } - } - } - - private def convertToBaggedRDDWithoutSampling[Datum] ( - input: RDD[Datum]): RDD[BaggedPoint[Datum]] = { - input.map(datum => new BaggedPoint(datum, Array(1.0))) - } - -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala deleted file mode 100644 index c745e9f8db..0000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DTStatsAggregator.scala +++ /dev/null @@ -1,178 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree.impl - -import org.apache.spark.mllib.tree.impurity._ - - - -/** - * DecisionTree statistics aggregator for a node. - * This holds a flat array of statistics for a set of (features, bins) - * and helps with indexing. - * This class is abstract to support learning with and without feature subsampling. - */ -private[spark] class DTStatsAggregator( - val metadata: DecisionTreeMetadata, - featureSubset: Option[Array[Int]]) extends Serializable { - - /** - * [[ImpurityAggregator]] instance specifying the impurity type. - */ - val impurityAggregator: ImpurityAggregator = metadata.impurity match { - case Gini => new GiniAggregator(metadata.numClasses) - case Entropy => new EntropyAggregator(metadata.numClasses) - case Variance => new VarianceAggregator() - case _ => throw new IllegalArgumentException(s"Bad impurity parameter: ${metadata.impurity}") - } - - /** - * Number of elements (Double values) used for the sufficient statistics of each bin. - */ - private val statsSize: Int = impurityAggregator.statsSize - - /** - * Number of bins for each feature. This is indexed by the feature index. - */ - private val numBins: Array[Int] = { - if (featureSubset.isDefined) { - featureSubset.get.map(metadata.numBins(_)) - } else { - metadata.numBins - } - } - - /** - * Offset for each feature for calculating indices into the [[allStats]] array. - */ - private val featureOffsets: Array[Int] = { - numBins.scanLeft(0)((total, nBins) => total + statsSize * nBins) - } - - /** - * Total number of elements stored in this aggregator - */ - private val allStatsSize: Int = featureOffsets.last - - /** - * Flat array of elements. - * Index for start of stats for a (feature, bin) is: - * index = featureOffsets(featureIndex) + binIndex * statsSize - */ - private val allStats: Array[Double] = new Array[Double](allStatsSize) - - /** - * Array of parent node sufficient stats. - * - * Note: this is necessary because stats for the parent node are not available - * on the first iteration of tree learning. - */ - private val parentStats: Array[Double] = new Array[Double](statsSize) - - /** - * Get an [[ImpurityCalculator]] for a given (node, feature, bin). - * @param featureOffset This is a pre-computed (node, feature) offset - * from [[getFeatureOffset]]. - */ - def getImpurityCalculator(featureOffset: Int, binIndex: Int): ImpurityCalculator = { - impurityAggregator.getCalculator(allStats, featureOffset + binIndex * statsSize) - } - - /** - * Get an [[ImpurityCalculator]] for the parent node. - */ - def getParentImpurityCalculator(): ImpurityCalculator = { - impurityAggregator.getCalculator(parentStats, 0) - } - - /** - * Update the stats for a given (feature, bin) for ordered features, using the given label. - */ - def update(featureIndex: Int, binIndex: Int, label: Double, instanceWeight: Double): Unit = { - val i = featureOffsets(featureIndex) + binIndex * statsSize - impurityAggregator.update(allStats, i, label, instanceWeight) - } - - /** - * Update the parent node stats using the given label. - */ - def updateParent(label: Double, instanceWeight: Double): Unit = { - impurityAggregator.update(parentStats, 0, label, instanceWeight) - } - - /** - * Faster version of [[update]]. - * Update the stats for a given (feature, bin), using the given label. - * @param featureOffset This is a pre-computed feature offset - * from [[getFeatureOffset]]. - */ - def featureUpdate( - featureOffset: Int, - binIndex: Int, - label: Double, - instanceWeight: Double): Unit = { - impurityAggregator.update(allStats, featureOffset + binIndex * statsSize, - label, instanceWeight) - } - - /** - * Pre-compute feature offset for use with [[featureUpdate]]. - * For ordered features only. - */ - def getFeatureOffset(featureIndex: Int): Int = featureOffsets(featureIndex) - - /** - * For a given feature, merge the stats for two bins. - * @param featureOffset This is a pre-computed feature offset - * from [[getFeatureOffset]]. - * @param binIndex The other bin is merged into this bin. - * @param otherBinIndex This bin is not modified. - */ - def mergeForFeature(featureOffset: Int, binIndex: Int, otherBinIndex: Int): Unit = { - impurityAggregator.merge(allStats, featureOffset + binIndex * statsSize, - featureOffset + otherBinIndex * statsSize) - } - - /** - * Merge this aggregator with another, and returns this aggregator. - * This method modifies this aggregator in-place. - */ - def merge(other: DTStatsAggregator): DTStatsAggregator = { - require(allStatsSize == other.allStatsSize, - s"DTStatsAggregator.merge requires that both aggregators have the same length stats vectors." - + s" This aggregator is of length $allStatsSize, but the other is ${other.allStatsSize}.") - var i = 0 - // TODO: Test BLAS.axpy - while (i < allStatsSize) { - allStats(i) += other.allStats(i) - i += 1 - } - - require(statsSize == other.statsSize, - s"DTStatsAggregator.merge requires that both aggregators have the same length parent " + - s"stats vectors. This aggregator's parent stats are length $statsSize, " + - s"but the other is ${other.statsSize}.") - var j = 0 - while (j < statsSize) { - parentStats(j) += other.parentStats(j) - j += 1 - } - - this - } -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala deleted file mode 100644 index 4f27dc44ef..0000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/DecisionTreeMetadata.scala +++ /dev/null @@ -1,217 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree.impl - -import scala.collection.mutable - -import org.apache.spark.internal.Logging -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.configuration.Algo._ -import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ -import org.apache.spark.mllib.tree.configuration.Strategy -import org.apache.spark.mllib.tree.impurity.Impurity -import org.apache.spark.rdd.RDD - -/** - * Learning and dataset metadata for DecisionTree. - * - * @param numClasses For classification: labels can take values {0, ..., numClasses - 1}. - * For regression: fixed at 0 (no meaning). - * @param maxBins Maximum number of bins, for all features. - * @param featureArity Map: categorical feature index --> arity. - * I.e., the feature takes values in {0, ..., arity - 1}. - * @param numBins Number of bins for each feature. - */ -private[spark] class DecisionTreeMetadata( - val numFeatures: Int, - val numExamples: Long, - val numClasses: Int, - val maxBins: Int, - val featureArity: Map[Int, Int], - val unorderedFeatures: Set[Int], - val numBins: Array[Int], - val impurity: Impurity, - val quantileStrategy: QuantileStrategy, - val maxDepth: Int, - val minInstancesPerNode: Int, - val minInfoGain: Double, - val numTrees: Int, - val numFeaturesPerNode: Int) extends Serializable { - - def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex) - - def isClassification: Boolean = numClasses >= 2 - - def isMulticlass: Boolean = numClasses > 2 - - def isMulticlassWithCategoricalFeatures: Boolean = isMulticlass && (featureArity.size > 0) - - def isCategorical(featureIndex: Int): Boolean = featureArity.contains(featureIndex) - - def isContinuous(featureIndex: Int): Boolean = !featureArity.contains(featureIndex) - - /** - * Number of splits for the given feature. - * For unordered features, there is 1 bin per split. - * For ordered features, there is 1 more bin than split. - */ - def numSplits(featureIndex: Int): Int = if (isUnordered(featureIndex)) { - numBins(featureIndex) - } else { - numBins(featureIndex) - 1 - } - - - /** - * Set number of splits for a continuous feature. - * For a continuous feature, number of bins is number of splits plus 1. - */ - def setNumSplits(featureIndex: Int, numSplits: Int) { - require(isContinuous(featureIndex), - s"Only number of bin for a continuous feature can be set.") - numBins(featureIndex) = numSplits + 1 - } - - /** - * Indicates if feature subsampling is being used. - */ - def subsamplingFeatures: Boolean = numFeatures != numFeaturesPerNode - -} - -private[spark] object DecisionTreeMetadata extends Logging { - - /** - * Construct a [[DecisionTreeMetadata]] instance for this dataset and parameters. - * This computes which categorical features will be ordered vs. unordered, - * as well as the number of splits and bins for each feature. - */ - def buildMetadata( - input: RDD[LabeledPoint], - strategy: Strategy, - numTrees: Int, - featureSubsetStrategy: String): DecisionTreeMetadata = { - - val numFeatures = input.map(_.features.size).take(1).headOption.getOrElse { - throw new IllegalArgumentException(s"DecisionTree requires size of input RDD > 0, " + - s"but was given by empty one.") - } - val numExamples = input.count() - val numClasses = strategy.algo match { - case Classification => strategy.numClasses - case Regression => 0 - } - - val maxPossibleBins = math.min(strategy.maxBins, numExamples).toInt - if (maxPossibleBins < strategy.maxBins) { - logWarning(s"DecisionTree reducing maxBins from ${strategy.maxBins} to $maxPossibleBins" + - s" (= number of training instances)") - } - - // We check the number of bins here against maxPossibleBins. - // This needs to be checked here instead of in Strategy since maxPossibleBins can be modified - // based on the number of training examples. - if (strategy.categoricalFeaturesInfo.nonEmpty) { - val maxCategoriesPerFeature = strategy.categoricalFeaturesInfo.values.max - val maxCategory = - strategy.categoricalFeaturesInfo.find(_._2 == maxCategoriesPerFeature).get._1 - require(maxCategoriesPerFeature <= maxPossibleBins, - s"DecisionTree requires maxBins (= $maxPossibleBins) to be at least as large as the " + - s"number of values in each categorical feature, but categorical feature $maxCategory " + - s"has $maxCategoriesPerFeature values. Considering remove this and other categorical " + - "features with a large number of values, or add more training examples.") - } - - val unorderedFeatures = new mutable.HashSet[Int]() - val numBins = Array.fill[Int](numFeatures)(maxPossibleBins) - if (numClasses > 2) { - // Multiclass classification - val maxCategoriesForUnorderedFeature = - ((math.log(maxPossibleBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt - strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) => - // Hack: If a categorical feature has only 1 category, we treat it as continuous. - // TODO(SPARK-9957): Handle this properly by filtering out those features. - if (numCategories > 1) { - // Decide if some categorical features should be treated as unordered features, - // which require 2 * ((1 << numCategories - 1) - 1) bins. - // We do this check with log values to prevent overflows in case numCategories is large. - // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins - if (numCategories <= maxCategoriesForUnorderedFeature) { - unorderedFeatures.add(featureIndex) - numBins(featureIndex) = numUnorderedBins(numCategories) - } else { - numBins(featureIndex) = numCategories - } - } - } - } else { - // Binary classification or regression - strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) => - // If a categorical feature has only 1 category, we treat it as continuous: SPARK-9957 - if (numCategories > 1) { - numBins(featureIndex) = numCategories - } - } - } - - // Set number of features to use per node (for random forests). - val _featureSubsetStrategy = featureSubsetStrategy match { - case "auto" => - if (numTrees == 1) { - "all" - } else { - if (strategy.algo == Classification) { - "sqrt" - } else { - "onethird" - } - } - case _ => featureSubsetStrategy - } - val numFeaturesPerNode: Int = _featureSubsetStrategy match { - case "all" => numFeatures - case "sqrt" => math.sqrt(numFeatures).ceil.toInt - case "log2" => math.max(1, (math.log(numFeatures) / math.log(2)).ceil.toInt) - case "onethird" => (numFeatures / 3.0).ceil.toInt - } - - new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max, - strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins, - strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth, - strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode) - } - - /** - * Version of [[DecisionTreeMetadata#buildMetadata]] for DecisionTree. - */ - def buildMetadata( - input: RDD[LabeledPoint], - strategy: Strategy): DecisionTreeMetadata = { - buildMetadata(input, strategy, numTrees = 1, featureSubsetStrategy = "all") - } - - /** - * Given the arity of a categorical feature (arity = number of categories), - * return the number of bins for the feature if it is to be treated as an unordered feature. - * There is 1 split for every partitioning of categories into 2 disjoint, non-empty sets; - * there are math.pow(2, arity - 1) - 1 such splits. - * Each split has 2 corresponding bins. - */ - def numUnorderedBins(arity: Int): Int = (1 << arity - 1) - 1 - -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala deleted file mode 100644 index dc7e969f7b..0000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala +++ /dev/null @@ -1,195 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree.impl - -import scala.collection.mutable - -import org.apache.hadoop.fs.{FileSystem, Path} - -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.mllib.tree.configuration.FeatureType._ -import org.apache.spark.mllib.tree.model.{Bin, Node, Split} -import org.apache.spark.rdd.RDD -import org.apache.spark.storage.StorageLevel - -/** - * :: DeveloperApi :: - * This is used by the node id cache to find the child id that a data point would belong to. - * @param split Split information. - * @param nodeIndex The current node index of a data point that this will update. - */ -@DeveloperApi -private[tree] case class NodeIndexUpdater( - split: Split, - nodeIndex: Int) { - /** - * Determine a child node index based on the feature value and the split. - * @param binnedFeatures Binned feature values. - * @param bins Bin information to convert the bin indices to approximate feature values. - * @return Child node index to update to. - */ - def updateNodeIndex(binnedFeatures: Array[Int], bins: Array[Array[Bin]]): Int = { - if (split.featureType == Continuous) { - val featureIndex = split.feature - val binIndex = binnedFeatures(featureIndex) - val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold - if (featureValueUpperBound <= split.threshold) { - Node.leftChildIndex(nodeIndex) - } else { - Node.rightChildIndex(nodeIndex) - } - } else { - if (split.categories.contains(binnedFeatures(split.feature).toDouble)) { - Node.leftChildIndex(nodeIndex) - } else { - Node.rightChildIndex(nodeIndex) - } - } - } -} - -/** - * :: DeveloperApi :: - * A given TreePoint would belong to a particular node per tree. - * Each row in the nodeIdsForInstances RDD is an array over trees of the node index - * in each tree. Initially, values should all be 1 for root node. - * The nodeIdsForInstances RDD needs to be updated at each iteration. - * @param nodeIdsForInstances The initial values in the cache - * (should be an Array of all 1's (meaning the root nodes)). - * @param checkpointInterval The checkpointing interval - * (how often should the cache be checkpointed.). - */ -@DeveloperApi -private[spark] class NodeIdCache( - var nodeIdsForInstances: RDD[Array[Int]], - val checkpointInterval: Int) { - - // Keep a reference to a previous node Ids for instances. - // Because we will keep on re-persisting updated node Ids, - // we want to unpersist the previous RDD. - private var prevNodeIdsForInstances: RDD[Array[Int]] = null - - // To keep track of the past checkpointed RDDs. - private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]() - private var rddUpdateCount = 0 - - /** - * Update the node index values in the cache. - * This updates the RDD and its lineage. - * TODO: Passing bin information to executors seems unnecessary and costly. - * @param data The RDD of training rows. - * @param nodeIdUpdaters A map of node index updaters. - * The key is the indices of nodes that we want to update. - * @param bins Bin information needed to find child node indices. - */ - def updateNodeIndices( - data: RDD[BaggedPoint[TreePoint]], - nodeIdUpdaters: Array[mutable.Map[Int, NodeIndexUpdater]], - bins: Array[Array[Bin]]): Unit = { - if (prevNodeIdsForInstances != null) { - // Unpersist the previous one if one exists. - prevNodeIdsForInstances.unpersist() - } - - prevNodeIdsForInstances = nodeIdsForInstances - nodeIdsForInstances = data.zip(nodeIdsForInstances).map { - case (point, node) => { - var treeId = 0 - while (treeId < nodeIdUpdaters.length) { - val nodeIdUpdater = nodeIdUpdaters(treeId).getOrElse(node(treeId), null) - if (nodeIdUpdater != null) { - val newNodeIndex = nodeIdUpdater.updateNodeIndex( - binnedFeatures = point.datum.binnedFeatures, - bins = bins) - node(treeId) = newNodeIndex - } - - treeId += 1 - } - - node - } - } - - // Keep on persisting new ones. - nodeIdsForInstances.persist(StorageLevel.MEMORY_AND_DISK) - rddUpdateCount += 1 - - // Handle checkpointing if the directory is not None. - if (nodeIdsForInstances.sparkContext.getCheckpointDir.nonEmpty && - (rddUpdateCount % checkpointInterval) == 0) { - // Let's see if we can delete previous checkpoints. - var canDelete = true - while (checkpointQueue.size > 1 && canDelete) { - // We can delete the oldest checkpoint iff - // the next checkpoint actually exists in the file system. - if (checkpointQueue.get(1).get.getCheckpointFile.isDefined) { - val old = checkpointQueue.dequeue() - - // Since the old checkpoint is not deleted by Spark, - // we'll manually delete it here. - val fs = FileSystem.get(old.sparkContext.hadoopConfiguration) - fs.delete(new Path(old.getCheckpointFile.get), true) - } else { - canDelete = false - } - } - - nodeIdsForInstances.checkpoint() - checkpointQueue.enqueue(nodeIdsForInstances) - } - } - - /** - * Call this after training is finished to delete any remaining checkpoints. - */ - def deleteAllCheckpoints(): Unit = { - while (checkpointQueue.nonEmpty) { - val old = checkpointQueue.dequeue() - for (checkpointFile <- old.getCheckpointFile) { - val fs = FileSystem.get(old.sparkContext.hadoopConfiguration) - fs.delete(new Path(checkpointFile), true) - } - } - if (prevNodeIdsForInstances != null) { - // Unpersist the previous one if one exists. - prevNodeIdsForInstances.unpersist() - } - } -} - -private[spark] object NodeIdCache { - /** - * Initialize the node Id cache with initial node Id values. - * @param data The RDD of training rows. - * @param numTrees The number of trees that we want to create cache for. - * @param checkpointInterval The checkpointing interval - * (how often should the cache be checkpointed.). - * @param initVal The initial values in the cache. - * @return A node Id cache containing an RDD of initial root node Indices. - */ - def init( - data: RDD[BaggedPoint[TreePoint]], - numTrees: Int, - checkpointInterval: Int, - initVal: Int = 1): NodeIdCache = { - new NodeIdCache( - data.map(_ => Array.fill[Int](numTrees)(initVal)), - checkpointInterval) - } -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala deleted file mode 100644 index 70afaa162b..0000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TimeTracker.scala +++ /dev/null @@ -1,70 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree.impl - -import scala.collection.mutable.{HashMap => MutableHashMap} - -/** - * Time tracker implementation which holds labeled timers. - */ -private[spark] class TimeTracker extends Serializable { - - private val starts: MutableHashMap[String, Long] = new MutableHashMap[String, Long]() - - private val totals: MutableHashMap[String, Long] = new MutableHashMap[String, Long]() - - /** - * Starts a new timer, or re-starts a stopped timer. - */ - def start(timerLabel: String): Unit = { - val currentTime = System.nanoTime() - if (starts.contains(timerLabel)) { - throw new RuntimeException(s"TimeTracker.start(timerLabel) called again on" + - s" timerLabel = $timerLabel before that timer was stopped.") - } - starts(timerLabel) = currentTime - } - - /** - * Stops a timer and returns the elapsed time in seconds. - */ - def stop(timerLabel: String): Double = { - val currentTime = System.nanoTime() - if (!starts.contains(timerLabel)) { - throw new RuntimeException(s"TimeTracker.stop(timerLabel) called on" + - s" timerLabel = $timerLabel, but that timer was not started.") - } - val elapsed = currentTime - starts(timerLabel) - starts.remove(timerLabel) - if (totals.contains(timerLabel)) { - totals(timerLabel) += elapsed - } else { - totals(timerLabel) = elapsed - } - elapsed / 1e9 - } - - /** - * Print all timing results in seconds. - */ - override def toString: String = { - totals.map { case (label, elapsed) => - s" $label: ${elapsed / 1e9}" - }.mkString("\n") - } -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala deleted file mode 100644 index 21919d69a3..0000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/TreePoint.scala +++ /dev/null @@ -1,150 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree.impl - -import org.apache.spark.mllib.regression.LabeledPoint -import org.apache.spark.mllib.tree.model.Bin -import org.apache.spark.rdd.RDD - - -/** - * Internal representation of LabeledPoint for DecisionTree. - * This bins feature values based on a subsampled of data as follows: - * (a) Continuous features are binned into ranges. - * (b) Unordered categorical features are binned based on subsets of feature values. - * "Unordered categorical features" are categorical features with low arity used in - * multiclass classification. - * (c) Ordered categorical features are binned based on feature values. - * "Ordered categorical features" are categorical features with high arity, - * or any categorical feature used in regression or binary classification. - * - * @param label Label from LabeledPoint - * @param binnedFeatures Binned feature values. - * Same length as LabeledPoint.features, but values are bin indices. - */ -private[spark] class TreePoint(val label: Double, val binnedFeatures: Array[Int]) - extends Serializable { -} - -private[spark] object TreePoint { - - /** - * Convert an input dataset into its TreePoint representation, - * binning feature values in preparation for DecisionTree training. - * @param input Input dataset. - * @param bins Bins for features, of size (numFeatures, numBins). - * @param metadata Learning and dataset metadata - * @return TreePoint dataset representation - */ - def convertToTreeRDD( - input: RDD[LabeledPoint], - bins: Array[Array[Bin]], - metadata: DecisionTreeMetadata): RDD[TreePoint] = { - // Construct arrays for featureArity for efficiency in the inner loop. - val featureArity: Array[Int] = new Array[Int](metadata.numFeatures) - var featureIndex = 0 - while (featureIndex < metadata.numFeatures) { - featureArity(featureIndex) = metadata.featureArity.getOrElse(featureIndex, 0) - featureIndex += 1 - } - input.map { x => - TreePoint.labeledPointToTreePoint(x, bins, featureArity) - } - } - - /** - * Convert one LabeledPoint into its TreePoint representation. - * @param bins Bins for features, of size (numFeatures, numBins). - * @param featureArity Array indexed by feature, with value 0 for continuous and numCategories - * for categorical features. - */ - private def labeledPointToTreePoint( - labeledPoint: LabeledPoint, - bins: Array[Array[Bin]], - featureArity: Array[Int]): TreePoint = { - val numFeatures = labeledPoint.features.size - val arr = new Array[Int](numFeatures) - var featureIndex = 0 - while (featureIndex < numFeatures) { - arr(featureIndex) = findBin(featureIndex, labeledPoint, featureArity(featureIndex), - bins) - featureIndex += 1 - } - new TreePoint(labeledPoint.label, arr) - } - - /** - * Find bin for one (labeledPoint, feature). - * - * @param featureArity 0 for continuous features; number of categories for categorical features. - * @param bins Bins for features, of size (numFeatures, numBins). - */ - private def findBin( - featureIndex: Int, - labeledPoint: LabeledPoint, - featureArity: Int, - bins: Array[Array[Bin]]): Int = { - - /** - * Binary search helper method for continuous feature. - */ - def binarySearchForBins(): Int = { - val binForFeatures = bins(featureIndex) - val feature = labeledPoint.features(featureIndex) - var left = 0 - var right = binForFeatures.length - 1 - while (left <= right) { - val mid = left + (right - left) / 2 - val bin = binForFeatures(mid) - val lowThreshold = bin.lowSplit.threshold - val highThreshold = bin.highSplit.threshold - if ((lowThreshold < feature) && (highThreshold >= feature)) { - return mid - } else if (lowThreshold >= feature) { - right = mid - 1 - } else { - left = mid + 1 - } - } - -1 - } - - if (featureArity == 0) { - // Perform binary search for finding bin for continuous features. - val binIndex = binarySearchForBins() - if (binIndex == -1) { - throw new RuntimeException("No bin was found for continuous feature." + - " This error can occur when given invalid data values (such as NaN)." + - s" Feature index: $featureIndex. Feature value: ${labeledPoint.features(featureIndex)}") - } - binIndex - } else { - // Categorical feature bins are indexed by feature values. - val featureValue = labeledPoint.features(featureIndex) - if (featureValue < 0 || featureValue >= featureArity) { - throw new IllegalArgumentException( - s"DecisionTree given invalid data:" + - s" Feature $featureIndex is categorical with values in" + - s" {0,...,${featureArity - 1}," + - s" but a data point gives it value $featureValue.\n" + - " Bad data point: " + labeledPoint.toString) - } - featureValue.toInt - } - } -} diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 13aff11007..ff7700d2d1 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -85,7 +85,7 @@ object Entropy extends Impurity { * Note: Instances of this class do not hold the data; they operate on views of the data. * @param numClasses Number of classes for label. */ -private[tree] class EntropyAggregator(numClasses: Int) +private[spark] class EntropyAggregator(numClasses: Int) extends ImpurityAggregator(numClasses) with Serializable { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 39c7f9c3be..58dc79b739 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -81,7 +81,7 @@ object Gini extends Impurity { * Note: Instances of this class do not hold the data; they operate on views of the data. * @param numClasses Number of classes for label. */ -private[tree] class GiniAggregator(numClasses: Int) +private[spark] class GiniAggregator(numClasses: Int) extends ImpurityAggregator(numClasses) with Serializable { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 92d74a1b83..2423516123 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -71,7 +71,7 @@ object Variance extends Impurity { * in order to compute impurity from a sample. * Note: Instances of this class do not hold the data; they operate on views of the data. */ -private[tree] class VarianceAggregator() +private[spark] class VarianceAggregator() extends ImpurityAggregator(statsSize = 3) with Serializable { /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala deleted file mode 100644 index 0cad473782..0000000000 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/Bin.scala +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree.model - -import org.apache.spark.mllib.tree.configuration.FeatureType._ - -/** - * Used for "binning" the feature values for faster best split calculation. - * - * For a continuous feature, the bin is determined by a low and a high split, - * where an example with featureValue falls into the bin s.t. - * lowSplit.threshold < featureValue <= highSplit.threshold. - * - * For ordered categorical features, there is a 1-1-1 correspondence between - * bins, splits, and feature values. The bin is determined by category/feature value. - * However, the bins are not necessarily ordered by feature value; - * they are ordered using impurity. - * - * For unordered categorical features, there is a 1-1 correspondence between bins, splits, - * where bins and splits correspond to subsets of feature values (in highSplit.categories). - * An unordered feature with k categories uses (1 << k - 1) - 1 bins, corresponding to all - * partitionings of categories into 2 disjoint, non-empty sets. - * - * @param lowSplit signifying the lower threshold for the continuous feature to be - * accepted in the bin - * @param highSplit signifying the upper threshold for the continuous feature to be - * accepted in the bin - * @param featureType type of feature -- categorical or continuous - * @param category categorical label value accepted in the bin for ordered features - */ -private[tree] -case class Bin(lowSplit: Split, highSplit: Split, featureType: FeatureType, category: Double) diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala new file mode 100644 index 0000000000..77ab3d8bb7 --- /dev/null +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/BaggedPointSuite.scala @@ -0,0 +1,99 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import org.apache.spark.SparkFunSuite +import org.apache.spark.mllib.tree.EnsembleTestHelper +import org.apache.spark.mllib.util.MLlibTestSparkContext + +/** + * Test suite for [[BaggedPoint]]. + */ +class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext { + + test("BaggedPoint RDD: without subsampling") { + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) + val rdd = sc.parallelize(arr) + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false, 42) + baggedRDD.collect().foreach { baggedPoint => + assert(baggedPoint.subsampleWeights.size == 1 && baggedPoint.subsampleWeights(0) == 1) + } + } + + test("BaggedPoint RDD: with subsampling with replacement (fraction = 1.0)") { + val numSubsamples = 100 + val (expectedMean, expectedStddev) = (1.0, 1.0) + + val seeds = Array(123, 5354, 230, 349867, 23987) + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) + val rdd = sc.parallelize(arr) + seeds.foreach { seed => + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, true, seed) + val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() + EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, + expectedStddev, epsilon = 0.01) + } + } + + test("BaggedPoint RDD: with subsampling with replacement (fraction = 0.5)") { + val numSubsamples = 100 + val subsample = 0.5 + val (expectedMean, expectedStddev) = (subsample, math.sqrt(subsample)) + + val seeds = Array(123, 5354, 230, 349867, 23987) + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) + val rdd = sc.parallelize(arr) + seeds.foreach { seed => + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, true, seed) + val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() + EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, + expectedStddev, epsilon = 0.01) + } + } + + test("BaggedPoint RDD: with subsampling without replacement (fraction = 1.0)") { + val numSubsamples = 100 + val (expectedMean, expectedStddev) = (1.0, 0) + + val seeds = Array(123, 5354, 230, 349867, 23987) + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) + val rdd = sc.parallelize(arr) + seeds.foreach { seed => + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, false, seed) + val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() + EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, + expectedStddev, epsilon = 0.01) + } + } + + test("BaggedPoint RDD: with subsampling without replacement (fraction = 0.5)") { + val numSubsamples = 100 + val subsample = 0.5 + val (expectedMean, expectedStddev) = (subsample, math.sqrt(subsample * (1 - subsample))) + + val seeds = Array(123, 5354, 230, 349867, 23987) + val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) + val rdd = sc.parallelize(arr) + seeds.foreach { seed => + val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, false, seed) + val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() + EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, + expectedStddev, epsilon = 0.01) + } + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala index 441338e74e..e64551f03c 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tree/impl/RandomForestSuite.scala @@ -26,7 +26,6 @@ import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTreeSuite => OldDTSuite, EnsembleTestHelper} import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, QuantileStrategy, Strategy => OldStrategy} -import org.apache.spark.mllib.tree.impl.{BaggedPoint, DecisionTreeMetadata} import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, GiniCalculator} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.mllib.util.TestingUtils._ diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala index bb1041b109..49cb7e1f24 100644 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/mllib/tree/DecisionTreeSuite.scala @@ -20,12 +20,12 @@ package org.apache.spark.mllib.tree import scala.collection.JavaConverters._ import org.apache.spark.SparkFunSuite +import org.apache.spark.ml.tree.impl.DecisionTreeMetadata import org.apache.spark.mllib.linalg.Vectors import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.Algo._ import org.apache.spark.mllib.tree.configuration.FeatureType._ import org.apache.spark.mllib.tree.configuration.Strategy -import org.apache.spark.mllib.tree.impl.DecisionTreeMetadata import org.apache.spark.mllib.tree.impurity.{Entropy, Gini, Variance} import org.apache.spark.mllib.tree.model._ import org.apache.spark.mllib.util.MLlibTestSparkContext diff --git a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala b/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala deleted file mode 100644 index 9d756da410..0000000000 --- a/mllib/src/test/scala/org/apache/spark/mllib/tree/impl/BaggedPointSuite.scala +++ /dev/null @@ -1,99 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark.mllib.tree.impl - -import org.apache.spark.SparkFunSuite -import org.apache.spark.mllib.tree.EnsembleTestHelper -import org.apache.spark.mllib.util.MLlibTestSparkContext - -/** - * Test suite for [[BaggedPoint]]. - */ -class BaggedPointSuite extends SparkFunSuite with MLlibTestSparkContext { - - test("BaggedPoint RDD: without subsampling") { - val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) - val rdd = sc.parallelize(arr) - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, 1, false, 42) - baggedRDD.collect().foreach { baggedPoint => - assert(baggedPoint.subsampleWeights.size == 1 && baggedPoint.subsampleWeights(0) == 1) - } - } - - test("BaggedPoint RDD: with subsampling with replacement (fraction = 1.0)") { - val numSubsamples = 100 - val (expectedMean, expectedStddev) = (1.0, 1.0) - - val seeds = Array(123, 5354, 230, 349867, 23987) - val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) - val rdd = sc.parallelize(arr) - seeds.foreach { seed => - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, true, seed) - val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() - EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, - expectedStddev, epsilon = 0.01) - } - } - - test("BaggedPoint RDD: with subsampling with replacement (fraction = 0.5)") { - val numSubsamples = 100 - val subsample = 0.5 - val (expectedMean, expectedStddev) = (subsample, math.sqrt(subsample)) - - val seeds = Array(123, 5354, 230, 349867, 23987) - val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) - val rdd = sc.parallelize(arr) - seeds.foreach { seed => - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, true, seed) - val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() - EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, - expectedStddev, epsilon = 0.01) - } - } - - test("BaggedPoint RDD: with subsampling without replacement (fraction = 1.0)") { - val numSubsamples = 100 - val (expectedMean, expectedStddev) = (1.0, 0) - - val seeds = Array(123, 5354, 230, 349867, 23987) - val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) - val rdd = sc.parallelize(arr) - seeds.foreach { seed => - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, 1.0, numSubsamples, false, seed) - val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() - EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, - expectedStddev, epsilon = 0.01) - } - } - - test("BaggedPoint RDD: with subsampling without replacement (fraction = 0.5)") { - val numSubsamples = 100 - val subsample = 0.5 - val (expectedMean, expectedStddev) = (subsample, math.sqrt(subsample * (1 - subsample))) - - val seeds = Array(123, 5354, 230, 349867, 23987) - val arr = EnsembleTestHelper.generateOrderedLabeledPoints(1, 1000) - val rdd = sc.parallelize(arr) - seeds.foreach { seed => - val baggedRDD = BaggedPoint.convertToBaggedRDD(rdd, subsample, numSubsamples, false, seed) - val subsampleCounts: Array[Array[Double]] = baggedRDD.map(_.subsampleWeights).collect() - EnsembleTestHelper.testRandomArrays(subsampleCounts, numSubsamples, expectedMean, - expectedStddev, epsilon = 0.01) - } - } -} -- cgit v1.2.3