aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorLuvsandondov Lkhamsuren <lkhamsurenl@gmail.com>2015-10-17 10:07:42 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-10-17 10:07:42 -0700
commitcca2258685147be6c950c9f5c4e50eaa1e090714 (patch)
tree6f3d9bf93fd18e26d43170c564f02f8dff50b9bf /mllib
parente1e77b22b3b577909a12c3aa898eb53be02267fd (diff)
downloadspark-cca2258685147be6c950c9f5c4e50eaa1e090714.tar.gz
spark-cca2258685147be6c950c9f5c4e50eaa1e090714.tar.bz2
spark-cca2258685147be6c950c9f5c4e50eaa1e090714.zip
[SPARK-9963] [ML] RandomForest cleanup: replace predictNodeIndex with predictImpl
predictNodeIndex is moved to LearningNode and renamed predictImpl for consistency with Node.predictImpl Author: Luvsandondov Lkhamsuren <lkhamsurenl@gmail.com> Closes #8609 from lkhamsurenl/SPARK-9963.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala37
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala44
2 files changed, 38 insertions, 43 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
index cd24931293..d89682611e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
@@ -279,6 +279,43 @@ private[tree] class LearningNode(
}
}
+ /**
+ * Get the node index corresponding to this data point.
+ * This function mimics prediction, passing an example from the root node down to a leaf
+ * or unsplit node; that node's index is returned.
+ *
+ * @param binnedFeatures Binned feature vector for data point.
+ * @param splits possible splits for all features, indexed (numFeatures)(numSplits)
+ * @return Leaf index if the data point reaches a leaf.
+ * Otherwise, last node reachable in tree matching this example.
+ * Note: This is the global node index, i.e., the index used in the tree.
+ * This index is different from the index used during training a particular
+ * group of nodes on one call to [[findBestSplits()]].
+ */
+ def predictImpl(binnedFeatures: Array[Int], splits: Array[Array[Split]]): Int = {
+ if (this.isLeaf || this.split.isEmpty) {
+ this.id
+ } else {
+ val split = this.split.get
+ val featureIndex = split.featureIndex
+ val splitLeft = split.shouldGoLeft(binnedFeatures(featureIndex), splits(featureIndex))
+ if (this.leftChild.isEmpty) {
+ // Not yet split. Return next layer of nodes to train
+ if (splitLeft) {
+ LearningNode.leftChildIndex(this.id)
+ } else {
+ LearningNode.rightChildIndex(this.id)
+ }
+ } else {
+ if (splitLeft) {
+ this.leftChild.get.predictImpl(binnedFeatures, splits)
+ } else {
+ this.rightChild.get.predictImpl(binnedFeatures, splits)
+ }
+ }
+ }
+ }
+
}
private[tree] object LearningNode {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index c494556085..96d5652857 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -206,47 +206,6 @@ private[ml] object RandomForest extends Logging {
}
/**
- * Get the node index corresponding to this data point.
- * This function mimics prediction, passing an example from the root node down to a leaf
- * or unsplit node; that node's index is returned.
- *
- * @param node Node in tree from which to classify the given data point.
- * @param binnedFeatures Binned feature vector for data point.
- * @param splits possible splits for all features, indexed (numFeatures)(numSplits)
- * @return Leaf index if the data point reaches a leaf.
- * Otherwise, last node reachable in tree matching this example.
- * Note: This is the global node index, i.e., the index used in the tree.
- * This index is different from the index used during training a particular
- * group of nodes on one call to [[findBestSplits()]].
- */
- private def predictNodeIndex(
- node: LearningNode,
- binnedFeatures: Array[Int],
- splits: Array[Array[Split]]): Int = {
- if (node.isLeaf || node.split.isEmpty) {
- node.id
- } else {
- val split = node.split.get
- val featureIndex = split.featureIndex
- val splitLeft = split.shouldGoLeft(binnedFeatures(featureIndex), splits(featureIndex))
- if (node.leftChild.isEmpty) {
- // Not yet split. Return index from next layer of nodes to train
- if (splitLeft) {
- LearningNode.leftChildIndex(node.id)
- } else {
- LearningNode.rightChildIndex(node.id)
- }
- } else {
- if (splitLeft) {
- predictNodeIndex(node.leftChild.get, binnedFeatures, splits)
- } else {
- predictNodeIndex(node.rightChild.get, binnedFeatures, splits)
- }
- }
- }
- }
-
- /**
* Helper for binSeqOp, for data which can contain a mix of ordered and unordered features.
*
* For ordered features, a single bin is updated.
@@ -453,8 +412,7 @@ private[ml] object RandomForest extends Logging {
agg: Array[DTStatsAggregator],
baggedPoint: BaggedPoint[TreePoint]): Array[DTStatsAggregator] = {
treeToNodeToIndexInfo.foreach { case (treeIndex, nodeIndexToInfo) =>
- val nodeIndex =
- predictNodeIndex(topNodes(treeIndex), baggedPoint.datum.binnedFeatures, splits)
+ val nodeIndex = topNodes(treeIndex).predictImpl(baggedPoint.datum.binnedFeatures, splits)
nodeBinSeqOp(treeIndex, nodeIndexToInfo.getOrElse(nodeIndex, null), agg, baggedPoint)
}
agg