diff options
Diffstat (limited to 'python/pyspark/ml')
-rw-r--r-- | python/pyspark/ml/tuning.py | 20 |
1 files changed, 13 insertions, 7 deletions
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 705ee53685..08f8db57f4 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -19,8 +19,9 @@ import itertools import numpy as np from pyspark import since -from pyspark.ml.param import Params, Param from pyspark.ml import Estimator, Model +from pyspark.ml.param import Params, Param +from pyspark.ml.param.shared import HasSeed from pyspark.ml.util import keyword_only from pyspark.sql.functions import rand @@ -89,7 +90,7 @@ class ParamGridBuilder(object): return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)] -class CrossValidator(Estimator): +class CrossValidator(Estimator, HasSeed): """ K-fold cross validation. @@ -129,9 +130,11 @@ class CrossValidator(Estimator): numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation") @keyword_only - def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): + def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, + seed=None): """ - __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3) + __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ + seed=None) """ super(CrossValidator, self).__init__() #: param for estimator to be cross-validated @@ -151,9 +154,11 @@ class CrossValidator(Estimator): @keyword_only @since("1.4.0") - def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): + def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, + seed=None): """ - setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): + setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ + seed=None): Sets params for cross validator. """ kwargs = self.setParams._input_kwargs @@ -225,9 +230,10 @@ class CrossValidator(Estimator): numModels = len(epm) eva = self.getOrDefault(self.evaluator) nFolds = self.getOrDefault(self.numFolds) + seed = self.getOrDefault(self.seed) h = 1.0 / nFolds randCol = self.uid + "_rand" - df = dataset.select("*", rand(0).alias(randCol)) + df = dataset.select("*", rand(seed).alias(randCol)) metrics = np.zeros(numModels) for i in range(nFolds): validateLB = i * h |