aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala10
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala12
4 files changed, 6 insertions, 20 deletions
diff --git a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
index 205d80dd02..262fd2c961 100644
--- a/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
+++ b/examples/src/main/scala/org/apache/spark/examples/mllib/DecisionTreeRunner.scala
@@ -272,6 +272,8 @@ object DecisionTreeRunner {
case Variance => impurity.Variance
}
+ params.checkpointDir.foreach(sc.setCheckpointDir)
+
val strategy
= new Strategy(
algo = params.algo,
@@ -282,7 +284,6 @@ object DecisionTreeRunner {
minInstancesPerNode = params.minInstancesPerNode,
minInfoGain = params.minInfoGain,
useNodeIdCache = params.useNodeIdCache,
- checkpointDir = params.checkpointDir,
checkpointInterval = params.checkpointInterval)
if (params.numTrees == 1) {
val startTime = System.nanoTime()
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
index 45b0154c5e..db01f2e229 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/RandomForest.scala
@@ -204,7 +204,6 @@ private class RandomForest (
Some(NodeIdCache.init(
data = baggedInput,
numTrees = numTrees,
- checkpointDir = strategy.checkpointDir,
checkpointInterval = strategy.checkpointInterval,
initVal = 1))
} else {
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
index 3308adb675..8d5c36da32 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/configuration/Strategy.scala
@@ -62,11 +62,10 @@ import org.apache.spark.mllib.tree.configuration.QuantileStrategy._
* @param subsamplingRate Fraction of the training data used for learning decision tree.
* @param useNodeIdCache If this is true, instead of passing trees to executors, the algorithm will
* maintain a separate RDD of node Id cache for each row.
- * @param checkpointDir If the node Id cache is used, it will help to checkpoint
- * the node Id cache periodically. This is the checkpoint directory
- * to be used for the node Id cache.
* @param checkpointInterval How often to checkpoint when the node Id cache gets updated.
- * E.g. 10 means that the cache will get checkpointed every 10 updates.
+ * E.g. 10 means that the cache will get checkpointed every 10 updates. If
+ * the checkpoint directory is not set in
+ * [[org.apache.spark.SparkContext]], this setting is ignored.
*/
@Experimental
class Strategy (
@@ -82,7 +81,6 @@ class Strategy (
@BeanProperty var maxMemoryInMB: Int = 256,
@BeanProperty var subsamplingRate: Double = 1,
@BeanProperty var useNodeIdCache: Boolean = false,
- @BeanProperty var checkpointDir: Option[String] = None,
@BeanProperty var checkpointInterval: Int = 10) extends Serializable {
def isMulticlassClassification =
@@ -165,7 +163,7 @@ class Strategy (
def copy: Strategy = {
new Strategy(algo, impurity, maxDepth, numClasses, maxBins,
quantileCalculationStrategy, categoricalFeaturesInfo, minInstancesPerNode, minInfoGain,
- maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointDir, checkpointInterval)
+ maxMemoryInMB, subsamplingRate, useNodeIdCache, checkpointInterval)
}
}
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
index 83011b48b7..bdd0f576b0 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/tree/impl/NodeIdCache.scala
@@ -71,15 +71,12 @@ private[tree] case class NodeIndexUpdater(
* The nodeIdsForInstances RDD needs to be updated at each iteration.
* @param nodeIdsForInstances The initial values in the cache
* (should be an Array of all 1's (meaning the root nodes)).
- * @param checkpointDir The checkpoint directory where
- * the checkpointed files will be stored.
* @param checkpointInterval The checkpointing interval
* (how often should the cache be checkpointed.).
*/
@DeveloperApi
private[tree] class NodeIdCache(
var nodeIdsForInstances: RDD[Array[Int]],
- val checkpointDir: Option[String],
val checkpointInterval: Int) {
// Keep a reference to a previous node Ids for instances.
@@ -91,12 +88,6 @@ private[tree] class NodeIdCache(
private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]()
private var rddUpdateCount = 0
- // If a checkpoint directory is given, and there's no prior checkpoint directory,
- // then set the checkpoint directory with the given one.
- if (checkpointDir.nonEmpty && nodeIdsForInstances.sparkContext.getCheckpointDir.isEmpty) {
- nodeIdsForInstances.sparkContext.setCheckpointDir(checkpointDir.get)
- }
-
/**
* Update the node index values in the cache.
* This updates the RDD and its lineage.
@@ -184,7 +175,6 @@ private[tree] object NodeIdCache {
* Initialize the node Id cache with initial node Id values.
* @param data The RDD of training rows.
* @param numTrees The number of trees that we want to create cache for.
- * @param checkpointDir The checkpoint directory where the checkpointed files will be stored.
* @param checkpointInterval The checkpointing interval
* (how often should the cache be checkpointed.).
* @param initVal The initial values in the cache.
@@ -193,12 +183,10 @@ private[tree] object NodeIdCache {
def init(
data: RDD[BaggedPoint[TreePoint]],
numTrees: Int,
- checkpointDir: Option[String],
checkpointInterval: Int,
initVal: Int = 1): NodeIdCache = {
new NodeIdCache(
data.map(_ => Array.fill[Int](numTrees)(initVal)),
- checkpointDir,
checkpointInterval)
}
}