aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala37
1 files changed, 37 insertions, 0 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
index cd24931293..d89682611e 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/Node.scala
@@ -279,6 +279,43 @@ private[tree] class LearningNode(
}
}
+ /**
+ * Get the node index corresponding to this data point.
+ * This function mimics prediction, passing an example from the root node down to a leaf
+ * or unsplit node; that node's index is returned.
+ *
+ * @param binnedFeatures Binned feature vector for data point.
+ * @param splits possible splits for all features, indexed (numFeatures)(numSplits)
+ * @return Leaf index if the data point reaches a leaf.
+ * Otherwise, last node reachable in tree matching this example.
+ * Note: This is the global node index, i.e., the index used in the tree.
+ * This index is different from the index used during training a particular
+ * group of nodes on one call to [[findBestSplits()]].
+ */
+ def predictImpl(binnedFeatures: Array[Int], splits: Array[Array[Split]]): Int = {
+ if (this.isLeaf || this.split.isEmpty) {
+ this.id
+ } else {
+ val split = this.split.get
+ val featureIndex = split.featureIndex
+ val splitLeft = split.shouldGoLeft(binnedFeatures(featureIndex), splits(featureIndex))
+ if (this.leftChild.isEmpty) {
+ // Not yet split. Return next layer of nodes to train
+ if (splitLeft) {
+ LearningNode.leftChildIndex(this.id)
+ } else {
+ LearningNode.rightChildIndex(this.id)
+ }
+ } else {
+ if (splitLeft) {
+ this.leftChild.get.predictImpl(binnedFeatures, splits)
+ } else {
+ this.rightChild.get.predictImpl(binnedFeatures, splits)
+ }
+ }
+ }
+ }
+
}
private[tree] object LearningNode {