aboutsummaryrefslogtreecommitdiff
path: root/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
diff options
context:
space:
mode:
Diffstat (limited to 'mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala')
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala195
1 files changed, 0 insertions, 195 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
deleted file mode 100644
index dc7e969f7b..0000000000
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
+++ /dev/null
@@ -1,195 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark.mllib.tree.impl
-
-import scala.collection.mutable
-
-import org.apache.hadoop.fs.{FileSystem, Path}
-
-import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.mllib.tree.configuration.FeatureType._
-import org.apache.spark.mllib.tree.model.{Bin, Node, Split}
-import org.apache.spark.rdd.RDD
-import org.apache.spark.storage.StorageLevel
-
-/**
- * :: DeveloperApi ::
- * This is used by the node id cache to find the child id that a data point would belong to.
- * @param split Split information.
- * @param nodeIndex The current node index of a data point that this will update.
- */
-@DeveloperApi
-private[tree] case class NodeIndexUpdater(
- split: Split,
- nodeIndex: Int) {
- /**
- * Determine a child node index based on the feature value and the split.
- * @param binnedFeatures Binned feature values.
- * @param bins Bin information to convert the bin indices to approximate feature values.
- * @return Child node index to update to.
- */
- def updateNodeIndex(binnedFeatures: Array[Int], bins: Array[Array[Bin]]): Int = {
- if (split.featureType == Continuous) {
- val featureIndex = split.feature
- val binIndex = binnedFeatures(featureIndex)
- val featureValueUpperBound = bins(featureIndex)(binIndex).highSplit.threshold
- if (featureValueUpperBound <= split.threshold) {
- Node.leftChildIndex(nodeIndex)
- } else {
- Node.rightChildIndex(nodeIndex)
- }
- } else {
- if (split.categories.contains(binnedFeatures(split.feature).toDouble)) {
- Node.leftChildIndex(nodeIndex)
- } else {
- Node.rightChildIndex(nodeIndex)
- }
- }
- }
-}
-
-/**
- * :: DeveloperApi ::
- * A given TreePoint would belong to a particular node per tree.
- * Each row in the nodeIdsForInstances RDD is an array over trees of the node index
- * in each tree. Initially, values should all be 1 for root node.
- * The nodeIdsForInstances RDD needs to be updated at each iteration.
- * @param nodeIdsForInstances The initial values in the cache
- * (should be an Array of all 1's (meaning the root nodes)).
- * @param checkpointInterval The checkpointing interval
- * (how often should the cache be checkpointed.).
- */
-@DeveloperApi
-private[spark] class NodeIdCache(
- var nodeIdsForInstances: RDD[Array[Int]],
- val checkpointInterval: Int) {
-
- // Keep a reference to a previous node Ids for instances.
- // Because we will keep on re-persisting updated node Ids,
- // we want to unpersist the previous RDD.
- private var prevNodeIdsForInstances: RDD[Array[Int]] = null
-
- // To keep track of the past checkpointed RDDs.
- private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]()
- private var rddUpdateCount = 0
-
- /**
- * Update the node index values in the cache.
- * This updates the RDD and its lineage.
- * TODO: Passing bin information to executors seems unnecessary and costly.
- * @param data The RDD of training rows.
- * @param nodeIdUpdaters A map of node index updaters.
- * The key is the indices of nodes that we want to update.
- * @param bins Bin information needed to find child node indices.
- */
- def updateNodeIndices(
- data: RDD[BaggedPoint[TreePoint]],
- nodeIdUpdaters: Array[mutable.Map[Int, NodeIndexUpdater]],
- bins: Array[Array[Bin]]): Unit = {
- if (prevNodeIdsForInstances != null) {
- // Unpersist the previous one if one exists.
- prevNodeIdsForInstances.unpersist()
- }
-
- prevNodeIdsForInstances = nodeIdsForInstances
- nodeIdsForInstances = data.zip(nodeIdsForInstances).map {
- case (point, node) => {
- var treeId = 0
- while (treeId < nodeIdUpdaters.length) {
- val nodeIdUpdater = nodeIdUpdaters(treeId).getOrElse(node(treeId), null)
- if (nodeIdUpdater != null) {
- val newNodeIndex = nodeIdUpdater.updateNodeIndex(
- binnedFeatures = point.datum.binnedFeatures,
- bins = bins)
- node(treeId) = newNodeIndex
- }
-
- treeId += 1
- }
-
- node
- }
- }
-
- // Keep on persisting new ones.
- nodeIdsForInstances.persist(StorageLevel.MEMORY_AND_DISK)
- rddUpdateCount += 1
-
- // Handle checkpointing if the directory is not None.
- if (nodeIdsForInstances.sparkContext.getCheckpointDir.nonEmpty &&
- (rddUpdateCount % checkpointInterval) == 0) {
- // Let's see if we can delete previous checkpoints.
- var canDelete = true
- while (checkpointQueue.size > 1 && canDelete) {
- // We can delete the oldest checkpoint iff
- // the next checkpoint actually exists in the file system.
- if (checkpointQueue.get(1).get.getCheckpointFile.isDefined) {
- val old = checkpointQueue.dequeue()
-
- // Since the old checkpoint is not deleted by Spark,
- // we'll manually delete it here.
- val fs = FileSystem.get(old.sparkContext.hadoopConfiguration)
- fs.delete(new Path(old.getCheckpointFile.get), true)
- } else {
- canDelete = false
- }
- }
-
- nodeIdsForInstances.checkpoint()
- checkpointQueue.enqueue(nodeIdsForInstances)
- }
- }
-
- /**
- * Call this after training is finished to delete any remaining checkpoints.
- */
- def deleteAllCheckpoints(): Unit = {
- while (checkpointQueue.nonEmpty) {
- val old = checkpointQueue.dequeue()
- for (checkpointFile <- old.getCheckpointFile) {
- val fs = FileSystem.get(old.sparkContext.hadoopConfiguration)
- fs.delete(new Path(checkpointFile), true)
- }
- }
- if (prevNodeIdsForInstances != null) {
- // Unpersist the previous one if one exists.
- prevNodeIdsForInstances.unpersist()
- }
- }
-}
-
-private[spark] object NodeIdCache {
- /**
- * Initialize the node Id cache with initial node Id values.
- * @param data The RDD of training rows.
- * @param numTrees The number of trees that we want to create cache for.
- * @param checkpointInterval The checkpointing interval
- * (how often should the cache be checkpointed.).
- * @param initVal The initial values in the cache.
- * @return A node Id cache containing an RDD of initial root node Indices.
- */
- def init(
- data: RDD[BaggedPoint[TreePoint]],
- numTrees: Int,
- checkpointInterval: Int,
- initVal: Int = 1): NodeIdCache = {
- new NodeIdCache(
- data.map(_ => Array.fill[Int](numTrees)(initVal)),
- checkpointInterval)
- }
-}