diff options
author | Xusen Yin <yinxusen@gmail.com> | 2016-03-30 14:32:29 -0700 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-03-30 14:32:29 -0700 |
commit | 529d6ce8f96ef2b4a57c2d9066c7d80466e36209 (patch) | |
tree | fcb5ade9f0b80cf02959448b39070cac5685c2a2 /mllib/src/test/scala/org/apache | |
parent | bdabfd43f6e4900b48010dd00ffa48ed5fd15997 (diff) | |
download | spark-529d6ce8f96ef2b4a57c2d9066c7d80466e36209.tar.gz spark-529d6ce8f96ef2b4a57c2d9066c7d80466e36209.tar.bz2 spark-529d6ce8f96ef2b4a57c2d9066c7d80466e36209.zip |
[SPARK-14181] TrainValidationSplit should have HasSeed
https://issues.apache.org/jira/browse/SPARK-14181
TrainValidationSplit should have HasSeed for the random split of RDD. I also changed the random split from the RDD function to the DataFrame function.
Author: Xusen Yin <yinxusen@gmail.com>
Closes #11985 from yinxusen/SPARK-14181.
Diffstat (limited to 'mllib/src/test/scala/org/apache')
-rw-r--r-- | mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala | 4 |
1 files changed, 4 insertions, 0 deletions
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala index 7cf7b3e087..4030956fab 100644 --- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala +++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala @@ -48,6 +48,7 @@ class TrainValidationSplitSuite .setEstimatorParamMaps(lrParamMaps) .setEvaluator(eval) .setTrainRatio(0.5) + .setSeed(42L) val cvModel = cv.fit(dataset) val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression] assert(cv.getTrainRatio === 0.5) @@ -72,6 +73,7 @@ class TrainValidationSplitSuite .setEstimatorParamMaps(lrParamMaps) .setEvaluator(eval) .setTrainRatio(0.5) + .setSeed(42L) val cvModel = cv.fit(dataset) val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression] assert(parent.getRegParam === 0.001) @@ -120,6 +122,7 @@ class TrainValidationSplitSuite .setEvaluator(evaluator) .setTrainRatio(0.5) .setEstimatorParamMaps(paramMaps) + .setSeed(42L) val tvs2 = testDefaultReadWrite(tvs, testParams = false) @@ -140,6 +143,7 @@ class TrainValidationSplitSuite .set(tvs.evaluator, evaluator) .set(tvs.trainRatio, 0.5) .set(tvs.estimatorParamMaps, paramMaps) + .set(tvs.seed, 42L) val tvs2 = testDefaultReadWrite(tvs, testParams = false) |