aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-09-23 16:41:42 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-09-23 16:41:42 -0700
commit067afb4e9bb227f159bcbc2aafafce9693303ea9 (patch)
tree25c1677d2f34624350d37e971bbba94a81cbac06
parentce2b056d35c0c75d5c162b93680ee2d84152e911 (diff)
downloadspark-067afb4e9bb227f159bcbc2aafafce9693303ea9.tar.gz
spark-067afb4e9bb227f159bcbc2aafafce9693303ea9.tar.bz2
spark-067afb4e9bb227f159bcbc2aafafce9693303ea9.zip
[SPARK-10699] [ML] Support checkpointInterval can be disabled
Currently use can set ```checkpointInterval``` to specify how often should the cache be check-pointed. But we also need the function that users can disable it. This PR supports that users can disable checkpoint if user setting ```checkpointInterval = -1```. We also add documents for GBT ```cacheNodeIds``` to make users can understand more clearly about checkpoint. Author: Yanbo Liang <ybliang8@gmail.com> Closes #8820 from yanboliang/spark-10699.
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala6
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala2
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala4
6 files changed, 9 insertions, 10 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index a6f6d463bf..b0157f7ce2 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -19,7 +19,6 @@ package org.apache.spark.ml.classification
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.param.ParamMap
-import org.apache.spark.ml.param.shared.HasCheckpointInterval
import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams}
import org.apache.spark.ml.tree.impl.RandomForest
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index 8049d51fee..8cb6b5493c 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -56,9 +56,9 @@ private[shared] object SharedParamsCodeGen {
ParamDesc[String]("inputCol", "input column name"),
ParamDesc[Array[String]]("inputCols", "input column names"),
ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")),
- ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1). E.g. 10 means that " +
- "the cache will get checkpointed every 10 iterations.",
- isValid = "ParamValidators.gtEq(1)"),
+ ParamDesc[Int]("checkpointInterval", "set checkpoint interval (>= 1) or " +
+ "disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed " +
+ "every 10 iterations", isValid = "(interval: Int) => interval == -1 || interval >= 1"),
ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")),
ParamDesc[String]("handleInvalid", "how to handle invalid entries. Options are skip (which " +
"will filter out rows with bad values), or error (which will throw an errror). More " +
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index aff47fc326..e3625212e5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -223,10 +223,10 @@ private[ml] trait HasOutputCol extends Params {
private[ml] trait HasCheckpointInterval extends Params {
/**
- * Param for checkpoint interval (>= 1). E.g. 10 means that the cache will get checkpointed every 10 iterations..
+ * Param for set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations.
* @group param
*/
- final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval (>= 1). E.g. 10 means that the cache will get checkpointed every 10 iterations.", ParamValidators.gtEq(1))
+ final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "set checkpoint interval (>= 1) or disable checkpoint (-1). E.g. 10 means that the cache will get checkpointed every 10 iterations", (interval: Int) => interval == -1 || interval >= 1)
/** @group getParam */
final def getCheckpointInterval: Int = $(checkpointInterval)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
index 7db8ad8d27..9a56a75b69 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/recommendation/ALS.scala
@@ -561,7 +561,7 @@ object ALS extends Logging {
var itemFactors = initialize(itemInBlocks, rank, seedGen.nextLong())
var previousCheckpointFile: Option[String] = None
val shouldCheckpoint: Int => Boolean = (iter) =>
- sc.checkpointDir.isDefined && (iter % checkpointInterval == 0)
+ sc.checkpointDir.isDefined && checkpointInterval != -1 && (iter % checkpointInterval == 0)
val deletePreviousCheckpointFile: () => Unit = () =>
previousCheckpointFile.foreach { file =>
try {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala
index c5ad8df73f..1ee01131d6 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/NodeIdCache.scala
@@ -122,7 +122,7 @@ private[spark] class NodeIdCache(
rddUpdateCount += 1
// Handle checkpointing if the directory is not None.
- if (canCheckpoint && (rddUpdateCount % checkpointInterval) == 0) {
+ if (canCheckpoint && checkpointInterval != -1 && (rddUpdateCount % checkpointInterval) == 0) {
// Let's see if we can delete previous checkpoints.
var canDelete = true
while (checkpointQueue.size > 1 && canDelete) {
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index 42e74ce6d2..281ba6eeff 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -17,7 +17,6 @@
package org.apache.spark.ml.tree
-import org.apache.spark.ml.classification.ClassifierParams
import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.param._
import org.apache.spark.ml.param.shared._
@@ -87,7 +86,8 @@ private[ml] trait DecisionTreeParams extends PredictorParams with HasCheckpointI
/**
* If false, the algorithm will pass trees to executors to match instances with nodes.
* If true, the algorithm will cache node IDs for each instance.
- * Caching can speed up training of deeper trees.
+ * Caching can speed up training of deeper trees. Users can set how often should the
+ * cache be checkpointed or disable it by setting checkpointInterval.
* (default = false)
* @group expertParam
*/