From e8bdcdeabb2df139a656f86686cdb53c891b1f4b Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Fri, 31 Jul 2015 11:56:52 -0700 Subject: [SPARK-6885] [ML] decision tree support predict class probabilities Decision tree support predict class probabilities. Implement the prediction probabilities function referred the old DecisionTree API and the [sklean API](https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/tree.py#L593). I make the DecisionTreeClassificationModel inherit from ProbabilisticClassificationModel, make the predictRaw to return the raw counts vector and make raw2probabilityInPlace/predictProbability return the probabilities for each prediction. Author: Yanbo Liang Closes #7694 from yanboliang/spark-6885 and squashes the following commits: 08d5b7f [Yanbo Liang] fix ImpurityStats null parameters and raw2probabilityInPlace sum = 0 issue 2174278 [Yanbo Liang] solve merge conflicts 7e90ba8 [Yanbo Liang] fix typos 33ae183 [Yanbo Liang] fix annotation ff043d3 [Yanbo Liang] raw2probabilityInPlace should operate in-place c32d6ce [Yanbo Liang] optimize calculateImpurityStats function again 6167fb0 [Yanbo Liang] optimize calculateImpurityStats function fbbe2ec [Yanbo Liang] eliminate duplicated struct and code beb1634 [Yanbo Liang] try to eliminate impurityStats for each LearningNode 99e8943 [Yanbo Liang] code optimization 5ec3323 [Yanbo Liang] implement InformationGainAndImpurityStats 227c91b [Yanbo Liang] refactor LearningNode to store ImpurityCalculator d746ffc [Yanbo Liang] decision tree support predict class probabilities --- .../ml/classification/DecisionTreeClassifier.scala | 40 +++++-- .../spark/ml/classification/GBTClassifier.scala | 2 +- .../ml/classification/RandomForestClassifier.scala | 2 +- .../ml/regression/DecisionTreeRegressor.scala | 2 +- .../apache/spark/ml/regression/GBTRegressor.scala | 2 +- .../ml/regression/RandomForestRegressor.scala | 2 +- .../main/scala/org/apache/spark/ml/tree/Node.scala | 80 +++++++------ .../apache/spark/ml/tree/impl/RandomForest.scala | 126 +++++++++------------ .../apache/spark/mllib/tree/impurity/Entropy.scala | 2 +- .../apache/spark/mllib/tree/impurity/Gini.scala | 2 +- .../spark/mllib/tree/impurity/Impurity.scala | 2 +- .../spark/mllib/tree/impurity/Variance.scala | 2 +- .../mllib/tree/model/InformationGainStats.scala | 61 +++++++++- .../DecisionTreeClassifierSuite.scala | 30 ++++- .../ml/classification/GBTClassifierSuite.scala | 2 +- .../RandomForestClassifierSuite.scala | 2 +- 16 files changed, 229 insertions(+), 130 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index 36fe1bd404..f27cfd0331 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -18,12 +18,11 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.Experimental -import org.apache.spark.ml.{PredictionModel, Predictor} import org.apache.spark.ml.param.ParamMap import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams} import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.util.{Identifiable, MetadataUtils} -import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy} import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel} @@ -39,7 +38,7 @@ import org.apache.spark.sql.DataFrame */ @Experimental final class DecisionTreeClassifier(override val uid: String) - extends Predictor[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel] + extends ProbabilisticClassifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel] with DecisionTreeParams with TreeClassifierParams { def this() = this(Identifiable.randomUID("dtc")) @@ -106,8 +105,9 @@ object DecisionTreeClassifier { @Experimental final class DecisionTreeClassificationModel private[ml] ( override val uid: String, - override val rootNode: Node) - extends PredictionModel[Vector, DecisionTreeClassificationModel] + override val rootNode: Node, + override val numClasses: Int) + extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel] with DecisionTreeModel with Serializable { require(rootNode != null, @@ -117,14 +117,36 @@ final class DecisionTreeClassificationModel private[ml] ( * Construct a decision tree classification model. * @param rootNode Root node of tree, with other nodes attached. */ - def this(rootNode: Node) = this(Identifiable.randomUID("dtc"), rootNode) + def this(rootNode: Node, numClasses: Int) = + this(Identifiable.randomUID("dtc"), rootNode, numClasses) override protected def predict(features: Vector): Double = { - rootNode.predict(features) + rootNode.predictImpl(features).prediction + } + + override protected def predictRaw(features: Vector): Vector = { + Vectors.dense(rootNode.predictImpl(features).impurityStats.stats.clone()) + } + + override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = { + rawPrediction match { + case dv: DenseVector => + var i = 0 + val size = dv.size + val sum = dv.values.sum + while (i < size) { + dv.values(i) = if (sum != 0) dv.values(i) / sum else 0.0 + i += 1 + } + dv + case sv: SparseVector => + throw new RuntimeException("Unexpected error in DecisionTreeClassificationModel:" + + " raw2probabilityInPlace encountered SparseVector") + } } override def copy(extra: ParamMap): DecisionTreeClassificationModel = { - copyValues(new DecisionTreeClassificationModel(uid, rootNode), extra) + copyValues(new DecisionTreeClassificationModel(uid, rootNode, numClasses), extra) } override def toString: String = { @@ -149,6 +171,6 @@ private[ml] object DecisionTreeClassificationModel { s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}") val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures) val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc") - new DecisionTreeClassificationModel(uid, rootNode) + new DecisionTreeClassificationModel(uid, rootNode, -1) } } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala index eb0b1a0a40..c3891a9599 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala @@ -190,7 +190,7 @@ final class GBTClassificationModel( override protected def predict(features: Vector): Double = { // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 // Classifies by thresholding sum of weighted tree predictions - val treePredictions = _trees.map(_.rootNode.predict(features)) + val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction) val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) if (prediction > 0.0) 1.0 else 0.0 } diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala index bc19bd6df8..0c7eb4a662 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala @@ -160,7 +160,7 @@ final class RandomForestClassificationModel private[ml] ( // Ignore the weights since all are 1.0 for now. val votes = new Array[Double](numClasses) _trees.view.foreach { tree => - val prediction = tree.rootNode.predict(features).toInt + val prediction = tree.rootNode.predictImpl(features).prediction.toInt votes(prediction) = votes(prediction) + 1.0 // 1.0 = weight } Vectors.dense(votes) diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 6f3340c2f0..4d30e4b554 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -110,7 +110,7 @@ final class DecisionTreeRegressionModel private[ml] ( def this(rootNode: Node) = this(Identifiable.randomUID("dtr"), rootNode) override protected def predict(features: Vector): Double = { - rootNode.predict(features) + rootNode.predictImpl(features).prediction } override def copy(extra: ParamMap): DecisionTreeRegressionModel = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala index e38dc73ee0..5633bc3202 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala @@ -180,7 +180,7 @@ final class GBTRegressionModel( override protected def predict(features: Vector): Double = { // TODO: When we add a generic Boosting class, handle transform there? SPARK-7129 // Classifies by thresholding sum of weighted tree predictions - val treePredictions = _trees.map(_.rootNode.predict(features)) + val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction) blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1) } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala index 506a878c25..17fb1ad5e1 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala @@ -143,7 +143,7 @@ final class RandomForestRegressionModel private[ml] ( // TODO: When we add a generic Bagging class, handle transform there. SPARK-7128 // Predict average of tree predictions. // Ignore the weights since all are 1.0 for now. - _trees.map(_.rootNode.predict(features)).sum / numTrees + _trees.map(_.rootNode.predictImpl(features).prediction).sum / numTrees } override def copy(extra: ParamMap): RandomForestRegressionModel = { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala index bbc2427ca7..8879352a60 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala @@ -19,8 +19,9 @@ package org.apache.spark.ml.tree import org.apache.spark.annotation.DeveloperApi import org.apache.spark.mllib.linalg.Vector +import org.apache.spark.mllib.tree.impurity.ImpurityCalculator import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformationGainStats, - Node => OldNode, Predict => OldPredict} + Node => OldNode, Predict => OldPredict, ImpurityStats} /** * :: DeveloperApi :: @@ -38,8 +39,15 @@ sealed abstract class Node extends Serializable { /** Impurity measure at this node (for training data) */ def impurity: Double + /** + * Statistics aggregated from training data at this node, used to compute prediction, impurity, + * and probabilities. + * For classification, the array of class counts must be normalized to a probability distribution. + */ + private[tree] def impurityStats: ImpurityCalculator + /** Recursive prediction helper method */ - private[ml] def predict(features: Vector): Double = prediction + private[ml] def predictImpl(features: Vector): LeafNode /** * Get the number of nodes in tree below this node, including leaf nodes. @@ -75,7 +83,8 @@ private[ml] object Node { if (oldNode.isLeaf) { // TODO: Once the implementation has been moved to this API, then include sufficient // statistics here. - new LeafNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity) + new LeafNode(prediction = oldNode.predict.predict, + impurity = oldNode.impurity, impurityStats = null) } else { val gain = if (oldNode.stats.nonEmpty) { oldNode.stats.get.gain @@ -85,7 +94,7 @@ private[ml] object Node { new InternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity, gain = gain, leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures), rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures), - split = Split.fromOld(oldNode.split.get, categoricalFeatures)) + split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null) } } } @@ -99,11 +108,13 @@ private[ml] object Node { @DeveloperApi final class LeafNode private[ml] ( override val prediction: Double, - override val impurity: Double) extends Node { + override val impurity: Double, + override val impurityStats: ImpurityCalculator) extends Node { - override def toString: String = s"LeafNode(prediction = $prediction, impurity = $impurity)" + override def toString: String = + s"LeafNode(prediction = $prediction, impurity = $impurity)" - override private[ml] def predict(features: Vector): Double = prediction + override private[ml] def predictImpl(features: Vector): LeafNode = this override private[tree] def numDescendants: Int = 0 @@ -115,9 +126,8 @@ final class LeafNode private[ml] ( override private[tree] def subtreeDepth: Int = 0 override private[ml] def toOld(id: Int): OldNode = { - // NOTE: We do NOT store 'prob' in the new API currently. - new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = true, - None, None, None, None) + new OldNode(id, new OldPredict(prediction, prob = impurityStats.prob(prediction)), + impurity, isLeaf = true, None, None, None, None) } } @@ -139,17 +149,18 @@ final class InternalNode private[ml] ( val gain: Double, val leftChild: Node, val rightChild: Node, - val split: Split) extends Node { + val split: Split, + override val impurityStats: ImpurityCalculator) extends Node { override def toString: String = { s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)" } - override private[ml] def predict(features: Vector): Double = { + override private[ml] def predictImpl(features: Vector): LeafNode = { if (split.shouldGoLeft(features)) { - leftChild.predict(features) + leftChild.predictImpl(features) } else { - rightChild.predict(features) + rightChild.predictImpl(features) } } @@ -172,9 +183,8 @@ final class InternalNode private[ml] ( override private[ml] def toOld(id: Int): OldNode = { assert(id.toLong * 2 < Int.MaxValue, "Decision Tree could not be converted from new to old API" + " since the old API does not support deep trees.") - // NOTE: We do NOT store 'prob' in the new API currently. - new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = false, - Some(split.toOld), Some(leftChild.toOld(OldNode.leftChildIndex(id))), + new OldNode(id, new OldPredict(prediction, prob = impurityStats.prob(prediction)), impurity, + isLeaf = false, Some(split.toOld), Some(leftChild.toOld(OldNode.leftChildIndex(id))), Some(rightChild.toOld(OldNode.rightChildIndex(id))), Some(new OldInformationGainStats(gain, impurity, leftChild.impurity, rightChild.impurity, new OldPredict(leftChild.prediction, prob = 0.0), @@ -223,36 +233,36 @@ private object InternalNode { * * @param id We currently use the same indexing as the old implementation in * [[org.apache.spark.mllib.tree.model.Node]], but this will change later. - * @param predictionStats Predicted label + class probability (for classification). - * We will later modify this to store aggregate statistics for labels - * to provide all class probabilities (for classification) and maybe a - * distribution (for regression). * @param isLeaf Indicates whether this node will definitely be a leaf in the learned tree, * so that we do not need to consider splitting it further. - * @param stats Old structure for storing stats about information gain, prediction, etc. - * This is legacy and will be modified in the future. + * @param stats Impurity statistics for this node. */ private[tree] class LearningNode( var id: Int, - var predictionStats: OldPredict, - var impurity: Double, var leftChild: Option[LearningNode], var rightChild: Option[LearningNode], var split: Option[Split], var isLeaf: Boolean, - var stats: Option[OldInformationGainStats]) extends Serializable { + var stats: ImpurityStats) extends Serializable { /** * Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children. */ def toNode: Node = { if (leftChild.nonEmpty) { - assert(rightChild.nonEmpty && split.nonEmpty && stats.nonEmpty, + assert(rightChild.nonEmpty && split.nonEmpty && stats != null, "Unknown error during Decision Tree learning. Could not convert LearningNode to Node.") - new InternalNode(predictionStats.predict, impurity, stats.get.gain, - leftChild.get.toNode, rightChild.get.toNode, split.get) + new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain, + leftChild.get.toNode, rightChild.get.toNode, split.get, stats.impurityCalculator) } else { - new LeafNode(predictionStats.predict, impurity) + if (stats.valid) { + new LeafNode(stats.impurityCalculator.predict, stats.impurity, + stats.impurityCalculator) + } else { + // Here we want to keep same behavior with the old mllib.DecisionTreeModel + new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator) + } + } } @@ -263,16 +273,14 @@ private[tree] object LearningNode { /** Create a node with some of its fields set. */ def apply( id: Int, - predictionStats: OldPredict, - impurity: Double, - isLeaf: Boolean): LearningNode = { - new LearningNode(id, predictionStats, impurity, None, None, None, false, None) + isLeaf: Boolean, + stats: ImpurityStats): LearningNode = { + new LearningNode(id, None, None, None, false, stats) } /** Create an empty node with the given node index. Values must be set later on. */ def emptyNode(nodeIndex: Int): LearningNode = { - new LearningNode(nodeIndex, new OldPredict(Double.NaN, Double.NaN), Double.NaN, - None, None, None, false, None) + new LearningNode(nodeIndex, None, None, None, false, null) } // The below indexing methods were copied from spark.mllib.tree.model.Node diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 15b56bd844..a8b90d9d26 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -31,7 +31,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => O import org.apache.spark.mllib.tree.impl.{BaggedPoint, DTStatsAggregator, DecisionTreeMetadata, TimeTracker} import org.apache.spark.mllib.tree.impurity.ImpurityCalculator -import org.apache.spark.mllib.tree.model.{InformationGainStats, Predict} +import org.apache.spark.mllib.tree.model.ImpurityStats import org.apache.spark.rdd.RDD import org.apache.spark.storage.StorageLevel import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom} @@ -180,13 +180,17 @@ private[ml] object RandomForest extends Logging { parentUID match { case Some(uid) => if (strategy.algo == OldAlgo.Classification) { - topNodes.map(rootNode => new DecisionTreeClassificationModel(uid, rootNode.toNode)) + topNodes.map { rootNode => + new DecisionTreeClassificationModel(uid, rootNode.toNode, strategy.getNumClasses) + } } else { topNodes.map(rootNode => new DecisionTreeRegressionModel(uid, rootNode.toNode)) } case None => if (strategy.algo == OldAlgo.Classification) { - topNodes.map(rootNode => new DecisionTreeClassificationModel(rootNode.toNode)) + topNodes.map { rootNode => + new DecisionTreeClassificationModel(rootNode.toNode, strategy.getNumClasses) + } } else { topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode)) } @@ -549,9 +553,9 @@ private[ml] object RandomForest extends Logging { } // find best split for each node - val (split: Split, stats: InformationGainStats, predict: Predict) = + val (split: Split, stats: ImpurityStats) = binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex)) - (nodeIndex, (split, stats, predict)) + (nodeIndex, (split, stats)) }.collectAsMap() timer.stop("chooseSplits") @@ -568,17 +572,15 @@ private[ml] object RandomForest extends Logging { val nodeIndex = node.id val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex) val aggNodeIndex = nodeInfo.nodeIndexInGroup - val (split: Split, stats: InformationGainStats, predict: Predict) = + val (split: Split, stats: ImpurityStats) = nodeToBestSplits(aggNodeIndex) logDebug("best split = " + split) // Extract info for this node. Create children if not leaf. val isLeaf = (stats.gain <= 0) || (LearningNode.indexToLevel(nodeIndex) == metadata.maxDepth) - node.predictionStats = predict node.isLeaf = isLeaf - node.stats = Some(stats) - node.impurity = stats.impurity + node.stats = stats logDebug("Node = " + node) if (!isLeaf) { @@ -587,9 +589,9 @@ private[ml] object RandomForest extends Logging { val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0) val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0) node.leftChild = Some(LearningNode(LearningNode.leftChildIndex(nodeIndex), - stats.leftPredict, stats.leftImpurity, leftChildIsLeaf)) + leftChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator))) node.rightChild = Some(LearningNode(LearningNode.rightChildIndex(nodeIndex), - stats.rightPredict, stats.rightImpurity, rightChildIsLeaf)) + rightChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.rightImpurityCalculator))) if (nodeIdCache.nonEmpty) { val nodeIndexUpdater = NodeIndexUpdater( @@ -621,28 +623,44 @@ private[ml] object RandomForest extends Logging { } /** - * Calculate the information gain for a given (feature, split) based upon left/right aggregates. + * Calculate the impurity statistics for a give (feature, split) based upon left/right aggregates. + * @param stats the recycle impurity statistics for this feature's all splits, + * only 'impurity' and 'impurityCalculator' are valid between each iteration * @param leftImpurityCalculator left node aggregates for this (feature, split) * @param rightImpurityCalculator right node aggregate for this (feature, split) - * @return information gain and statistics for split + * @param metadata learning and dataset metadata for DecisionTree + * @return Impurity statistics for this (feature, split) */ - private def calculateGainForSplit( + private def calculateImpurityStats( + stats: ImpurityStats, leftImpurityCalculator: ImpurityCalculator, rightImpurityCalculator: ImpurityCalculator, - metadata: DecisionTreeMetadata, - impurity: Double): InformationGainStats = { + metadata: DecisionTreeMetadata): ImpurityStats = { + + val parentImpurityCalculator: ImpurityCalculator = if (stats == null) { + leftImpurityCalculator.copy.add(rightImpurityCalculator) + } else { + stats.impurityCalculator + } + + val impurity: Double = if (stats == null) { + parentImpurityCalculator.calculate() + } else { + stats.impurity + } + val leftCount = leftImpurityCalculator.count val rightCount = rightImpurityCalculator.count + val totalCount = leftCount + rightCount + // If left child or right child doesn't satisfy minimum instances per node, // then this split is invalid, return invalid information gain stats. if ((leftCount < metadata.minInstancesPerNode) || (rightCount < metadata.minInstancesPerNode)) { - return InformationGainStats.invalidInformationGainStats + return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) } - val totalCount = leftCount + rightCount - val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0 val rightImpurity = rightImpurityCalculator.calculate() @@ -654,39 +672,11 @@ private[ml] object RandomForest extends Logging { // if information gain doesn't satisfy minimum information gain, // then this split is invalid, return invalid information gain stats. if (gain < metadata.minInfoGain) { - return InformationGainStats.invalidInformationGainStats + return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator) } - // calculate left and right predict - val leftPredict = calculatePredict(leftImpurityCalculator) - val rightPredict = calculatePredict(rightImpurityCalculator) - - new InformationGainStats(gain, impurity, leftImpurity, rightImpurity, - leftPredict, rightPredict) - } - - private def calculatePredict(impurityCalculator: ImpurityCalculator): Predict = { - val predict = impurityCalculator.predict - val prob = impurityCalculator.prob(predict) - new Predict(predict, prob) - } - - /** - * Calculate predict value for current node, given stats of any split. - * Note that this function is called only once for each node. - * @param leftImpurityCalculator left node aggregates for a split - * @param rightImpurityCalculator right node aggregates for a split - * @return predict value and impurity for current node - */ - private def calculatePredictImpurity( - leftImpurityCalculator: ImpurityCalculator, - rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = { - val parentNodeAgg = leftImpurityCalculator.copy - parentNodeAgg.add(rightImpurityCalculator) - val predict = calculatePredict(parentNodeAgg) - val impurity = parentNodeAgg.calculate() - - (predict, impurity) + new ImpurityStats(gain, impurity, parentImpurityCalculator, + leftImpurityCalculator, rightImpurityCalculator) } /** @@ -698,14 +688,14 @@ private[ml] object RandomForest extends Logging { binAggregates: DTStatsAggregator, splits: Array[Array[Split]], featuresForNode: Option[Array[Int]], - node: LearningNode): (Split, InformationGainStats, Predict) = { + node: LearningNode): (Split, ImpurityStats) = { - // Calculate prediction and impurity if current node is top node + // Calculate InformationGain and ImpurityStats if current node is top node val level = LearningNode.indexToLevel(node.id) - var predictionAndImpurity: Option[(Predict, Double)] = if (level == 0) { - None + var gainAndImpurityStats: ImpurityStats = if (level ==0) { + null } else { - Some((node.predictionStats, node.impurity)) + node.stats } // For each (feature, split), calculate the gain, and select the best (feature, split). @@ -734,11 +724,9 @@ private[ml] object RandomForest extends Logging { val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits) rightChildStats.subtract(leftChildStats) - predictionAndImpurity = Some(predictionAndImpurity.getOrElse( - calculatePredictImpurity(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2) - (splitIdx, gainStats) + gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, + leftChildStats, rightChildStats, binAggregates.metadata) + (splitIdx, gainAndImpurityStats) }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else if (binAggregates.metadata.isUnordered(featureIndex)) { @@ -750,11 +738,9 @@ private[ml] object RandomForest extends Logging { val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex) val rightChildStats = binAggregates.getImpurityCalculator(rightChildOffset, splitIndex) - predictionAndImpurity = Some(predictionAndImpurity.getOrElse( - calculatePredictImpurity(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2) - (splitIndex, gainStats) + gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, + leftChildStats, rightChildStats, binAggregates.metadata) + (splitIndex, gainAndImpurityStats) }.maxBy(_._2.gain) (splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats) } else { @@ -825,11 +811,9 @@ private[ml] object RandomForest extends Logging { val rightChildStats = binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory) rightChildStats.subtract(leftChildStats) - predictionAndImpurity = Some(predictionAndImpurity.getOrElse( - calculatePredictImpurity(leftChildStats, rightChildStats))) - val gainStats = calculateGainForSplit(leftChildStats, - rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2) - (splitIndex, gainStats) + gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats, + leftChildStats, rightChildStats, binAggregates.metadata) + (splitIndex, gainAndImpurityStats) }.maxBy(_._2.gain) val categoriesForSplit = categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1) @@ -839,7 +823,7 @@ private[ml] object RandomForest extends Logging { } }.maxBy(_._2.gain) - (bestSplit, bestSplitStats, predictionAndImpurity.get._1) + (bestSplit, bestSplitStats) } /** diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala index 5ac10f3fd3..0768204c33 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala @@ -118,7 +118,7 @@ private[tree] class EntropyAggregator(numClasses: Int) * (node, feature, bin). * @param stats Array of sufficient statistics for a (node, feature, bin). */ -private[tree] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { +private[spark] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { /** * Make a deep copy of this [[ImpurityCalculator]]. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala index 19d318203c..d0077db683 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala @@ -114,7 +114,7 @@ private[tree] class GiniAggregator(numClasses: Int) * (node, feature, bin). * @param stats Array of sufficient statistics for a (node, feature, bin). */ -private[tree] class GiniCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { +private[spark] class GiniCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { /** * Make a deep copy of this [[ImpurityCalculator]]. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala index 578749d85a..86cee7e430 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala @@ -95,7 +95,7 @@ private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Ser * (node, feature, bin). * @param stats Array of sufficient statistics for a (node, feature, bin). */ -private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) { +private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) extends Serializable { /** * Make a deep copy of this [[ImpurityCalculator]]. diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala index 7104a7fa4d..04d0cd24e6 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala @@ -98,7 +98,7 @@ private[tree] class VarianceAggregator() * (node, feature, bin). * @param stats Array of sufficient statistics for a (node, feature, bin). */ -private[tree] class VarianceCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { +private[spark] class VarianceCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) { require(stats.size == 3, s"VarianceCalculator requires sufficient statistics array stats to be of length 3," + diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala index dc9e0f9f51..508bf9c1bd 100644 --- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala +++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala @@ -18,6 +18,7 @@ package org.apache.spark.mllib.tree.model import org.apache.spark.annotation.DeveloperApi +import org.apache.spark.mllib.tree.impurity.ImpurityCalculator /** * :: DeveloperApi :: @@ -66,7 +67,6 @@ class InformationGainStats( } } - private[spark] object InformationGainStats { /** * An [[org.apache.spark.mllib.tree.model.InformationGainStats]] object to @@ -76,3 +76,62 @@ private[spark] object InformationGainStats { val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0, new Predict(0.0, 0.0), new Predict(0.0, 0.0)) } + +/** + * :: DeveloperApi :: + * Impurity statistics for each split + * @param gain information gain value + * @param impurity current node impurity + * @param impurityCalculator impurity statistics for current node + * @param leftImpurityCalculator impurity statistics for left child node + * @param rightImpurityCalculator impurity statistics for right child node + * @param valid whether the current split satisfies minimum info gain or + * minimum number of instances per node + */ +@DeveloperApi +private[spark] class ImpurityStats( + val gain: Double, + val impurity: Double, + val impurityCalculator: ImpurityCalculator, + val leftImpurityCalculator: ImpurityCalculator, + val rightImpurityCalculator: ImpurityCalculator, + val valid: Boolean = true) extends Serializable { + + override def toString: String = { + s"gain = $gain, impurity = $impurity, left impurity = $leftImpurity, " + + s"right impurity = $rightImpurity" + } + + def leftImpurity: Double = if (leftImpurityCalculator != null) { + leftImpurityCalculator.calculate() + } else { + -1.0 + } + + def rightImpurity: Double = if (rightImpurityCalculator != null) { + rightImpurityCalculator.calculate() + } else { + -1.0 + } +} + +private[spark] object ImpurityStats { + + /** + * Return an [[org.apache.spark.mllib.tree.model.ImpurityStats]] object to + * denote that current split doesn't satisfies minimum info gain or + * minimum number of instances per node. + */ + def getInvalidImpurityStats(impurityCalculator: ImpurityCalculator): ImpurityStats = { + new ImpurityStats(Double.MinValue, impurityCalculator.calculate(), + impurityCalculator, null, null, false) + } + + /** + * Return an [[org.apache.spark.mllib.tree.model.ImpurityStats]] object + * that only 'impurity' and 'impurityCalculator' are defined. + */ + def getEmptyImpurityStats(impurityCalculator: ImpurityCalculator): ImpurityStats = { + new ImpurityStats(Double.NaN, impurityCalculator.calculate(), impurityCalculator, null, null) + } +} diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 73b4805c4c..c7bbf1ce07 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -21,12 +21,13 @@ import org.apache.spark.SparkFunSuite import org.apache.spark.ml.impl.TreeTests import org.apache.spark.ml.param.ParamsSuite import org.apache.spark.ml.tree.LeafNode -import org.apache.spark.mllib.linalg.Vectors +import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite} import org.apache.spark.mllib.util.MLlibTestSparkContext import org.apache.spark.rdd.RDD import org.apache.spark.sql.DataFrame +import org.apache.spark.sql.Row class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { @@ -57,7 +58,7 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte test("params") { ParamsSuite.checkParams(new DecisionTreeClassifier) - val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0)) + val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2) ParamsSuite.checkParams(model) } @@ -231,6 +232,31 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses) } + test("predictRaw and predictProbability") { + val rdd = continuousDataPointsForMulticlassRDD + val dt = new DecisionTreeClassifier() + .setImpurity("Gini") + .setMaxDepth(4) + .setMaxBins(100) + val categoricalFeatures = Map(0 -> 3) + val numClasses = 3 + + val newData: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses) + val newTree = dt.fit(newData) + + val predictions = newTree.transform(newData) + .select(newTree.getPredictionCol, newTree.getRawPredictionCol, newTree.getProbabilityCol) + .collect() + + predictions.foreach { case Row(pred: Double, rawPred: Vector, probPred: Vector) => + assert(pred === rawPred.argmax, + s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.") + val sum = rawPred.toArray.sum + assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred, + "probability prediction mismatch") + } + } + ///////////////////////////////////////////////////////////////////////////// // Tests of model save/load ///////////////////////////////////////////////////////////////////////////// diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala index a7bc77965f..d4b5896c12 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala @@ -58,7 +58,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext { test("params") { ParamsSuite.checkParams(new GBTClassifier) val model = new GBTClassificationModel("gbtc", - Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0))), + Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, null))), Array(1.0)) ParamsSuite.checkParams(model) } diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala index ab711c8e4b..dbb2577c62 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala @@ -66,7 +66,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte test("params") { ParamsSuite.checkParams(new RandomForestClassifier) val model = new RandomForestClassificationModel("rfc", - Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))), 2) + Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)), 2) ParamsSuite.checkParams(model) } -- cgit v1.2.3