aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-09-10 20:34:00 -0700
committerXiangrui Meng <meng@databricks.com>2015-09-10 20:34:00 -0700
commit339a527141984bfb182862b0987d3c4690c9ede1 (patch)
treebc940f6f1eccdedfdbe614d2debd3ce60f50033a /mllib
parent0eabea8a058ad60411c1384930ba12c1c638f5f1 (diff)
downloadspark-339a527141984bfb182862b0987d3c4690c9ede1.tar.gz
spark-339a527141984bfb182862b0987d3c4690c9ede1.tar.bz2
spark-339a527141984bfb182862b0987d3c4690c9ede1.zip
[SPARK-10023] [ML] [PySpark] Unified DecisionTreeParams checkpointInterval between Scala and Python API.
"checkpointInterval" is member of DecisionTreeParams in Scala API which is inconsistency with Python API, we should unified them. ``` member of DecisionTreeParams <-> Scala API shared param for all ML Transformer/Estimator <-> Python API ``` Proposal: "checkpointInterval" is also used by ALS, so we make it shared params at Scala. Author: Yanbo Liang <ybliang8@gmail.com> Closes #8528 from yanboliang/spark-10023.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala1
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala3
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala32
4 files changed, 16 insertions, 24 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index 6f70b96b17..0a75d5d222 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -19,6 +19,7 @@ package org.apache.spark.ml.classification
import org.apache.spark.annotation.Experimental
import org.apache.spark.ml.param.ParamMap
+import org.apache.spark.ml.param.shared.HasCheckpointInterval
import org.apache.spark.ml.tree.{DecisionTreeModel, DecisionTreeParams, Node, TreeClassifierParams}
import org.apache.spark.ml.tree.impl.RandomForest
import org.apache.spark.ml.util.{Identifiable, MetadataUtils}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
index 8c16c6149b..e9e99ed1db 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/SharedParamsCodeGen.scala
@@ -56,7 +56,8 @@ private[shared] object SharedParamsCodeGen {
ParamDesc[String]("inputCol", "input column name"),
ParamDesc[Array[String]]("inputCols", "input column names"),
ParamDesc[String]("outputCol", "output column name", Some("uid + \"__output\"")),
- ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1)",
+ ParamDesc[Int]("checkpointInterval", "checkpoint interval (>= 1). E.g. 10 means that " +
+ "the cache will get checkpointed every 10 iterations.",
isValid = "ParamValidators.gtEq(1)"),
ParamDesc[Boolean]("fitIntercept", "whether to fit an intercept term", Some("true")),
ParamDesc[String]("handleInvalid", "how to handle invalid entries. Options are skip (which " +
diff --git a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
index c26768953e..3009217086 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/param/shared/sharedParams.scala
@@ -223,10 +223,10 @@ private[ml] trait HasOutputCol extends Params {
private[ml] trait HasCheckpointInterval extends Params {
/**
- * Param for checkpoint interval (>= 1).
+ * Param for checkpoint interval (>= 1). E.g. 10 means that the cache will get checkpointed every 10 iterations..
* @group param
*/
- final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval (>= 1)", ParamValidators.gtEq(1))
+ final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "checkpoint interval (>= 1). E.g. 10 means that the cache will get checkpointed every 10 iterations.", ParamValidators.gtEq(1))
/** @group getParam */
final def getCheckpointInterval: Int = $(checkpointInterval)
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index dbd8d31571..d29f5253c9 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -20,7 +20,7 @@ package org.apache.spark.ml.tree
import org.apache.spark.ml.classification.ClassifierParams
import org.apache.spark.ml.PredictorParams
import org.apache.spark.ml.param._
-import org.apache.spark.ml.param.shared.{HasMaxIter, HasSeed, HasThresholds}
+import org.apache.spark.ml.param.shared.{HasCheckpointInterval, HasMaxIter, HasSeed, HasThresholds}
import org.apache.spark.mllib.tree.configuration.{Algo => OldAlgo, BoostingStrategy => OldBoostingStrategy, Strategy => OldStrategy}
import org.apache.spark.mllib.tree.impurity.{Entropy => OldEntropy, Gini => OldGini, Impurity => OldImpurity, Variance => OldVariance}
import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
@@ -30,7 +30,7 @@ import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
*
* Note: Marked as private and DeveloperApi since this may be made public in the future.
*/
-private[ml] trait DecisionTreeParams extends PredictorParams {
+private[ml] trait DecisionTreeParams extends PredictorParams with HasCheckpointInterval {
/**
* Maximum depth of the tree (>= 0).
@@ -96,21 +96,6 @@ private[ml] trait DecisionTreeParams extends PredictorParams {
" algorithm will cache node IDs for each instance. Caching can speed up training of deeper" +
" trees.")
- /**
- * Specifies how often to checkpoint the cached node IDs.
- * E.g. 10 means that the cache will get checkpointed every 10 iterations.
- * This is only used if cacheNodeIds is true and if the checkpoint directory is set in
- * [[org.apache.spark.SparkContext]].
- * Must be >= 1.
- * (default = 10)
- * @group expertParam
- */
- final val checkpointInterval: IntParam = new IntParam(this, "checkpointInterval", "Specifies" +
- " how often to checkpoint the cached node IDs. E.g. 10 means that the cache will get" +
- " checkpointed every 10 iterations. This is only used if cacheNodeIds is true and if the" +
- " checkpoint directory is set in the SparkContext. Must be >= 1.",
- ParamValidators.gtEq(1))
-
setDefault(maxDepth -> 5, maxBins -> 32, minInstancesPerNode -> 1, minInfoGain -> 0.0,
maxMemoryInMB -> 256, cacheNodeIds -> false, checkpointInterval -> 10)
@@ -150,12 +135,17 @@ private[ml] trait DecisionTreeParams extends PredictorParams {
/** @group expertGetParam */
final def getCacheNodeIds: Boolean = $(cacheNodeIds)
- /** @group expertSetParam */
+ /**
+ * Specifies how often to checkpoint the cached node IDs.
+ * E.g. 10 means that the cache will get checkpointed every 10 iterations.
+ * This is only used if cacheNodeIds is true and if the checkpoint directory is set in
+ * [[org.apache.spark.SparkContext]].
+ * Must be >= 1.
+ * (default = 10)
+ * @group expertSetParam
+ */
def setCheckpointInterval(value: Int): this.type = set(checkpointInterval, value)
- /** @group expertGetParam */
- final def getCheckpointInterval: Int = $(checkpointInterval)
-
/** (private[ml]) Create a Strategy instance to use with the old API. */
private[ml] def getOldStrategy(
categoricalFeatures: Map[Int, Int],