From 8fa8c8375d7015a0332aa9ee613d7c6b6d62bae7 Mon Sep 17 00:00:00 2001 From: Yu ISHIKAWA Date: Thu, 5 Nov 2015 17:59:01 -0800 Subject: [SPARK-11514][ML] Pass random seed to spark.ml DecisionTree* cc jkbradley Author: Yu ISHIKAWA Closes #9486 from yu-iskw/SPARK-11514. --- .../spark/ml/classification/DecisionTreeClassifier.scala | 4 +++- .../apache/spark/ml/regression/DecisionTreeRegressor.scala | 4 +++- .../src/main/scala/org/apache/spark/ml/tree/treeParams.scala | 11 ++++++----- .../spark/ml/classification/DecisionTreeClassifierSuite.scala | 1 + .../spark/ml/regression/DecisionTreeRegressorSuite.scala | 1 + 5 files changed, 14 insertions(+), 7 deletions(-) diff --git a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala index b0157f7ce2..c478aea44a 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/classification/DecisionTreeClassifier.scala @@ -62,6 +62,8 @@ final class DecisionTreeClassifier(override val uid: String) override def setImpurity(value: String): this.type = super.setImpurity(value) + override def setSeed(value: Long): this.type = super.setSeed(value) + override protected def train(dataset: DataFrame): DecisionTreeClassificationModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) @@ -75,7 +77,7 @@ final class DecisionTreeClassifier(override val uid: String) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val strategy = getOldStrategy(categoricalFeatures, numClasses) val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all", - seed = 0L, parentUID = Some(uid)) + seed = $(seed), parentUID = Some(uid)) trees.head.asInstanceOf[DecisionTreeClassificationModel] } diff --git a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala index 04420fc6e8..477030d9ea 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/regression/DecisionTreeRegressor.scala @@ -71,13 +71,15 @@ final class DecisionTreeRegressor @Since("1.4.0") (@Since("1.4.0") override val @Since("1.4.0") override def setImpurity(value: String): this.type = super.setImpurity(value) + override def setSeed(value: Long): this.type = super.setSeed(value) + override protected def train(dataset: DataFrame): DecisionTreeRegressionModel = { val categoricalFeatures: Map[Int, Int] = MetadataUtils.getCategoricalFeatures(dataset.schema($(featuresCol))) val oldDataset: RDD[LabeledPoint] = extractLabeledPoints(dataset) val strategy = getOldStrategy(categoricalFeatures) val trees = RandomForest.run(oldDataset, strategy, numTrees = 1, featureSubsetStrategy = "all", - seed = 0L, parentUID = Some(uid)) + seed = $(seed), parentUID = Some(uid)) trees.head.asInstanceOf[DecisionTreeRegressionModel] } diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala index 281ba6eeff..1da97db927 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/treeParams.scala @@ -29,7 +29,8 @@ import org.apache.spark.mllib.tree.loss.{Loss => OldLoss} * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -private[ml] trait DecisionTreeParams extends PredictorParams with HasCheckpointInterval { +private[ml] trait DecisionTreeParams extends PredictorParams + with HasCheckpointInterval with HasSeed { /** * Maximum depth of the tree (>= 0). @@ -123,6 +124,9 @@ private[ml] trait DecisionTreeParams extends PredictorParams with HasCheckpointI /** @group getParam */ final def getMinInfoGain: Double = $(minInfoGain) + /** @group setParam */ + def setSeed(value: Long): this.type = set(seed, value) + /** @group expertSetParam */ def setMaxMemoryInMB(value: Int): this.type = set(maxMemoryInMB, value) @@ -257,7 +261,7 @@ private[ml] object TreeRegressorParams { * * Note: Marked as private and DeveloperApi since this may be made public in the future. */ -private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed { +private[ml] trait TreeEnsembleParams extends DecisionTreeParams { /** * Fraction of the training data used for learning each decision tree, in range (0, 1]. @@ -276,9 +280,6 @@ private[ml] trait TreeEnsembleParams extends DecisionTreeParams with HasSeed { /** @group getParam */ final def getSubsamplingRate: Double = $(subsamplingRate) - /** @group setParam */ - def setSeed(value: Long): this.type = set(seed, value) - /** * Create a Strategy instance to use with the old API. * NOTE: The caller should set impurity and seed. diff --git a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala index 815f6fd997..92b8f84144 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/classification/DecisionTreeClassifierSuite.scala @@ -72,6 +72,7 @@ class DecisionTreeClassifierSuite extends SparkFunSuite with MLlibTestSparkConte .setImpurity("gini") .setMaxDepth(2) .setMaxBins(100) + .setSeed(1) val categoricalFeatures = Map(0 -> 3, 1-> 3) val numClasses = 2 compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures, numClasses) diff --git a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala index 868fb8eecb..e0d5afa7a7 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/regression/DecisionTreeRegressorSuite.scala @@ -49,6 +49,7 @@ class DecisionTreeRegressorSuite extends SparkFunSuite with MLlibTestSparkContex .setImpurity("variance") .setMaxDepth(2) .setMaxBins(100) + .setSeed(1) val categoricalFeatures = Map(0 -> 3, 1-> 3) compareAPIs(categoricalDataPointsRDD, dt, categoricalFeatures) } -- cgit v1.2.3