aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-07-31 11:56:52 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-07-31 11:56:52 -0700
commite8bdcdeabb2df139a656f86686cdb53c891b1f4b (patch)
treed6ccab74f50d58e7b18a786ce66dcd8f5fe30f60 /mllib
parent4011a947154d97a9ffb5a71f077481a12534d36b (diff)
downloadspark-e8bdcdeabb2df139a656f86686cdb53c891b1f4b.tar.gz
spark-e8bdcdeabb2df139a656f86686cdb53c891b1f4b.tar.bz2
spark-e8bdcdeabb2df139a656f86686cdb53c891b1f4b.zip
[SPARK-6885] [ML] decision tree support predict class probabilities
Decision tree support predict class probabilities. Implement the prediction probabilities function referred the old DecisionTree API and the [sklean API](https://github.com/scikit-learn/scikit-learn/blob/master/sklearn/tree/tree.py#L593). I make the DecisionTreeClassificationModel inherit from ProbabilisticClassificationModel, make the predictRaw to return the raw counts vector and make raw2probabilityInPlace/predictProbability return the probabilities for each prediction. Author: Yanbo Liang <ybliang8@gmail.com> Closes #7694 from yanboliang/spark-6885 and squashes the following commits: 08d5b7f [Yanbo Liang] fix ImpurityStats null parameters and raw2probabilityInPlace sum = 0 issue 2174278 [Yanbo Liang] solve merge conflicts 7e90ba8 [Yanbo Liang] fix typos 33ae183 [Yanbo Liang] fix annotation ff043d3 [Yanbo Liang] raw2probabilityInPlace should operate in-place c32d6ce [Yanbo Liang] optimize calculateImpurityStats function again 6167fb0 [Yanbo Liang] optimize calculateImpurityStats function fbbe2ec [Yanbo Liang] eliminate duplicated struct and code beb1634 [Yanbo Liang] try to eliminate impurityStats for each LearningNode 99e8943 [Yanbo Liang] code optimization 5ec3323 [Yanbo Liang] implement InformationGainAndImpurityStats 227c91b [Yanbo Liang] refactor LearningNode to store ImpurityCalculator d746ffc [Yanbo Liang] decision tree support predict class probabilities
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala40
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala80
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala126
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala61
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala30
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala2
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala2
16 files changed, 229 insertions, 130 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 36fe1bd404..f27cfd0331 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -18,12 +18,11 @@
package org.apache.spark.ml.classification
import org.apache.spark.annotation.Experimental
-import org.apache.spark.ml.{PredictionModel, Predictor}
import org.apache.spark.ml.param.ParamMap
import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams}
import org.apache.spark.ml.tree.impl.RandomForest
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
-import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.linalg.{DenseVector, SparseVector, Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.model.{DecisionTreeModel => OldDecisionTreeModel}
@@ -39,7 +38,7 @@ import org.apache.spark.sql.DataFrame
*/
@Experimental
final class DecisionTreeClassifier(override val uid: String)
- extends Predictor[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
+ extends ProbabilisticClassifier[Vector, DecisionTreeClassifier, DecisionTreeClassificationModel]
with DecisionTreeParams with TreeClassifierParams {
def this() = this(Identifiable.randomUID("dtc"))
@@ -106,8 +105,9 @@ object DecisionTreeClassifier {
@Experimental
final class DecisionTreeClassificationModel private[ml] (
override val uid: String,
- override val rootNode: Node)
- extends PredictionModel[Vector, DecisionTreeClassificationModel]
+ override val rootNode: Node,
+ override val numClasses: Int)
+ extends ProbabilisticClassificationModel[Vector, DecisionTreeClassificationModel]
with DecisionTreeModel with Serializable {
require(rootNode != null,
@@ -117,14 +117,36 @@ final class DecisionTreeClassificationModel private[ml] (
* Construct a decision tree classification model.
* @param rootNode Root node of tree, with other nodes attached.
*/
- def this(rootNode: Node) = this(Identifiable.randomUID("dtc"), rootNode)
+ def this(rootNode: Node, numClasses: Int) =
+ this(Identifiable.randomUID("dtc"), rootNode, numClasses)
override protected def predict(features: Vector): Double = {
- rootNode.predict(features)
+ rootNode.predictImpl(features).prediction
+ }
+
+ override protected def predictRaw(features: Vector): Vector = {
+ Vectors.dense(rootNode.predictImpl(features).impurityStats.stats.clone())
+ }
+
+ override protected def raw2probabilityInPlace(rawPrediction: Vector): Vector = {
+ rawPrediction match {
+ case dv: DenseVector =>
+ var i = 0
+ val size = dv.size
+ val sum = dv.values.sum
+ while (i < size) {
+ dv.values(i) = if (sum != 0) dv.values(i) / sum else 0.0
+ i += 1
+ }
+ dv
+ case sv: SparseVector =>
+ throw new RuntimeException("Unexpected error in DecisionTreeClassificationModel:" +
+ " raw2probabilityInPlace encountered SparseVector")
+ }
}
override def copy(extra: ParamMap): DecisionTreeClassificationModel = {
- copyValues(new DecisionTreeClassificationModel(uid, rootNode), extra)
+ copyValues(new DecisionTreeClassificationModel(uid, rootNode, numClasses), extra)
}
override def toString: String = {
@@ -149,6 +171,6 @@ private[ml] object DecisionTreeClassificationModel {
s" DecisionTreeClassificationModel (new API). Algo is: ${oldModel.algo}")
val rootNode = Node.fromOld(oldModel.topNode, categoricalFeatures)
val uid = if (parent != null) parent.uid else Identifiable.randomUID("dtc")
- new DecisionTreeClassificationModel(uid, rootNode)
+ new DecisionTreeClassificationModel(uid, rootNode, -1)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
index eb0b1a0a40..c3891a9599 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/GBTClassifier.scala
@@ -190,7 +190,7 @@ final class GBTClassificationModel(
override protected def predict(features: Vector): Double = {
// TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
// Classifies by thresholding sum of weighted tree predictions
- val treePredictions = _trees.map(_.rootNode.predict(features))
+ val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
val prediction = blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
if (prediction > 0.0) 1.0 else 0.0
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
index bc19bd6df8..0c7eb4a662 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/RandomForestClassifier.scala
@@ -160,7 +160,7 @@ final class RandomForestClassificationModel private[ml] (
// Ignore the weights since all are 1.0 for now.
val votes = new Array[Double](numClasses)
_trees.view.foreach { tree =>
- val prediction = tree.rootNode.predict(features).toInt
+ val prediction = tree.rootNode.predictImpl(features).prediction.toInt
votes(prediction) = votes(prediction) + 1.0 // 1.0 = weight
}
Vectors.dense(votes)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 6f3340c2f0..4d30e4b554 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -110,7 +110,7 @@ final class DecisionTreeRegressionModel private[ml] (
def this(rootNode: Node) = this(Identifiable.randomUID("dtr"), rootNode)
override protected def predict(features: Vector): Double = {
- rootNode.predict(features)
+ rootNode.predictImpl(features).prediction
}
override def copy(extra: ParamMap): DecisionTreeRegressionModel = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
index e38dc73ee0..5633bc3202 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/GBTRegressor.scala
@@ -180,7 +180,7 @@ final class GBTRegressionModel(
override protected def predict(features: Vector): Double = {
// TODO: When we add a generic Boosting class, handle transform there? SPARK-7129
// Classifies by thresholding sum of weighted tree predictions
- val treePredictions = _trees.map(_.rootNode.predict(features))
+ val treePredictions = _trees.map(_.rootNode.predictImpl(features).prediction)
blas.ddot(numTrees, treePredictions, 1, _treeWeights, 1)
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
index 506a878c25..17fb1ad5e1 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/RandomForestRegressor.scala
@@ -143,7 +143,7 @@ final class RandomForestRegressionModel private[ml] (
// TODO: When we add a generic Bagging class, handle transform there. SPARK-7128
// Predict average of tree predictions.
// Ignore the weights since all are 1.0 for now.
- _trees.map(_.rootNode.predict(features)).sum / numTrees
+ _trees.map(_.rootNode.predictImpl(features).prediction).sum / numTrees
}
override def copy(extra: ParamMap): RandomForestRegressionModel = {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
index bbc2427ca7..8879352a60 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
@@ -19,8 +19,9 @@ package org.apache.spark.ml.tree
import org.apache.spark.annotation.DeveloperApi
import org.apache.spark.mllib.linalg.Vector
+import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
import org.apache.spark.mllib.tree.model.{InformationGainStats => OldInformationGainStats,
- Node => OldNode, Predict => OldPredict}
+ Node => OldNode, Predict => OldPredict, ImpurityStats}
/**
* :: DeveloperApi ::
@@ -38,8 +39,15 @@ sealed abstract class Node extends Serializable {
/** Impurity measure at this node (for training data) */
def impurity: Double
+ /**
+ * Statistics aggregated from training data at this node, used to compute prediction, impurity,
+ * and probabilities.
+ * For classification, the array of class counts must be normalized to a probability distribution.
+ */
+ private[tree] def impurityStats: ImpurityCalculator
+
/** Recursive prediction helper method */
- private[ml] def predict(features: Vector): Double = prediction
+ private[ml] def predictImpl(features: Vector): LeafNode
/**
* Get the number of nodes in tree below this node, including leaf nodes.
@@ -75,7 +83,8 @@ private[ml] object Node {
if (oldNode.isLeaf) {
// TODO: Once the implementation has been moved to this API, then include sufficient
// statistics here.
- new LeafNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity)
+ new LeafNode(prediction = oldNode.predict.predict,
+ impurity = oldNode.impurity, impurityStats = null)
} else {
val gain = if (oldNode.stats.nonEmpty) {
oldNode.stats.get.gain
@@ -85,7 +94,7 @@ private[ml] object Node {
new InternalNode(prediction = oldNode.predict.predict, impurity = oldNode.impurity,
gain = gain, leftChild = fromOld(oldNode.leftNode.get, categoricalFeatures),
rightChild = fromOld(oldNode.rightNode.get, categoricalFeatures),
- split = Split.fromOld(oldNode.split.get, categoricalFeatures))
+ split = Split.fromOld(oldNode.split.get, categoricalFeatures), impurityStats = null)
}
}
}
@@ -99,11 +108,13 @@ private[ml] object Node {
@DeveloperApi
final class LeafNode private[ml] (
override val prediction: Double,
- override val impurity: Double) extends Node {
+ override val impurity: Double,
+ override val impurityStats: ImpurityCalculator) extends Node {
- override def toString: String = s"LeafNode(prediction = $prediction, impurity = $impurity)"
+ override def toString: String =
+ s"LeafNode(prediction = $prediction, impurity = $impurity)"
- override private[ml] def predict(features: Vector): Double = prediction
+ override private[ml] def predictImpl(features: Vector): LeafNode = this
override private[tree] def numDescendants: Int = 0
@@ -115,9 +126,8 @@ final class LeafNode private[ml] (
override private[tree] def subtreeDepth: Int = 0
override private[ml] def toOld(id: Int): OldNode = {
- // NOTE: We do NOT store 'prob' in the new API currently.
- new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = true,
- None, None, None, None)
+ new OldNode(id, new OldPredict(prediction, prob = impurityStats.prob(prediction)),
+ impurity, isLeaf = true, None, None, None, None)
}
}
@@ -139,17 +149,18 @@ final class InternalNode private[ml] (
val gain: Double,
val leftChild: Node,
val rightChild: Node,
- val split: Split) extends Node {
+ val split: Split,
+ override val impurityStats: ImpurityCalculator) extends Node {
override def toString: String = {
s"InternalNode(prediction = $prediction, impurity = $impurity, split = $split)"
}
- override private[ml] def predict(features: Vector): Double = {
+ override private[ml] def predictImpl(features: Vector): LeafNode = {
if (split.shouldGoLeft(features)) {
- leftChild.predict(features)
+ leftChild.predictImpl(features)
} else {
- rightChild.predict(features)
+ rightChild.predictImpl(features)
}
}
@@ -172,9 +183,8 @@ final class InternalNode private[ml] (
override private[ml] def toOld(id: Int): OldNode = {
assert(id.toLong * 2 < Int.MaxValue, "Decision Tree could not be converted from new to old API"
+ " since the old API does not support deep trees.")
- // NOTE: We do NOT store 'prob' in the new API currently.
- new OldNode(id, new OldPredict(prediction, prob = 0.0), impurity, isLeaf = false,
- Some(split.toOld), Some(leftChild.toOld(OldNode.leftChildIndex(id))),
+ new OldNode(id, new OldPredict(prediction, prob = impurityStats.prob(prediction)), impurity,
+ isLeaf = false, Some(split.toOld), Some(leftChild.toOld(OldNode.leftChildIndex(id))),
Some(rightChild.toOld(OldNode.rightChildIndex(id))),
Some(new OldInformationGainStats(gain, impurity, leftChild.impurity, rightChild.impurity,
new OldPredict(leftChild.prediction, prob = 0.0),
@@ -223,36 +233,36 @@ private object InternalNode {
*
* @param id We currently use the same indexing as the old implementation in
* [[org.apache.spark.mllib.tree.model.Node]], but this will change later.
- * @param predictionStats Predicted label + class probability (for classification).
- * We will later modify this to store aggregate statistics for labels
- * to provide all class probabilities (for classification) and maybe a
- * distribution (for regression).
* @param isLeaf Indicates whether this node will definitely be a leaf in the learned tree,
* so that we do not need to consider splitting it further.
- * @param stats Old structure for storing stats about information gain, prediction, etc.
- * This is legacy and will be modified in the future.
+ * @param stats Impurity statistics for this node.
*/
private[tree] class LearningNode(
var id: Int,
- var predictionStats: OldPredict,
- var impurity: Double,
var leftChild: Option[LearningNode],
var rightChild: Option[LearningNode],
var split: Option[Split],
var isLeaf: Boolean,
- var stats: Option[OldInformationGainStats]) extends Serializable {
+ var stats: ImpurityStats) extends Serializable {
/**
* Convert this [[LearningNode]] to a regular [[Node]], and recurse on any children.
*/
def toNode: Node = {
if (leftChild.nonEmpty) {
- assert(rightChild.nonEmpty && split.nonEmpty && stats.nonEmpty,
+ assert(rightChild.nonEmpty && split.nonEmpty && stats != null,
"Unknown error during Decision Tree learning. Could not convert LearningNode to Node.")
- new InternalNode(predictionStats.predict, impurity, stats.get.gain,
- leftChild.get.toNode, rightChild.get.toNode, split.get)
+ new InternalNode(stats.impurityCalculator.predict, stats.impurity, stats.gain,
+ leftChild.get.toNode, rightChild.get.toNode, split.get, stats.impurityCalculator)
} else {
- new LeafNode(predictionStats.predict, impurity)
+ if (stats.valid) {
+ new LeafNode(stats.impurityCalculator.predict, stats.impurity,
+ stats.impurityCalculator)
+ } else {
+ // Here we want to keep same behavior with the old mllib.DecisionTreeModel
+ new LeafNode(stats.impurityCalculator.predict, -1.0, stats.impurityCalculator)
+ }
+
}
}
@@ -263,16 +273,14 @@ private[tree] object LearningNode {
/** Create a node with some of its fields set. */
def apply(
id: Int,
- predictionStats: OldPredict,
- impurity: Double,
- isLeaf: Boolean): LearningNode = {
- new LearningNode(id, predictionStats, impurity, None, None, None, false, None)
+ isLeaf: Boolean,
+ stats: ImpurityStats): LearningNode = {
+ new LearningNode(id, None, None, None, false, stats)
}
/** Create an empty node with the given node index. Values must be set later on. */
def emptyNode(nodeIndex: Int): LearningNode = {
- new LearningNode(nodeIndex, new OldPredict(Double.NaN, Double.NaN), Double.NaN,
- None, None, None, false, None)
+ new LearningNode(nodeIndex, None, None, None, false, null)
}
// The below indexing methods were copied from spark.mllib.tree.model.Node
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index 15b56bd844..a8b90d9d26 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -31,7 +31,7 @@ import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, Strategy => O
import org.apache.spark.mllib.tree.impl.{BaggedPoint, DTStatsAggregator, DecisionTreeMetadata,
TimeTracker}
import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
-import org.apache.spark.mllib.tree.model.{InformationGainStats, Predict}
+import org.apache.spark.mllib.tree.model.ImpurityStats
import org.apache.spark.rdd.RDD
import org.apache.spark.storage.StorageLevel
import org.apache.spark.util.random.{SamplingUtils, XORShiftRandom}
@@ -180,13 +180,17 @@ private[ml] object RandomForest extends Logging {
parentUID match {
case Some(uid) =>
if (strategy.algo == OldAlgo.Classification) {
- topNodes.map(rootNode => new DecisionTreeClassificationModel(uid, rootNode.toNode))
+ topNodes.map { rootNode =>
+ new DecisionTreeClassificationModel(uid, rootNode.toNode, strategy.getNumClasses)
+ }
} else {
topNodes.map(rootNode => new DecisionTreeRegressionModel(uid, rootNode.toNode))
}
case None =>
if (strategy.algo == OldAlgo.Classification) {
- topNodes.map(rootNode => new DecisionTreeClassificationModel(rootNode.toNode))
+ topNodes.map { rootNode =>
+ new DecisionTreeClassificationModel(rootNode.toNode, strategy.getNumClasses)
+ }
} else {
topNodes.map(rootNode => new DecisionTreeRegressionModel(rootNode.toNode))
}
@@ -549,9 +553,9 @@ private[ml] object RandomForest extends Logging {
}
// find best split for each node
- val (split: Split, stats: InformationGainStats, predict: Predict) =
+ val (split: Split, stats: ImpurityStats) =
binsToBestSplit(aggStats, splits, featuresForNode, nodes(nodeIndex))
- (nodeIndex, (split, stats, predict))
+ (nodeIndex, (split, stats))
}.collectAsMap()
timer.stop("chooseSplits")
@@ -568,17 +572,15 @@ private[ml] object RandomForest extends Logging {
val nodeIndex = node.id
val nodeInfo = treeToNodeToIndexInfo(treeIndex)(nodeIndex)
val aggNodeIndex = nodeInfo.nodeIndexInGroup
- val (split: Split, stats: InformationGainStats, predict: Predict) =
+ val (split: Split, stats: ImpurityStats) =
nodeToBestSplits(aggNodeIndex)
logDebug("best split = " + split)
// Extract info for this node. Create children if not leaf.
val isLeaf =
(stats.gain <= 0) || (LearningNode.indexToLevel(nodeIndex) == metadata.maxDepth)
- node.predictionStats = predict
node.isLeaf = isLeaf
- node.stats = Some(stats)
- node.impurity = stats.impurity
+ node.stats = stats
logDebug("Node = " + node)
if (!isLeaf) {
@@ -587,9 +589,9 @@ private[ml] object RandomForest extends Logging {
val leftChildIsLeaf = childIsLeaf || (stats.leftImpurity == 0.0)
val rightChildIsLeaf = childIsLeaf || (stats.rightImpurity == 0.0)
node.leftChild = Some(LearningNode(LearningNode.leftChildIndex(nodeIndex),
- stats.leftPredict, stats.leftImpurity, leftChildIsLeaf))
+ leftChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.leftImpurityCalculator)))
node.rightChild = Some(LearningNode(LearningNode.rightChildIndex(nodeIndex),
- stats.rightPredict, stats.rightImpurity, rightChildIsLeaf))
+ rightChildIsLeaf, ImpurityStats.getEmptyImpurityStats(stats.rightImpurityCalculator)))
if (nodeIdCache.nonEmpty) {
val nodeIndexUpdater = NodeIndexUpdater(
@@ -621,28 +623,44 @@ private[ml] object RandomForest extends Logging {
}
/**
- * Calculate the information gain for a given (feature, split) based upon left/right aggregates.
+ * Calculate the impurity statistics for a give (feature, split) based upon left/right aggregates.
+ * @param stats the recycle impurity statistics for this feature's all splits,
+ * only 'impurity' and 'impurityCalculator' are valid between each iteration
* @param leftImpurityCalculator left node aggregates for this (feature, split)
* @param rightImpurityCalculator right node aggregate for this (feature, split)
- * @return information gain and statistics for split
+ * @param metadata learning and dataset metadata for DecisionTree
+ * @return Impurity statistics for this (feature, split)
*/
- private def calculateGainForSplit(
+ private def calculateImpurityStats(
+ stats: ImpurityStats,
leftImpurityCalculator: ImpurityCalculator,
rightImpurityCalculator: ImpurityCalculator,
- metadata: DecisionTreeMetadata,
- impurity: Double): InformationGainStats = {
+ metadata: DecisionTreeMetadata): ImpurityStats = {
+
+ val parentImpurityCalculator: ImpurityCalculator = if (stats == null) {
+ leftImpurityCalculator.copy.add(rightImpurityCalculator)
+ } else {
+ stats.impurityCalculator
+ }
+
+ val impurity: Double = if (stats == null) {
+ parentImpurityCalculator.calculate()
+ } else {
+ stats.impurity
+ }
+
val leftCount = leftImpurityCalculator.count
val rightCount = rightImpurityCalculator.count
+ val totalCount = leftCount + rightCount
+
// If left child or right child doesn't satisfy minimum instances per node,
// then this split is invalid, return invalid information gain stats.
if ((leftCount < metadata.minInstancesPerNode) ||
(rightCount < metadata.minInstancesPerNode)) {
- return InformationGainStats.invalidInformationGainStats
+ return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
}
- val totalCount = leftCount + rightCount
-
val leftImpurity = leftImpurityCalculator.calculate() // Note: This equals 0 if count = 0
val rightImpurity = rightImpurityCalculator.calculate()
@@ -654,39 +672,11 @@ private[ml] object RandomForest extends Logging {
// if information gain doesn't satisfy minimum information gain,
// then this split is invalid, return invalid information gain stats.
if (gain < metadata.minInfoGain) {
- return InformationGainStats.invalidInformationGainStats
+ return ImpurityStats.getInvalidImpurityStats(parentImpurityCalculator)
}
- // calculate left and right predict
- val leftPredict = calculatePredict(leftImpurityCalculator)
- val rightPredict = calculatePredict(rightImpurityCalculator)
-
- new InformationGainStats(gain, impurity, leftImpurity, rightImpurity,
- leftPredict, rightPredict)
- }
-
- private def calculatePredict(impurityCalculator: ImpurityCalculator): Predict = {
- val predict = impurityCalculator.predict
- val prob = impurityCalculator.prob(predict)
- new Predict(predict, prob)
- }
-
- /**
- * Calculate predict value for current node, given stats of any split.
- * Note that this function is called only once for each node.
- * @param leftImpurityCalculator left node aggregates for a split
- * @param rightImpurityCalculator right node aggregates for a split
- * @return predict value and impurity for current node
- */
- private def calculatePredictImpurity(
- leftImpurityCalculator: ImpurityCalculator,
- rightImpurityCalculator: ImpurityCalculator): (Predict, Double) = {
- val parentNodeAgg = leftImpurityCalculator.copy
- parentNodeAgg.add(rightImpurityCalculator)
- val predict = calculatePredict(parentNodeAgg)
- val impurity = parentNodeAgg.calculate()
-
- (predict, impurity)
+ new ImpurityStats(gain, impurity, parentImpurityCalculator,
+ leftImpurityCalculator, rightImpurityCalculator)
}
/**
@@ -698,14 +688,14 @@ private[ml] object RandomForest extends Logging {
binAggregates: DTStatsAggregator,
splits: Array[Array[Split]],
featuresForNode: Option[Array[Int]],
- node: LearningNode): (Split, InformationGainStats, Predict) = {
+ node: LearningNode): (Split, ImpurityStats) = {
- // Calculate prediction and impurity if current node is top node
+ // Calculate InformationGain and ImpurityStats if current node is top node
val level = LearningNode.indexToLevel(node.id)
- var predictionAndImpurity: Option[(Predict, Double)] = if (level == 0) {
- None
+ var gainAndImpurityStats: ImpurityStats = if (level ==0) {
+ null
} else {
- Some((node.predictionStats, node.impurity))
+ node.stats
}
// For each (feature, split), calculate the gain, and select the best (feature, split).
@@ -734,11 +724,9 @@ private[ml] object RandomForest extends Logging {
val rightChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, numSplits)
rightChildStats.subtract(leftChildStats)
- predictionAndImpurity = Some(predictionAndImpurity.getOrElse(
- calculatePredictImpurity(leftChildStats, rightChildStats)))
- val gainStats = calculateGainForSplit(leftChildStats,
- rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2)
- (splitIdx, gainStats)
+ gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
+ leftChildStats, rightChildStats, binAggregates.metadata)
+ (splitIdx, gainAndImpurityStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else if (binAggregates.metadata.isUnordered(featureIndex)) {
@@ -750,11 +738,9 @@ private[ml] object RandomForest extends Logging {
val leftChildStats = binAggregates.getImpurityCalculator(leftChildOffset, splitIndex)
val rightChildStats =
binAggregates.getImpurityCalculator(rightChildOffset, splitIndex)
- predictionAndImpurity = Some(predictionAndImpurity.getOrElse(
- calculatePredictImpurity(leftChildStats, rightChildStats)))
- val gainStats = calculateGainForSplit(leftChildStats,
- rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2)
- (splitIndex, gainStats)
+ gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
+ leftChildStats, rightChildStats, binAggregates.metadata)
+ (splitIndex, gainAndImpurityStats)
}.maxBy(_._2.gain)
(splits(featureIndex)(bestFeatureSplitIndex), bestFeatureGainStats)
} else {
@@ -825,11 +811,9 @@ private[ml] object RandomForest extends Logging {
val rightChildStats =
binAggregates.getImpurityCalculator(nodeFeatureOffset, lastCategory)
rightChildStats.subtract(leftChildStats)
- predictionAndImpurity = Some(predictionAndImpurity.getOrElse(
- calculatePredictImpurity(leftChildStats, rightChildStats)))
- val gainStats = calculateGainForSplit(leftChildStats,
- rightChildStats, binAggregates.metadata, predictionAndImpurity.get._2)
- (splitIndex, gainStats)
+ gainAndImpurityStats = calculateImpurityStats(gainAndImpurityStats,
+ leftChildStats, rightChildStats, binAggregates.metadata)
+ (splitIndex, gainAndImpurityStats)
}.maxBy(_._2.gain)
val categoriesForSplit =
categoriesSortedByCentroid.map(_._1.toDouble).slice(0, bestFeatureSplitIndex + 1)
@@ -839,7 +823,7 @@ private[ml] object RandomForest extends Logging {
}
}.maxBy(_._2.gain)
- (bestSplit, bestSplitStats, predictionAndImpurity.get._1)
+ (bestSplit, bestSplitStats)
}
/**
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
index 5ac10f3fd3..0768204c33 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Entropy.scala
@@ -118,7 +118,7 @@ private[tree] class EntropyAggregator(numClasses: Int)
* (node, feature, bin).
* @param stats Array of sufficient statistics for a (node, feature, bin).
*/
-private[tree] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
+private[spark] class EntropyCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
/**
* Make a deep copy of this [[ImpurityCalculator]].
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
index 19d318203c..d0077db683 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Gini.scala
@@ -114,7 +114,7 @@ private[tree] class GiniAggregator(numClasses: Int)
* (node, feature, bin).
* @param stats Array of sufficient statistics for a (node, feature, bin).
*/
-private[tree] class GiniCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
+private[spark] class GiniCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
/**
* Make a deep copy of this [[ImpurityCalculator]].
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
index 578749d85a..86cee7e430 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Impurity.scala
@@ -95,7 +95,7 @@ private[spark] abstract class ImpurityAggregator(val statsSize: Int) extends Ser
* (node, feature, bin).
* @param stats Array of sufficient statistics for a (node, feature, bin).
*/
-private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) {
+private[spark] abstract class ImpurityCalculator(val stats: Array[Double]) extends Serializable {
/**
* Make a deep copy of this [[ImpurityCalculator]].
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
index 7104a7fa4d..04d0cd24e6 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impurity/Variance.scala
@@ -98,7 +98,7 @@ private[tree] class VarianceAggregator()
* (node, feature, bin).
* @param stats Array of sufficient statistics for a (node, feature, bin).
*/
-private[tree] class VarianceCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
+private[spark] class VarianceCalculator(stats: Array[Double]) extends ImpurityCalculator(stats) {
require(stats.size == 3,
s"VarianceCalculator requires sufficient statistics array stats to be of length 3," +
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
index dc9e0f9f51..508bf9c1bd 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/model/InformationGainStats.scala
@@ -18,6 +18,7 @@
package org.apache.spark.mllib.tree.model
import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.mllib.tree.impurity.ImpurityCalculator
/**
* :: DeveloperApi ::
@@ -66,7 +67,6 @@ class InformationGainStats(
}
}
-
private[spark] object InformationGainStats {
/**
* An [[org.apache.spark.mllib.tree.model.InformationGainStats]] object to
@@ -76,3 +76,62 @@ private[spark] object InformationGainStats {
val invalidInformationGainStats = new InformationGainStats(Double.MinValue, -1.0, -1.0, -1.0,
new Predict(0.0, 0.0), new Predict(0.0, 0.0))
}
+
+/**
+ * :: DeveloperApi ::
+ * Impurity statistics for each split
+ * @param gain information gain value
+ * @param impurity current node impurity
+ * @param impurityCalculator impurity statistics for current node
+ * @param leftImpurityCalculator impurity statistics for left child node
+ * @param rightImpurityCalculator impurity statistics for right child node
+ * @param valid whether the current split satisfies minimum info gain or
+ * minimum number of instances per node
+ */
+@DeveloperApi
+private[spark] class ImpurityStats(
+ val gain: Double,
+ val impurity: Double,
+ val impurityCalculator: ImpurityCalculator,
+ val leftImpurityCalculator: ImpurityCalculator,
+ val rightImpurityCalculator: ImpurityCalculator,
+ val valid: Boolean = true) extends Serializable {
+
+ override def toString: String = {
+ s"gain = $gain, impurity = $impurity, left impurity = $leftImpurity, " +
+ s"right impurity = $rightImpurity"
+ }
+
+ def leftImpurity: Double = if (leftImpurityCalculator != null) {
+ leftImpurityCalculator.calculate()
+ } else {
+ -1.0
+ }
+
+ def rightImpurity: Double = if (rightImpurityCalculator != null) {
+ rightImpurityCalculator.calculate()
+ } else {
+ -1.0
+ }
+}
+
+private[spark] object ImpurityStats {
+
+ /**
+ * Return an [[org.apache.spark.mllib.tree.model.ImpurityStats]] object to
+ * denote that current split doesn't satisfies minimum info gain or
+ * minimum number of instances per node.
+ */
+ def getInvalidImpurityStats(impurityCalculator: ImpurityCalculator): ImpurityStats = {
+ new ImpurityStats(Double.MinValue, impurityCalculator.calculate(),
+ impurityCalculator, null, null, false)
+ }
+
+ /**
+ * Return an [[org.apache.spark.mllib.tree.model.ImpurityStats]] object
+ * that only 'impurity' and 'impurityCalculator' are defined.
+ */
+ def getEmptyImpurityStats(impurityCalculator: ImpurityCalculator): ImpurityStats = {
+ new ImpurityStats(Double.NaN, impurityCalculator.calculate(), impurityCalculator, null, null)
+ }
+}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index 73b4805c4c..c7bbf1ce07 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -21,12 +21,13 @@ import org.apache.spark.SparkFunSuite
import org.apache.spark.ml.impl.TreeTests
import org.apache.spark.ml.param.ParamsSuite
import org.apache.spark.ml.tree.LeafNode
-import org.apache.spark.mllib.linalg.Vectors
+import org.apache.spark.mllib.linalg.{Vector, Vectors}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{DecisionTree => OldDecisionTree, DecisionTreeSuite => OldDecisionTreeSuite}
import org.apache.spark.mllib.util.MLlibTestSparkContext
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.DataFrame
+import org.apache.spark.sql.Row
class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
@@ -57,7 +58,7 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
test("params") {
ParamsSuite.checkParams(new DecisionTreeClassifier)
- val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))
+ val model = new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)
ParamsSuite.checkParams(model)
}
@@ -231,6 +232,31 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
compareAPIs(rdd, dt, categoricalFeatures = Map.empty[Int, Int], numClasses)
}
+ test("predictRaw and predictProbability") {
+ val rdd = continuousDataPointsForMulticlassRDD
+ val dt = new DecisionTreeClassifier()
+ .setImpurity("Gini")
+ .setMaxDepth(4)
+ .setMaxBins(100)
+ val categoricalFeatures = Map(0 -> 3)
+ val numClasses = 3
+
+ val newData: DataFrame = TreeTests.setMetadata(rdd, categoricalFeatures, numClasses)
+ val newTree = dt.fit(newData)
+
+ val predictions = newTree.transform(newData)
+ .select(newTree.getPredictionCol, newTree.getRawPredictionCol, newTree.getProbabilityCol)
+ .collect()
+
+ predictions.foreach { case Row(pred: Double, rawPred: Vector, probPred: Vector) =>
+ assert(pred === rawPred.argmax,
+ s"Expected prediction $pred but calculated ${rawPred.argmax} from rawPrediction.")
+ val sum = rawPred.toArray.sum
+ assert(Vectors.dense(rawPred.toArray.map(_ / sum)) === probPred,
+ "probability prediction mismatch")
+ }
+ }
+
/////////////////////////////////////////////////////////////////////////////
// Tests of model save/load
/////////////////////////////////////////////////////////////////////////////
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
index a7bc77965f..d4b5896c12 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/GBTClassifierSuite.scala
@@ -58,7 +58,7 @@ class GBTClassifierSuite extends SparkFunSuite with MLlibTestSparkContext {
test("params") {
ParamsSuite.checkParams(new GBTClassifier)
val model = new GBTClassificationModel("gbtc",
- Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0))),
+ Array(new DecisionTreeRegressionModel("dtr", new LeafNode(0.0, 0.0, null))),
Array(1.0))
ParamsSuite.checkParams(model)
}
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
index ab711c8e4b..dbb2577c62 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/RandomForestClassifierSuite.scala
@@ -66,7 +66,7 @@ class RandomForestClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
test("params") {
ParamsSuite.checkParams(new RandomForestClassifier)
val model = new RandomForestClassificationModel("rfc",
- Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0))), 2)
+ Array(new DecisionTreeClassificationModel("dtc", new LeafNode(0.0, 0.0, null), 2)), 2)
ParamsSuite.checkParams(model)
}