aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYu ISHIKAWA <yuu.ishikawa@gmail.com>2015-11-01 23:52:50 -0800
committerDB Tsai <dbt@netflix.com>2015-11-01 23:52:50 -0800
commite963070c13f56fbc2dfaf9f5d4e69d34afd0957c (patch)
tree53c8ab17fc6c2a6d3cc3e737df06a50490011759 /mllib
parent3e770a64a48c271c5829d2bcbdc1d6430cda2ac9 (diff)
downloadspark-e963070c13f56fbc2dfaf9f5d4e69d34afd0957c.tar.gz
spark-e963070c13f56fbc2dfaf9f5d4e69d34afd0957c.tar.bz2
spark-e963070c13f56fbc2dfaf9f5d4e69d34afd0957c.zip
[SPARK-9722] [ML] Pass random seed to spark.ml DecisionTree*
Author: Yu ISHIKAWA <yuu.ishikawa@gmail.com> Closes #9402 from yu-iskw/SPARK-9722.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala8
1 files changed, 5 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
index 96d5652857..4a3b12d144 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tree/impl/RandomForest.scala
@@ -74,7 +74,7 @@ private[ml] object RandomForest extends Logging {
// Find the splits and the corresponding bins (interval between the splits) using a sample
// of the input data.
timer.start("findSplitsBins")
- val splits = findSplits(retaggedInput, metadata)
+ val splits = findSplits(retaggedInput, metadata, seed)
timer.stop("findSplitsBins")
logDebug("numBins: feature: number of bins")
logDebug(Range(0, metadata.numFeatures).map { featureIndex =>
@@ -815,6 +815,7 @@ private[ml] object RandomForest extends Logging {
*
* @param input Training data: RDD of [[org.apache.spark.mllib.regression.LabeledPoint]]
* @param metadata Learning and dataset metadata
+ * @param seed random seed
* @return A tuple of (splits, bins).
* Splits is an Array of [[org.apache.spark.mllib.tree.model.Split]]
* of size (numFeatures, numSplits).
@@ -823,7 +824,8 @@ private[ml] object RandomForest extends Logging {
*/
protected[tree] def findSplits(
input: RDD[LabeledPoint],
- metadata: DecisionTreeMetadata): Array[Array[Split]] = {
+ metadata: DecisionTreeMetadata,
+ seed : Long): Array[Array[Split]] = {
logDebug("isMulticlass = " + metadata.isMulticlass)
@@ -840,7 +842,7 @@ private[ml] object RandomForest extends Logging {
1.0
}
logDebug("fraction of data used for calculating quantiles = " + fraction)
- input.sample(withReplacement = false, fraction, new XORShiftRandom(1).nextInt()).collect()
+ input.sample(withReplacement = false, fraction, new XORShiftRandom(seed).nextInt()).collect()
} else {
new Array[LabeledPoint](0)
}