aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorHolden Karau <holden@pigscanfly.ca>2015-05-20 15:16:12 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-05-20 15:16:12 -0700
commit191ee474527530246ac3164ae9631e01bdd1e647 (patch)
tree24a8b2d7991f39f478b05f2cb3a7b14c539a9eb9 /python
parent6338c40da61de045485c51aa11a5b1e425d22144 (diff)
downloadspark-191ee474527530246ac3164ae9631e01bdd1e647.tar.gz
spark-191ee474527530246ac3164ae9631e01bdd1e647.tar.bz2
spark-191ee474527530246ac3164ae9631e01bdd1e647.zip
[SPARK-7511] [MLLIB] pyspark ml seed param should be random by default or 42 is quite funny but not very random
Author: Holden Karau <holden@pigscanfly.ca> Closes #6139 from holdenk/SPARK-7511-pyspark-ml-seed-param-should-be-random-by-default-or-42-is-quite-funny-but-not-very-random and squashes the following commits: 591f8e5 [Holden Karau] specify old seed for doc tests 2470004 [Holden Karau] Fix a bunch of seeds with default values to have None as the default which will then result in using the hash of the class name cbad96d [Holden Karau] Add the setParams function that is used in the real code 423b8d7 [Holden Karau] Switch the test code to behave slightly more like production code. also don't check the param map value only check for key existence 140d25d [Holden Karau] remove extra space 926165a [Holden Karau] Add some missing newlines for pep8 style 8616751 [Holden Karau] merge in master 58532e6 [Holden Karau] its the __name__ method, also treat None values as not set 56ef24a [Holden Karau] fix test and regenerate base afdaa5c [Holden Karau] make sure different classes have different results 68eb528 [Holden Karau] switch default seed to hash of type of self 89c4611 [Holden Karau] Merge branch 'master' into SPARK-7511-pyspark-ml-seed-param-should-be-random-by-default-or-42-is-quite-funny-but-not-very-random 31cd96f [Holden Karau] specify the seed to randomforestregressor test e1b947f [Holden Karau] Style fixes ce90ec8 [Holden Karau] merge in master bcdf3c9 [Holden Karau] update docstring seeds to none and some other default seeds from 42 65eba21 [Holden Karau] pep8 fixes 0e3797e [Holden Karau] Make seed default to random in more places 213a543 [Holden Karau] Simplify the generated code to only include set default if there is a default rather than having None is note None in the generated code 1ff17c2 [Holden Karau] Make the seed random for HasSeed in python
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/ml/classification.py12
-rw-r--r--python/pyspark/ml/feature.py10
-rw-r--r--python/pyspark/ml/param/__init__.py2
-rw-r--r--python/pyspark/ml/param/_shared_params_code_gen.py9
-rw-r--r--python/pyspark/ml/param/shared.py37
-rw-r--r--python/pyspark/ml/recommendation.py10
-rw-r--r--python/pyspark/ml/regression.py13
-rw-r--r--python/pyspark/ml/tests.py67
8 files changed, 96 insertions, 64 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 4e645519c4..7abbde8b26 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -292,7 +292,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
>>> stringIndexer = StringIndexer(inputCol="label", outputCol="indexed")
>>> si_model = stringIndexer.fit(df)
>>> td = si_model.transform(df)
- >>> rf = RandomForestClassifier(numTrees=2, maxDepth=2, labelCol="indexed")
+ >>> rf = RandomForestClassifier(numTrees=2, maxDepth=2, labelCol="indexed", seed=42)
>>> model = rf.fit(td)
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
@@ -319,12 +319,12 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini",
- numTrees=20, featureSubsetStrategy="auto", seed=42):
+ numTrees=20, featureSubsetStrategy="auto", seed=None):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="gini", \
- numTrees=20, featureSubsetStrategy="auto", seed=42)
+ numTrees=20, featureSubsetStrategy="auto", seed=None)
"""
super(RandomForestClassifier, self).__init__()
self._java_obj = self._new_java_obj(
@@ -347,7 +347,7 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
"The number of features to consider for splits at each tree node. Supported " +
"options: " + ", ".join(RandomForestParams.supportedFeatureSubsetStrategies))
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
- maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=42,
+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None,
impurity="gini", numTrees=20, featureSubsetStrategy="auto")
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
@@ -355,12 +355,12 @@ class RandomForestClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPred
@keyword_only
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
- maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=42,
+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None,
impurity="gini", numTrees=20, featureSubsetStrategy="auto"):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
- maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=42, \
+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, \
impurity="gini", numTrees=20, featureSubsetStrategy="auto")
Sets params for linear classification.
"""
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index c8115cb5bc..5511dceb70 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -876,10 +876,10 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has
@keyword_only
def __init__(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1,
- seed=42, inputCol=None, outputCol=None):
+ seed=None, inputCol=None, outputCol=None):
"""
__init__(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, \
- seed=42, inputCol=None, outputCol=None)
+ seed=None, inputCol=None, outputCol=None)
"""
super(Word2Vec, self).__init__()
self._java_obj = self._new_java_obj("org.apache.spark.ml.feature.Word2Vec", self.uid)
@@ -891,15 +891,15 @@ class Word2Vec(JavaEstimator, HasStepSize, HasMaxIter, HasSeed, HasInputCol, Has
"the minimum number of times a token must appear to be included " +
"in the word2vec model's vocabulary")
self._setDefault(vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1,
- seed=42)
+ seed=None)
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
@keyword_only
def setParams(self, vectorSize=100, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1,
- seed=42, inputCol=None, outputCol=None):
+ seed=None, inputCol=None, outputCol=None):
"""
- setParams(self, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, seed=42, \
+ setParams(self, minCount=5, numPartitions=1, stepSize=0.025, maxIter=1, seed=None, \
inputCol=None, outputCol=None)
Sets params for this Word2Vec.
"""
diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py
index 67fb6e3dc7..7845536161 100644
--- a/python/pyspark/ml/param/__init__.py
+++ b/python/pyspark/ml/param/__init__.py
@@ -147,7 +147,7 @@ class Params(Identifiable):
def getOrDefault(self, param):
"""
Gets the value of a param in the user-supplied param map or its
- default value. Raises an error if either is set.
+ default value. Raises an error if neither is set.
"""
param = self._resolveParam(param)
if param in self._paramMap:
diff --git a/python/pyspark/ml/param/_shared_params_code_gen.py b/python/pyspark/ml/param/_shared_params_code_gen.py
index 91e45ec373..ccb929af18 100644
--- a/python/pyspark/ml/param/_shared_params_code_gen.py
+++ b/python/pyspark/ml/param/_shared_params_code_gen.py
@@ -56,9 +56,10 @@ def _gen_param_header(name, doc, defaultValueStr):
def __init__(self):
super(Has$Name, self).__init__()
#: param for $doc
- self.$name = Param(self, "$name", "$doc")
- if $defaultValueStr is not None:
- self._setDefault($name=$defaultValueStr)'''
+ self.$name = Param(self, "$name", "$doc")'''
+ if defaultValueStr is not None:
+ template += '''
+ self._setDefault($name=$defaultValueStr)'''
Name = name[0].upper() + name[1:]
return template \
@@ -118,7 +119,7 @@ if __name__ == "__main__":
("outputCol", "output column name", None),
("numFeatures", "number of features", None),
("checkpointInterval", "checkpoint interval (>= 1)", None),
- ("seed", "random seed", None),
+ ("seed", "random seed", "hash(type(self).__name__)"),
("tol", "the convergence tolerance for iterative algorithms", None),
("stepSize", "Step size to be used for each iteration of optimization.", None)]
code = []
diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py
index a5dc9b7ef2..0b93788899 100644
--- a/python/pyspark/ml/param/shared.py
+++ b/python/pyspark/ml/param/shared.py
@@ -32,8 +32,6 @@ class HasMaxIter(Params):
super(HasMaxIter, self).__init__()
#: param for max number of iterations (>= 0)
self.maxIter = Param(self, "maxIter", "max number of iterations (>= 0)")
- if None is not None:
- self._setDefault(maxIter=None)
def setMaxIter(self, value):
"""
@@ -61,8 +59,6 @@ class HasRegParam(Params):
super(HasRegParam, self).__init__()
#: param for regularization parameter (>= 0)
self.regParam = Param(self, "regParam", "regularization parameter (>= 0)")
- if None is not None:
- self._setDefault(regParam=None)
def setRegParam(self, value):
"""
@@ -90,8 +86,7 @@ class HasFeaturesCol(Params):
super(HasFeaturesCol, self).__init__()
#: param for features column name
self.featuresCol = Param(self, "featuresCol", "features column name")
- if 'features' is not None:
- self._setDefault(featuresCol='features')
+ self._setDefault(featuresCol='features')
def setFeaturesCol(self, value):
"""
@@ -119,8 +114,7 @@ class HasLabelCol(Params):
super(HasLabelCol, self).__init__()
#: param for label column name
self.labelCol = Param(self, "labelCol", "label column name")
- if 'label' is not None:
- self._setDefault(labelCol='label')
+ self._setDefault(labelCol='label')
def setLabelCol(self, value):
"""
@@ -148,8 +142,7 @@ class HasPredictionCol(Params):
super(HasPredictionCol, self).__init__()
#: param for prediction column name
self.predictionCol = Param(self, "predictionCol", "prediction column name")
- if 'prediction' is not None:
- self._setDefault(predictionCol='prediction')
+ self._setDefault(predictionCol='prediction')
def setPredictionCol(self, value):
"""
@@ -177,8 +170,7 @@ class HasProbabilityCol(Params):
super(HasProbabilityCol, self).__init__()
#: param for Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.
self.probabilityCol = Param(self, "probabilityCol", "Column name for predicted class conditional probabilities. Note: Not all models output well-calibrated probability estimates! These probabilities should be treated as confidences, not precise probabilities.")
- if 'probability' is not None:
- self._setDefault(probabilityCol='probability')
+ self._setDefault(probabilityCol='probability')
def setProbabilityCol(self, value):
"""
@@ -206,8 +198,7 @@ class HasRawPredictionCol(Params):
super(HasRawPredictionCol, self).__init__()
#: param for raw prediction (a.k.a. confidence) column name
self.rawPredictionCol = Param(self, "rawPredictionCol", "raw prediction (a.k.a. confidence) column name")
- if 'rawPrediction' is not None:
- self._setDefault(rawPredictionCol='rawPrediction')
+ self._setDefault(rawPredictionCol='rawPrediction')
def setRawPredictionCol(self, value):
"""
@@ -235,8 +226,6 @@ class HasInputCol(Params):
super(HasInputCol, self).__init__()
#: param for input column name
self.inputCol = Param(self, "inputCol", "input column name")
- if None is not None:
- self._setDefault(inputCol=None)
def setInputCol(self, value):
"""
@@ -264,8 +253,6 @@ class HasInputCols(Params):
super(HasInputCols, self).__init__()
#: param for input column names
self.inputCols = Param(self, "inputCols", "input column names")
- if None is not None:
- self._setDefault(inputCols=None)
def setInputCols(self, value):
"""
@@ -293,8 +280,6 @@ class HasOutputCol(Params):
super(HasOutputCol, self).__init__()
#: param for output column name
self.outputCol = Param(self, "outputCol", "output column name")
- if None is not None:
- self._setDefault(outputCol=None)
def setOutputCol(self, value):
"""
@@ -322,8 +307,6 @@ class HasNumFeatures(Params):
super(HasNumFeatures, self).__init__()
#: param for number of features
self.numFeatures = Param(self, "numFeatures", "number of features")
- if None is not None:
- self._setDefault(numFeatures=None)
def setNumFeatures(self, value):
"""
@@ -351,8 +334,6 @@ class HasCheckpointInterval(Params):
super(HasCheckpointInterval, self).__init__()
#: param for checkpoint interval (>= 1)
self.checkpointInterval = Param(self, "checkpointInterval", "checkpoint interval (>= 1)")
- if None is not None:
- self._setDefault(checkpointInterval=None)
def setCheckpointInterval(self, value):
"""
@@ -380,8 +361,7 @@ class HasSeed(Params):
super(HasSeed, self).__init__()
#: param for random seed
self.seed = Param(self, "seed", "random seed")
- if None is not None:
- self._setDefault(seed=None)
+ self._setDefault(seed=hash(type(self).__name__))
def setSeed(self, value):
"""
@@ -409,8 +389,6 @@ class HasTol(Params):
super(HasTol, self).__init__()
#: param for the convergence tolerance for iterative algorithms
self.tol = Param(self, "tol", "the convergence tolerance for iterative algorithms")
- if None is not None:
- self._setDefault(tol=None)
def setTol(self, value):
"""
@@ -438,8 +416,6 @@ class HasStepSize(Params):
super(HasStepSize, self).__init__()
#: param for Step size to be used for each iteration of optimization.
self.stepSize = Param(self, "stepSize", "Step size to be used for each iteration of optimization.")
- if None is not None:
- self._setDefault(stepSize=None)
def setStepSize(self, value):
"""
@@ -467,6 +443,7 @@ class DecisionTreeParams(Params):
minInfoGain = Param(Params._dummy(), "minInfoGain", "Minimum information gain for a split to be considered at a tree node.")
maxMemoryInMB = Param(Params._dummy(), "maxMemoryInMB", "Maximum memory in MB allocated to histogram aggregation.")
cacheNodeIds = Param(Params._dummy(), "cacheNodeIds", "If false, the algorithm will pass trees to executors to match instances with nodes. If true, the algorithm will cache node IDs for each instance. Caching can speed up training of deeper trees.")
+
def __init__(self):
super(DecisionTreeParams, self).__init__()
diff --git a/python/pyspark/ml/recommendation.py b/python/pyspark/ml/recommendation.py
index 39c2527543..b3e0dd7abf 100644
--- a/python/pyspark/ml/recommendation.py
+++ b/python/pyspark/ml/recommendation.py
@@ -89,11 +89,11 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
@keyword_only
def __init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10,
- implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=0,
+ implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None,
ratingCol="rating", nonnegative=False, checkpointInterval=10):
"""
__init__(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, \
- implicitPrefs=false, alpha=1.0, userCol="user", itemCol="item", seed=0, \
+ implicitPrefs=false, alpha=1.0, userCol="user", itemCol="item", seed=None, \
ratingCol="rating", nonnegative=false, checkpointInterval=10)
"""
super(ALS, self).__init__()
@@ -109,18 +109,18 @@ class ALS(JavaEstimator, HasCheckpointInterval, HasMaxIter, HasPredictionCol, Ha
self.nonnegative = Param(self, "nonnegative",
"whether to use nonnegative constraint for least squares")
self._setDefault(rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10,
- implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=0,
+ implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None,
ratingCol="rating", nonnegative=False, checkpointInterval=10)
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
@keyword_only
def setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10,
- implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=0,
+ implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None,
ratingCol="rating", nonnegative=False, checkpointInterval=10):
"""
setParams(self, rank=10, maxIter=10, regParam=0.1, numUserBlocks=10, numItemBlocks=10, \
- implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=0, \
+ implicitPrefs=False, alpha=1.0, userCol="user", itemCol="item", seed=None, \
ratingCol="rating", nonnegative=False, checkpointInterval=10)
Sets params for ALS.
"""
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index ff809cdafd..b139e27372 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -257,7 +257,7 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
>>> df = sqlContext.createDataFrame([
... (1.0, Vectors.dense(1.0)),
... (0.0, Vectors.sparse(1, [], []))], ["label", "features"])
- >>> rf = RandomForestRegressor(numTrees=2, maxDepth=2)
+ >>> rf = RandomForestRegressor(numTrees=2, maxDepth=2, seed=42)
>>> model = rf.fit(df)
>>> test0 = sqlContext.createDataFrame([(Vectors.dense(-1.0),)], ["features"])
>>> model.transform(test0).head().prediction
@@ -284,12 +284,13 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, impurity="variance",
- numTrees=20, featureSubsetStrategy="auto", seed=42):
+ numTrees=20, featureSubsetStrategy="auto", seed=None):
"""
__init__(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, \
- impurity="variance", numTrees=20, featureSubsetStrategy="auto", seed=42)
+ impurity="variance", numTrees=20, \
+ featureSubsetStrategy="auto", seed=None)
"""
super(RandomForestRegressor, self).__init__()
self._java_obj = self._new_java_obj(
@@ -312,7 +313,7 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
"The number of features to consider for splits at each tree node. Supported " +
"options: " + ", ".join(RandomForestParams.supportedFeatureSubsetStrategies))
self._setDefault(maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
- maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=42,
+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None,
impurity="variance", numTrees=20, featureSubsetStrategy="auto")
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
@@ -320,12 +321,12 @@ class RandomForestRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
@keyword_only
def setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction",
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0,
- maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=42,
+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None,
impurity="variance", numTrees=20, featureSubsetStrategy="auto"):
"""
setParams(self, featuresCol="features", labelCol="label", predictionCol="prediction", \
maxDepth=5, maxBins=32, minInstancesPerNode=1, minInfoGain=0.0, \
- maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=42, \
+ maxMemoryInMB=256, cacheNodeIds=False, checkpointInterval=10, seed=None, \
impurity="variance", numTrees=20, featureSubsetStrategy="auto")
Sets params for linear regression.
"""
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 10fe0ef8db..6adbf166f3 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -33,7 +33,8 @@ else:
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
from pyspark.sql import DataFrame, SQLContext
from pyspark.ml.param import Param, Params
-from pyspark.ml.param.shared import HasMaxIter, HasInputCol
+from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed
+from pyspark.ml.util import keyword_only
from pyspark.ml import Estimator, Model, Pipeline, Transformer
from pyspark.ml.feature import *
from pyspark.mllib.linalg import DenseVector
@@ -111,14 +112,46 @@ class PipelineTests(PySparkTestCase):
self.assertEqual(6, dataset.index)
-class TestParams(HasMaxIter, HasInputCol):
+class TestParams(HasMaxIter, HasInputCol, HasSeed):
"""
- A subclass of Params mixed with HasMaxIter and HasInputCol.
+ A subclass of Params mixed with HasMaxIter, HasInputCol and HasSeed.
"""
-
- def __init__(self):
+ @keyword_only
+ def __init__(self, seed=None):
super(TestParams, self).__init__()
self._setDefault(maxIter=10)
+ kwargs = self.__init__._input_kwargs
+ self.setParams(**kwargs)
+
+ @keyword_only
+ def setParams(self, seed=None):
+ """
+ setParams(self, seed=None)
+ Sets params for this test.
+ """
+ kwargs = self.setParams._input_kwargs
+ return self._set(**kwargs)
+
+
+class OtherTestParams(HasMaxIter, HasInputCol, HasSeed):
+ """
+ A subclass of Params mixed with HasMaxIter, HasInputCol and HasSeed.
+ """
+ @keyword_only
+ def __init__(self, seed=None):
+ super(OtherTestParams, self).__init__()
+ self._setDefault(maxIter=10)
+ kwargs = self.__init__._input_kwargs
+ self.setParams(**kwargs)
+
+ @keyword_only
+ def setParams(self, seed=None):
+ """
+ setParams(self, seed=None)
+ Sets params for this test.
+ """
+ kwargs = self.setParams._input_kwargs
+ return self._set(**kwargs)
class ParamTests(PySparkTestCase):
@@ -134,9 +167,10 @@ class ParamTests(PySparkTestCase):
testParams = TestParams()
maxIter = testParams.maxIter
inputCol = testParams.inputCol
+ seed = testParams.seed
params = testParams.params
- self.assertEqual(params, [inputCol, maxIter])
+ self.assertEqual(params, [inputCol, maxIter, seed])
self.assertTrue(testParams.hasParam(maxIter))
self.assertTrue(testParams.hasDefault(maxIter))
@@ -154,10 +188,29 @@ class ParamTests(PySparkTestCase):
with self.assertRaises(KeyError):
testParams.getInputCol()
+ # Since the default is normally random, set it to a known number for debug str
+ testParams._setDefault(seed=41)
+ testParams.setSeed(43)
+
self.assertEquals(
testParams.explainParams(),
"\n".join(["inputCol: input column name (undefined)",
- "maxIter: max number of iterations (>= 0) (default: 10, current: 100)"]))
+ "maxIter: max number of iterations (>= 0) (default: 10, current: 100)",
+ "seed: random seed (default: 41, current: 43)"]))
+
+ def test_hasseed(self):
+ noSeedSpecd = TestParams()
+ withSeedSpecd = TestParams(seed=42)
+ other = OtherTestParams()
+ # Check that we no longer use 42 as the magic number
+ self.assertNotEqual(noSeedSpecd.getSeed(), 42)
+ origSeed = noSeedSpecd.getSeed()
+ # Check that we only compute the seed once
+ self.assertEqual(noSeedSpecd.getSeed(), origSeed)
+ # Check that a specified seed is honored
+ self.assertEqual(withSeedSpecd.getSeed(), 42)
+ # Check that a different class has a different seed
+ self.assertNotEqual(other.getSeed(), noSeedSpecd.getSeed())
class FeatureTests(PySparkTestCase):