aboutsummaryrefslogtreecommitdiff
path: root/mllib
diff options
context:
space:
mode:
authorYanbo Liang <ybliang8@gmail.com>2015-12-16 11:05:37 -0800
committerJoseph K. Bradley <joseph@databricks.com>2015-12-16 11:05:37 -0800
commit860dc7f2f8dd01f2562ba83b7af27ba29d91cb62 (patch)
tree7fb192d6b6c212dd6d9ad7bd2e8b6fecb9e2b676 /mllib
parent7b6dc29d0ebbfb3bb941130f8542120b6bc3e234 (diff)
downloadspark-860dc7f2f8dd01f2562ba83b7af27ba29d91cb62.tar.gz
spark-860dc7f2f8dd01f2562ba83b7af27ba29d91cb62.tar.bz2
spark-860dc7f2f8dd01f2562ba83b7af27ba29d91cb62.zip
[SPARK-9694][ML] Add random seed Param to Scala CrossValidator
Add random seed Param to Scala CrossValidator Author: Yanbo Liang <ybliang8@gmail.com> Closes #9108 from yanboliang/spark-9694.
Diffstat (limited to 'mllib')
-rw-r--r--mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala11
-rw-r--r--mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala8
2 files changed, 16 insertions, 3 deletions
diff --git a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
index 5c09f1aaff..40f8857fc5 100644
--- a/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
+++ b/mllib/src/main/scala/org/apache/spark/ml/tuning/CrossValidator.scala
@@ -29,8 +29,9 @@ import org.apache.spark.ml.classification.OneVsRestParams
import org.apache.spark.ml.evaluation.Evaluator
import org.apache.spark.ml.feature.RFormulaModel
import org.apache.spark.ml.param._
-import org.apache.spark.ml.util.DefaultParamsReader.Metadata
+import org.apache.spark.ml.param.shared.HasSeed
import org.apache.spark.ml.util._
+import org.apache.spark.ml.util.DefaultParamsReader.Metadata
import org.apache.spark.mllib.util.MLUtils
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.types.StructType
@@ -39,7 +40,7 @@ import org.apache.spark.sql.types.StructType
/**
* Params for [[CrossValidator]] and [[CrossValidatorModel]].
*/
-private[ml] trait CrossValidatorParams extends ValidatorParams {
+private[ml] trait CrossValidatorParams extends ValidatorParams with HasSeed {
/**
* Param for number of folds for cross validation. Must be >= 2.
* Default: 3
@@ -85,6 +86,10 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
@Since("1.2.0")
def setNumFolds(value: Int): this.type = set(numFolds, value)
+ /** @group setParam */
+ @Since("2.0.0")
+ def setSeed(value: Long): this.type = set(seed, value)
+
@Since("1.4.0")
override def fit(dataset: DataFrame): CrossValidatorModel = {
val schema = dataset.schema
@@ -95,7 +100,7 @@ class CrossValidator @Since("1.2.0") (@Since("1.4.0") override val uid: String)
val epm = $(estimatorParamMaps)
val numModels = epm.length
val metrics = new Array[Double](epm.length)
- val splits = MLUtils.kFold(dataset.rdd, $(numFolds), 0)
+ val splits = MLUtils.kFold(dataset.rdd, $(numFolds), $(seed))
splits.zipWithIndex.foreach { case ((training, validation), splitIndex) =>
val trainingDataset = sqlCtx.createDataFrame(training, schema).cache()
val validationDataset = sqlCtx.createDataFrame(validation, schema).cache()
diff --git a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
index 414ea99cfd..4c9151f0cb 100644
--- a/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
+++ b/mllib/src/main/scala/org/apache/spark/mllib/util/MLUtils.scala
@@ -265,6 +265,14 @@ object MLUtils {
*/
@Since("1.0.0")
def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Int): Array[(RDD[T], RDD[T])] = {
+ kFold(rdd, numFolds, seed.toLong)
+ }
+
+ /**
+ * Version of [[kFold()]] taking a Long seed.
+ */
+ @Since("2.0.0")
+ def kFold[T: ClassTag](rdd: RDD[T], numFolds: Int, seed: Long): Array[(RDD[T], RDD[T])] = {
val numFoldsF = numFolds.toFloat
(1 to numFolds).map { fold =>
val sampler = new BernoulliCellSampler[T]((fold - 1) / numFoldsF, fold / numFoldsF,