diff options
author | Martin Menestret <martinmenestret@gmail.com> | 2015-12-16 14:05:35 -0800 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2015-12-16 14:05:35 -0800 |
commit | 3a44aebd0c5331f6ff00734fa44ef63f8d18cfbb (patch) | |
tree | 0f9720cbacaa63b6b4c9195073dd62e1d05b30a6 /python/pyspark/ml | |
parent | 9657ee87888422c5596987fe760b49117a0ea4e2 (diff) | |
download | spark-3a44aebd0c5331f6ff00734fa44ef63f8d18cfbb.tar.gz spark-3a44aebd0c5331f6ff00734fa44ef63f8d18cfbb.tar.bz2 spark-3a44aebd0c5331f6ff00734fa44ef63f8d18cfbb.zip |
[SPARK-9690][ML][PYTHON] pyspark CrossValidator random seed
Extend CrossValidator with HasSeed in PySpark.
This PR replaces [https://github.com/apache/spark/pull/7997]
CC: yanboliang thunterdb mmenestret Would one of you mind taking a look? Thanks!
Author: Joseph K. Bradley <joseph@databricks.com>
Author: Martin MENESTRET <mmenestret@ippon.fr>
Closes #10268 from jkbradley/pyspark-cv-seed.
Diffstat (limited to 'python/pyspark/ml')
-rw-r--r-- | python/pyspark/ml/tuning.py | 20 |
1 files changed, 13 insertions, 7 deletions
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 705ee53685..08f8db57f4 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -19,8 +19,9 @@ import itertools import numpy as np from pyspark import since -from pyspark.ml.param import Params, Param from pyspark.ml import Estimator, Model +from pyspark.ml.param import Params, Param +from pyspark.ml.param.shared import HasSeed from pyspark.ml.util import keyword_only from pyspark.sql.functions import rand @@ -89,7 +90,7 @@ class ParamGridBuilder(object): return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)] -class CrossValidator(Estimator): +class CrossValidator(Estimator, HasSeed): """ K-fold cross validation. @@ -129,9 +130,11 @@ class CrossValidator(Estimator): numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation") @keyword_only - def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): + def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, + seed=None): """ - __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3) + __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ + seed=None) """ super(CrossValidator, self).__init__() #: param for estimator to be cross-validated @@ -151,9 +154,11 @@ class CrossValidator(Estimator): @keyword_only @since("1.4.0") - def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): + def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, + seed=None): """ - setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): + setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ + seed=None): Sets params for cross validator. """ kwargs = self.setParams._input_kwargs @@ -225,9 +230,10 @@ class CrossValidator(Estimator): numModels = len(epm) eva = self.getOrDefault(self.evaluator) nFolds = self.getOrDefault(self.numFolds) + seed = self.getOrDefault(self.seed) h = 1.0 / nFolds randCol = self.uid + "_rand" - df = dataset.select("*", rand(0).alias(randCol)) + df = dataset.select("*", rand(seed).alias(randCol)) metrics = np.zeros(numModels) for i in range(nFolds): validateLB = i * h |