aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYu ISHIKAWA <yuu.ishikawa@gmail.com>2015-11-05 17:59:01 -0800
committerJoseph K. Bradley <joseph@databricks.com>2015-11-05 17:59:01 -0800
commit8fa8c8375d7015a0332aa9ee613d7c6b6d62bae7 (patch)
tree8cdefd291fe5dc3a555c0424d2bc402c91a980a8 /mllib
parent6091e91fca58078a0f1d9c35d68c0ae7205a534c (diff)
downloadspark-8fa8c8375d7015a0332aa9ee613d7c6b6d62bae7.tar.gz
spark-8fa8c8375d7015a0332aa9ee613d7c6b6d62bae7.tar.bz2
spark-8fa8c8375d7015a0332aa9ee613d7c6b6d62bae7.zip
[SPARK-11514][ML] Pass random seed to spark.ml DecisionTree*
cc jkbradley Author: Yu ISHIKAWA <yuu.ishikawa@gmail.com> Closes #9486 from yu-iskw/SPARK-11514.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala4
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala11
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala1
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala1
5 files changed, 14 insertions, 7 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
index b0157f7ce2..c478aea44a 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala
@@ -62,6 +62,8 @@ final class DecisionTreeClassifier(override val uid: String)
override def setImpurity(value: String): this.type = super.setImpurity(value)
+ override def setSeed(value: Long): this.type = super.setSeed(value)
+
override protected def train(dataset: DataFrame): DecisionTreeClassificationModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
@@ -75,7 +77,7 @@ final class DecisionTreeClassifier(override val uid: String)
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val strategy = getOldStrategy(categoricalFeatures, numClasses)
val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
- seed = 0L, parentUID = Some(uid))
+ seed = $(seed), parentUID = Some(uid))
trees.head.asInstanceOf[DecisionTreeClassificationModel]
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
index 04420fc6e8..477030d9ea 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala
@@ -71,13 +71,15 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val
@Since("1.4.0")
override def setImpurity(value: String): this.type = super.setImpurity(value)
+ override def setSeed(value: Long): this.type = super.setSeed(value)
+
override protected def train(dataset: DataFrame): DecisionTreeRegressionModel = {
val categoricalFeatures: Map[Int, Int] =
MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol)))
val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset)
val strategy = getOldStrategy(categoricalFeatures)
val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all",
- seed = 0L, parentUID = Some(uid))
+ seed = $(seed), parentUID = Some(uid))
trees.head.asInstanceOf[DecisionTreeRegressionModel]
}
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
index 281ba6eeff..1da97db927 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala
@@ -29,7 +29,8 @@ import org.apache.spark.mllib.tree.loss.{Loss => OldLoss}
*
* Note: Marked as private and DeveloperApi since this may be made public in the future.
*/
-private[ml] trait DecisionTreeParams extends PredictorParams with HasCheckpointInterval {
+private[ml] trait DecisionTreeParams extends PredictorParams
+ with HasCheckpointInterval with HasSeed {
/**
* Maximum depth of the tree (>= 0).
@@ -123,6 +124,9 @@ private[ml] trait DecisionTreeParams extends PredictorParams with HasCheckpointI
/** @group getParam */
final def getMinInfoGain: Double = $(minInfoGain)
+ /** @group setParam */
+ def setSeed(value: Long): this.type = set(seed, value)
+
/** @group expertSetParam */
def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value)
@@ -257,7 +261,7 @@ private[ml] object TreeRegressorParams {
*
* Note: Marked as private and DeveloperApi since this may be made public in the future.
*/
-private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed {
+private[ml] trait TreeEnsembleParams extends DecisionTreeParams {
/**
* Fraction of the training data used for learning each decision tree, in range (0, 1].
@@ -276,9 +280,6 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed {
/** @group getParam */
final def getSubsamplingRate: Double = $(subsamplingRate)
- /** @group setParam */
- def setSeed(value: Long): this.type = set(seed, value)
-
/**
* Create a Strategy instance to use with the old API.
* NOTE: The caller should set impurity and seed.
diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
index 815f6fd997..92b8f84144 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala
@@ -72,6 +72,7 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte
.setImpurity("gini")
.setMaxDepth(2)
.setMaxBins(100)
+ .setSeed(1)
val categoricalFeatures = Map(0 -> 3, 1-> 3)
val numClasses = 2
compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures, numClasses)
diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
index 868fb8eecb..e0d5afa7a7 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala
@@ -49,6 +49,7 @@ class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContex
.setImpurity("variance")
.setMaxDepth(2)
.setMaxBins(100)
+ .setSeed(1)
val categoricalFeatures = Map(0 -> 3, 1-> 3)
compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures)
}