diff options
author | Yu ISHIKAWA <yuu.ishikawa@gmail.com> | 2015-11-01 23:52:50 -0800 |
---|---|---|
committer | DB Tsai <dbt@netflix.com> | 2015-11-01 23:52:50 -0800 |
commit | e963070c13f56fbc2dfaf9f5d4e69d34afd0957c (patch) | |
tree | 53c8ab17fc6c2a6d3cc3e737df06a50490011759 /mllib | |
parent | 3e770a64a48c271c5829d2bcbdc1d6430cda2ac9 (diff) | |
download | spark-e963070c13f56fbc2dfaf9f5d4e69d34afd0957c.tar.gz spark-e963070c13f56fbc2dfaf9f5d4e69d34afd0957c.tar.bz2 spark-e963070c13f56fbc2dfaf9f5d4e69d34afd0957c.zip |
[SPARK-9722] [ML] Pass random seed to spark.ml DecisionTree*
Author: Yu ISHIKAWA <yuu.ishikawa@gmail.com>
Closes #9402 from yu-iskw/SPARK-9722.
Diffstat (limited to 'mllib')
-rw-r--r-- | mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala | 8 |
1 files changed, 5 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala index 96d5652857..4a3b12d144 100644 --- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala +++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala @@ -74,7 +74,7 @@ private[ml] object RandomForest extends Logging { // Find the splits and the corresponding bins (interval between the splits) using a sample // of the input data. timer.start("findSplitsBins") - val splits = findSplits(retaggedInput, metadata) + val splits = findSplits(retaggedInput, metadata, seed) timer.stop("findSplitsBins") logDebug("numBins: feature: number of bins") logDebug(Range(0, metadata.numFeatures).map { featureIndex => @@ -815,6 +815,7 @@ private[ml] object RandomForest extends Logging { * * @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]] * @param metadata Learning and dataset metadata + * @param seed random seed * @return A tuple of (splits, bins). * Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]] * of size (numFeatures, numSplits). @@ -823,7 +824,8 @@ private[ml] object RandomForest extends Logging { */ protected[tree] def findSplits( input: RDD[LabeledPoint], - metadata: DecisionTreeMetadata): Array[Array[Split]] = { + metadata: DecisionTreeMetadata, + seed : Long): Array[Array[Split]] = { logDebug("isMulticlass = " + metadata.isMulticlass) @@ -840,7 +842,7 @@ private[ml] object RandomForest extends Logging { 1.0 } logDebug("fraction of data used for calculating quantiles = " + fraction) - input.sample(withReplacement = false, fraction, new XORShiftRandom(1).nextInt()).collect() + input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()).collect() } else { new Array[LabeledPoint](0) } |