aboutsummaryrefslogtreecommitdiff
path: root/mllib/src
diff options
context:
space:
mode:
authorXusen Yin <yinxusen@gmail.com>2016-03-30 14:32:29 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-03-30 14:32:29 -0700
commit529d6ce8f96ef2b4a57c2d9066c7d80466e36209 (patch)
treefcb5ade9f0b80cf02959448b39070cac5685c2a2 /mllib/src
parentbdabfd43f6e4900b48010dd00ffa48ed5fd15997 (diff)
downloadspark-529d6ce8f96ef2b4a57c2d9066c7d80466e36209.tar.gz
spark-529d6ce8f96ef2b4a57c2d9066c7d80466e36209.tar.bz2
spark-529d6ce8f96ef2b4a57c2d9066c7d80466e36209.zip
[SPARK-14181] TrainValidationSplit should have HasSeed
https://issues.apache.org/jira/browse/SPARK-14181 TrainValidationSplit should have HasSeed for the random split of RDD. I also changed the random split from the RDD function to the DataFrame function. Author: Xusen Yin <yinxusen@gmail.com> Closes #11985 from yinxusen/SPARK-14181.
Diffstat (limited to 'mllib/src')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala15
-rw-r--r--mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala4
2 files changed, 14 insertions, 5 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
index 4d1d6364d7..07330bb6b0 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/TrainValidationSplit.scala
@@ -25,6 +25,7 @@ import org.apache.spark.internal.Logging
import org.apache.spark.ml.{Estimator, Model}
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.param.{DoubleParam, ParamMap, ParamValidators}
+import org.apache.spark.ml.param.shared.HasSeed
import org.apache.spark.ml.util._
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
@@ -32,7 +33,7 @@ import org.apache.spark.sql.types.StructType
/**
* Params for [[TrainValidationSplit]] and [[TrainValidationSplitModel]].
*/
-private[ml] trait TrainValidationSplitParams extends ValidatorParams {
+private[ml] trait TrainValidationSplitParams extends ValidatorParams with HasSeed {
/**
* Param for ratio between train and validation data. Must be between 0 and 1.
* Default: 0.75
@@ -80,6 +81,10 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
@Since("1.5.0")
def setTrainRatio(value: Double): this.type = set(trainRatio, value)
+ /** @group setParam */
+ @Since("2.0.0")
+ def setSeed(value: Long): this.type = set(seed, value)
+
@Since("1.5.0")
override def fit(dataset: DataFrame): TrainValidationSplitModel = {
val schema = dataset.schema
@@ -91,10 +96,10 @@ class TrainValidationSplit @Since("1.5.0") (@Since("1.5.0") override val uid: St
val numModels = epm.length
val metrics = new Array[Double](epm.length)
- val Array(training, validation) =
- dataset.rdd.randomSplit(Array($(trainRatio), 1 - $(trainRatio)))
- val trainingDataset = sqlCtx.createDataFrame(training, schema).cache()
- val validationDataset = sqlCtx.createDataFrame(validation, schema).cache()
+ val Array(trainingDataset, validationDataset) =
+ dataset.randomSplit(Array($(trainRatio), 1 - $(trainRatio)), $(seed))
+ trainingDataset.cache()
+ validationDataset.cache()
// multi-model training
logDebug(s"Train split with multiple sets of parameters.")
diff --git a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
index 7cf7b3e087..4030956fab 100644
--- a/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
+++ b/mllib/src/test/scala/org/apache/spark/ml/tuning/TrainValidationSplitSuite.scala
@@ -48,6 +48,7 @@ class TrainValidationSplitSuite
.setEstimatorParamMaps(lrParamMaps)
.setEvaluator(eval)
.setTrainRatio(0.5)
+ .setSeed(42L)
val cvModel = cv.fit(dataset)
val parent = cvModel.bestModel.parent.asInstanceOf[LogisticRegression]
assert(cv.getTrainRatio === 0.5)
@@ -72,6 +73,7 @@ class TrainValidationSplitSuite
.setEstimatorParamMaps(lrParamMaps)
.setEvaluator(eval)
.setTrainRatio(0.5)
+ .setSeed(42L)
val cvModel = cv.fit(dataset)
val parent = cvModel.bestModel.parent.asInstanceOf[LinearRegression]
assert(parent.getRegParam === 0.001)
@@ -120,6 +122,7 @@ class TrainValidationSplitSuite
.setEvaluator(evaluator)
.setTrainRatio(0.5)
.setEstimatorParamMaps(paramMaps)
+ .setSeed(42L)
val tvs2 = testDefaultReadWrite(tvs, testParams = false)
@@ -140,6 +143,7 @@ class TrainValidationSplitSuite
.set(tvs.evaluator, evaluator)
.set(tvs.trainRatio, 0.5)
.set(tvs.estimatorParamMaps, paramMaps)
+ .set(tvs.seed, 42L)
val tvs2 = testDefaultReadWrite(tvs, testParams = false)