diff options
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala | 222 |
1 files changed, 222 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala new file mode 100644 index 0000000000..c7cde1563f --- /dev/null +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/DecisionTreeMetadata.scala @@ -0,0 +1,222 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.ml.tree.impl + +import scala.collection.mutable + +import org.apache.spark.internal.Logging +import org.apache.spark.mllib.regression.LabeledPoint +import org.apache.spark.mllib.tree.configuration.Algo._ +import org.apache.spark.mllib.tree.configuration.QuantileStrategy._ +import org.apache.spark.mllib.tree.configuration.Strategy +import org.apache.spark.mllib.tree.impurity.Impurity +import org.apache.spark.rdd.RDD + +/** + * Learning and dataset metadata for DecisionTree. + * + * @param numClasses For classification: labels can take values {0, ..., numClasses - 1}. + * For regression: fixed at 0 (no meaning). + * @param maxBins Maximum number of bins, for all features. + * @param featureArity Map: categorical feature index --> arity. + * I.e., the feature takes values in {0, ..., arity - 1}. + * @param numBins Number of bins for each feature. + */ +private[spark] class DecisionTreeMetadata( + val numFeatures: Int, + val numExamples: Long, + val numClasses: Int, + val maxBins: Int, + val featureArity: Map[Int, Int], + val unorderedFeatures: Set[Int], + val numBins: Array[Int], + val impurity: Impurity, + val quantileStrategy: QuantileStrategy, + val maxDepth: Int, + val minInstancesPerNode: Int, + val minInfoGain: Double, + val numTrees: Int, + val numFeaturesPerNode: Int) extends Serializable { + + def isUnordered(featureIndex: Int): Boolean = unorderedFeatures.contains(featureIndex) + + def isClassification: Boolean = numClasses >= 2 + + def isMulticlass: Boolean = numClasses > 2 + + def isMulticlassWithCategoricalFeatures: Boolean = isMulticlass && (featureArity.size > 0) + + def isCategorical(featureIndex: Int): Boolean = featureArity.contains(featureIndex) + + def isContinuous(featureIndex: Int): Boolean = !featureArity.contains(featureIndex) + + /** + * Number of splits for the given feature. + * For unordered features, there is 1 bin per split. + * For ordered features, there is 1 more bin than split. + */ + def numSplits(featureIndex: Int): Int = if (isUnordered(featureIndex)) { + numBins(featureIndex) + } else { + numBins(featureIndex) - 1 + } + + + /** + * Set number of splits for a continuous feature. + * For a continuous feature, number of bins is number of splits plus 1. + */ + def setNumSplits(featureIndex: Int, numSplits: Int) { + require(isContinuous(featureIndex), + s"Only number of bin for a continuous feature can be set.") + numBins(featureIndex) = numSplits + 1 + } + + /** + * Indicates if feature subsampling is being used. + */ + def subsamplingFeatures: Boolean = numFeatures != numFeaturesPerNode + +} + +private[spark] object DecisionTreeMetadata extends Logging { + + /** + * Construct a [[DecisionTreeMetadata]] instance for this dataset and parameters. + * This computes which categorical features will be ordered vs. unordered, + * as well as the number of splits and bins for each feature. + */ + def buildMetadata( + input: RDD[LabeledPoint], + strategy: Strategy, + numTrees: Int, + featureSubsetStrategy: String): DecisionTreeMetadata = { + + val numFeatures = input.map(_.features.size).take(1).headOption.getOrElse { + throw new IllegalArgumentException(s"DecisionTree requires size of input RDD > 0, " + + s"but was given by empty one.") + } + val numExamples = input.count() + val numClasses = strategy.algo match { + case Classification => strategy.numClasses + case Regression => 0 + } + + val maxPossibleBins = math.min(strategy.maxBins, numExamples).toInt + if (maxPossibleBins < strategy.maxBins) { + logWarning(s"DecisionTree reducing maxBins from ${strategy.maxBins} to $maxPossibleBins" + + s" (= number of training instances)") + } + + // We check the number of bins here against maxPossibleBins. + // This needs to be checked here instead of in Strategy since maxPossibleBins can be modified + // based on the number of training examples. + if (strategy.categoricalFeaturesInfo.nonEmpty) { + val maxCategoriesPerFeature = strategy.categoricalFeaturesInfo.values.max + val maxCategory = + strategy.categoricalFeaturesInfo.find(_._2 == maxCategoriesPerFeature).get._1 + require(maxCategoriesPerFeature <= maxPossibleBins, + s"DecisionTree requires maxBins (= $maxPossibleBins) to be at least as large as the " + + s"number of values in each categorical feature, but categorical feature $maxCategory " + + s"has $maxCategoriesPerFeature values. Considering remove this and other categorical " + + "features with a large number of values, or add more training examples.") + } + + val unorderedFeatures = new mutable.HashSet[Int]() + val numBins = Array.fill[Int](numFeatures)(maxPossibleBins) + if (numClasses > 2) { + // Multiclass classification + val maxCategoriesForUnorderedFeature = + ((math.log(maxPossibleBins / 2 + 1) / math.log(2.0)) + 1).floor.toInt + strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) => + // Hack: If a categorical feature has only 1 category, we treat it as continuous. + // TODO(SPARK-9957): Handle this properly by filtering out those features. + if (numCategories > 1) { + // Decide if some categorical features should be treated as unordered features, + // which require 2 * ((1 << numCategories - 1) - 1) bins. + // We do this check with log values to prevent overflows in case numCategories is large. + // The next check is equivalent to: 2 * ((1 << numCategories - 1) - 1) <= maxBins + if (numCategories <= maxCategoriesForUnorderedFeature) { + unorderedFeatures.add(featureIndex) + numBins(featureIndex) = numUnorderedBins(numCategories) + } else { + numBins(featureIndex) = numCategories + } + } + } + } else { + // Binary classification or regression + strategy.categoricalFeaturesInfo.foreach { case (featureIndex, numCategories) => + // If a categorical feature has only 1 category, we treat it as continuous: SPARK-9957 + if (numCategories > 1) { + numBins(featureIndex) = numCategories + } + } + } + + // Set number of features to use per node (for random forests). + val _featureSubsetStrategy = featureSubsetStrategy match { + case "auto" => + if (numTrees == 1) { + "all" + } else { + if (strategy.algo == Classification) { + "sqrt" + } else { + "onethird" + } + } + case _ => featureSubsetStrategy + } + + val isIntRegex = "^([1-9]\\d*)$".r + val isFractionRegex = "^(0?\\.\\d*[1-9]\\d*|1\\.0+)$".r + val numFeaturesPerNode: Int = _featureSubsetStrategy match { + case "all" => numFeatures + case "sqrt" => math.sqrt(numFeatures).ceil.toInt + case "log2" => math.max(1, (math.log(numFeatures) / math.log(2)).ceil.toInt) + case "onethird" => (numFeatures / 3.0).ceil.toInt + case isIntRegex(number) => if (BigInt(number) > numFeatures) numFeatures else number.toInt + case isFractionRegex(fraction) => (fraction.toDouble * numFeatures).ceil.toInt + } + + new DecisionTreeMetadata(numFeatures, numExamples, numClasses, numBins.max, + strategy.categoricalFeaturesInfo, unorderedFeatures.toSet, numBins, + strategy.impurity, strategy.quantileCalculationStrategy, strategy.maxDepth, + strategy.minInstancesPerNode, strategy.minInfoGain, numTrees, numFeaturesPerNode) + } + + /** + * Version of [[DecisionTreeMetadata#buildMetadata]] for DecisionTree. + */ + def buildMetadata( + input: RDD[LabeledPoint], + strategy: Strategy): DecisionTreeMetadata = { + buildMetadata(input, strategy, numTrees = 1, featureSubsetStrategy = "all") + } + + /** + * Given the arity of a categorical feature (arity = number of categories), + * return the number of bins for the feature if it is to be treated as an unordered feature. + * There is 1 split for every partitioning of categories into 2 disjoint, non-empty sets; + * there are math.pow(2, arity - 1) - 1 such splits. + * Each split has 2 corresponding bins. + */ + def numUnorderedBins(arity: Int): Int = (1 << arity - 1) - 1 + +} |