From 3a44aebd0c5331f6ff00734fa44ef63f8d18cfbb Mon Sep 17 00:00:00 2001 From: Martin Menestret Date: Wed, 16 Dec 2015 14:05:35 -0800 Subject: [SPARK-9690][ML][PYTHON] pyspark CrossValidator random seed Extend CrossValidator with HasSeed in PySpark. This PR replaces [https://github.com/apache/spark/pull/7997] CC: yanboliang thunterdb mmenestret Would one of you mind taking a look? Thanks! Author: Joseph K. Bradley Author: Martin MENESTRET Closes #10268 from jkbradley/pyspark-cv-seed. --- python/pyspark/ml/tuning.py | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 705ee53685..08f8db57f4 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -19,8 +19,9 @@ import itertools import numpy as np from pyspark import since -from pyspark.ml.param import Params, Param from pyspark.ml import Estimator, Model +from pyspark.ml.param import Params, Param +from pyspark.ml.param.shared import HasSeed from pyspark.ml.util import keyword_only from pyspark.sql.functions import rand @@ -89,7 +90,7 @@ class ParamGridBuilder(object): return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)] -class CrossValidator(Estimator): +class CrossValidator(Estimator, HasSeed): """ K-fold cross validation. @@ -129,9 +130,11 @@ class CrossValidator(Estimator): numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation") @keyword_only - def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): + def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, + seed=None): """ - __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3) + __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ + seed=None) """ super(CrossValidator, self).__init__() #: param for estimator to be cross-validated @@ -151,9 +154,11 @@ class CrossValidator(Estimator): @keyword_only @since("1.4.0") - def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): + def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, + seed=None): """ - setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3): + setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\ + seed=None): Sets params for cross validator. """ kwargs = self.setParams._input_kwargs @@ -225,9 +230,10 @@ class CrossValidator(Estimator): numModels = len(epm) eva = self.getOrDefault(self.evaluator) nFolds = self.getOrDefault(self.numFolds) + seed = self.getOrDefault(self.seed) h = 1.0 / nFolds randCol = self.uid + "_rand" - df = dataset.select("*", rand(0).alias(randCol)) + df = dataset.select("*", rand(seed).alias(randCol)) metrics = np.zeros(numModels) for i in range(nFolds): validateLB = i * h -- cgit v1.2.3