From 067afb4e9bb227f159bcbc2aafafce9693303ea9 Mon Sep 17 00:00:00 2001 From: Yanbo Liang Date: Wed, 23 Sep 2015 16:41:42 -0700 Subject: [SPARK-10699] [ML] Support checkpointInterval can be disabled Currently use can set ```checkpointInterval``` to specify how often should the cache be check-pointed. But we also need the function that users can disable it. This PR supports that users can disable checkpoint if user setting ```checkpointInterval = -1```. We also add documents for GBT ```cacheNodeIds``` to make users can understand more clearly about checkpoint. Author: Yanbo Liang Closes #8820 from yanboliang/spark-10699. --- .../org/apache/spark/ml/classification/DecisionTreeClassifier.scala | 1 - .../org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala | 6 +++--- .../main/scala/org/apache/spark/ml/param/shared/sharedParams.scala | 4 ++-- mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala | 2 +- .../src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala | 2 +- mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala | 4 ++-- 6 files changed, 9 insertions(+), 10 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index a6f6d463bf..b0157f7ce2 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -19,7 +19,6 @@ package org.apache.spark.ml.classification import org.apache.spark.annotation.Experimental import org.apache.spark.ml.param.ParamMap -import org.apache.spark.ml.param.shared.HasCheckpointInterval import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams} import org.apache.spark.ml.tree.impl.RandomForest import org.apache.spark.ml.util.{Identifiable, MetadataUtils} diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala index 8049d51fee..8cb6b5493c 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala @@ -56,9 +56,9 @@ private[shared] object SharedParamsCodeGen { ParamDesc[String]("inputCol", "input column name"), ParamDesc[Array[String]]("inputCols", "input column names"), ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")), - ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1). E.g. 10 means that " + - "the cache will get checkpointed every 10 iterations.", - isValid = "ParamValidators.gtEq(1)"), + ParamDesc[Int]("checkpointInterval", "set checkpoint interval (>= 1) or " + + "disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed " + + "every 10 iterations", isValid = "(interval: Int) => interval == -1 || interval >= 1"), ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")), ParamDesc[String]("handleInvalid", "how to handle invalid entries. Options are skip (which " + "will filter out rows with bad values), or error (which will throw an errror). More " + diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala index aff47fc326..e3625212e5 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala @@ -223,10 +223,10 @@ private[ml] trait HasOutputCol extends Params { private[ml] trait HasCheckpointInterval extends Params { /** - * Param for checkpoint interval (>= 1). E.g. 10 means that the cache will get checkpointed every 10 iterations.. + * Param for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations. * @group param */ - final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval (>= 1). E.g. 10 means that the cache will get checkpointed every 10 iterations.", ParamValidators.gtEq(1)) + final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations", (interval: Int) => interval == -1 || interval >= 1) /** @group getParam */ final def getCheckpointInterval: Int = $(checkpointInterval) diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala index 7db8ad8d27..9a56a75b69 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala @@ -561,7 +561,7 @@ object ALS extends Logging { var itemFactors = initialize(itemInBlocks, rank, seedGen.nextLong()) var previousCheckpointFile: Option[String] = None val shouldCheckpoint: Int => Boolean = (iter) => - sc.checkpointDir.isDefined && (iter % checkpointInterval == 0) + sc.checkpointDir.isDefined && checkpointInterval != -1 && (iter % checkpointInterval == 0) val deletePreviousCheckpointFile: () => Unit = () => previousCheckpointFile.foreach { file => try { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala index c5ad8df73f..1ee01131d6 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala @@ -122,7 +122,7 @@ private[spark] class NodeIdCache( rddUpdateCount += 1 // Handle checkpointing if the directory is not None. - if (canCheckpoint && (rddUpdateCount % checkpointInterval) == 0) { + if (canCheckpoint && checkpointInterval != -1 && (rddUpdateCount % checkpointInterval) == 0) { // Let's see if we can delete previous checkpoints. var canDelete = true while (checkpointQueue.size > 1 && canDelete) { diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 42e74ce6d2..281ba6eeff 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -17,7 +17,6 @@ package org.apache.spark.ml.tree -import org.apache.spark.ml.classification.ClassifierParams import org.apache.spark.ml.PredictorParams import org.apache.spark.ml.param._ import org.apache.spark.ml.param.shared._ @@ -87,7 +86,8 @@ private[ml] trait DecisionTreeParams extends PredictorParams with HasCheckpointI /** * If false, the algorithm will pass trees to executors to match instances with nodes. * If true, the algorithm will cache node IDs for each instance. - * Caching can speed up training of deeper trees. + * Caching can speed up training of deeper trees. Users can set how often should the + * cache be checkpointed or disable it by setting checkpointInterval. * (default = false) * @group expertParam */ -- cgit v1.2.3