aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMartin Menestret <martinmenestret@gmail.com>2015-12-16 14:05:35 -0800
committerJoseph K. Bradley <joseph@databricks.com>2015-12-16 14:05:35 -0800
commit3a44aebd0c5331f6ff00734fa44ef63f8d18cfbb (patch)
tree0f9720cbacaa63b6b4c9195073dd62e1d05b30a6
parent9657ee87888422c5596987fe760b49117a0ea4e2 (diff)
downloadspark-3a44aebd0c5331f6ff00734fa44ef63f8d18cfbb.tar.gz
spark-3a44aebd0c5331f6ff00734fa44ef63f8d18cfbb.tar.bz2
spark-3a44aebd0c5331f6ff00734fa44ef63f8d18cfbb.zip
[SPARK-9690][ML][PYTHON] pyspark CrossValidator random seed
Extend CrossValidator with HasSeed in PySpark. This PR replaces [https://github.com/apache/spark/pull/7997] CC: yanboliang thunterdb mmenestret Would one of you mind taking a look? Thanks! Author: Joseph K. Bradley <joseph@databricks.com> Author: Martin MENESTRET <mmenestret@ippon.fr> Closes #10268 from jkbradley/pyspark-cv-seed.
-rw-r--r--python/pyspark/ml/tuning.py20
1 files changed, 13 insertions, 7 deletions
diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py
index 705ee53685..08f8db57f4 100644
--- a/python/pyspark/ml/tuning.py
+++ b/python/pyspark/ml/tuning.py
@@ -19,8 +19,9 @@ import itertools
import numpy as np
from pyspark import since
-from pyspark.ml.param import Params, Param
from pyspark.ml import Estimator, Model
+from pyspark.ml.param import Params, Param
+from pyspark.ml.param.shared import HasSeed
from pyspark.ml.util import keyword_only
from pyspark.sql.functions import rand
@@ -89,7 +90,7 @@ class ParamGridBuilder(object):
return [dict(zip(keys, prod)) for prod in itertools.product(*grid_values)]
-class CrossValidator(Estimator):
+class CrossValidator(Estimator, HasSeed):
"""
K-fold cross validation.
@@ -129,9 +130,11 @@ class CrossValidator(Estimator):
numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation")
@keyword_only
- def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3):
+ def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,
+ seed=None):
"""
- __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3)
+ __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\
+ seed=None)
"""
super(CrossValidator, self).__init__()
#: param for estimator to be cross-validated
@@ -151,9 +154,11 @@ class CrossValidator(Estimator):
@keyword_only
@since("1.4.0")
- def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3):
+ def setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,
+ seed=None):
"""
- setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3):
+ setParams(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3,\
+ seed=None):
Sets params for cross validator.
"""
kwargs = self.setParams._input_kwargs
@@ -225,9 +230,10 @@ class CrossValidator(Estimator):
numModels = len(epm)
eva = self.getOrDefault(self.evaluator)
nFolds = self.getOrDefault(self.numFolds)
+ seed = self.getOrDefault(self.seed)
h = 1.0 / nFolds
randCol = self.uid + "_rand"
- df = dataset.select("*", rand(0).alias(randCol))
+ df = dataset.select("*", rand(seed).alias(randCol))
metrics = np.zeros(numModels)
for i in range(nFolds):
validateLB = i * h