aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/tests.py
diff options
context:
space:
mode:
authorHolden Karau <holden@pigscanfly.ca>2015-05-20 15:16:12 -0700
committerJoseph K. Bradley <joseph@databricks.com>2015-05-20 15:16:12 -0700
commit191ee474527530246ac3164ae9631e01bdd1e647 (patch)
tree24a8b2d7991f39f478b05f2cb3a7b14c539a9eb9 /python/pyspark/ml/tests.py
parent6338c40da61de045485c51aa11a5b1e425d22144 (diff)
downloadspark-191ee474527530246ac3164ae9631e01bdd1e647.tar.gz
spark-191ee474527530246ac3164ae9631e01bdd1e647.tar.bz2
spark-191ee474527530246ac3164ae9631e01bdd1e647.zip
[SPARK-7511] [MLLIB] pyspark ml seed param should be random by default or 42 is quite funny but not very random
Author: Holden Karau <holden@pigscanfly.ca> Closes #6139 from holdenk/SPARK-7511-pyspark-ml-seed-param-should-be-random-by-default-or-42-is-quite-funny-but-not-very-random and squashes the following commits: 591f8e5 [Holden Karau] specify old seed for doc tests 2470004 [Holden Karau] Fix a bunch of seeds with default values to have None as the default which will then result in using the hash of the class name cbad96d [Holden Karau] Add the setParams function that is used in the real code 423b8d7 [Holden Karau] Switch the test code to behave slightly more like production code. also don't check the param map value only check for key existence 140d25d [Holden Karau] remove extra space 926165a [Holden Karau] Add some missing newlines for pep8 style 8616751 [Holden Karau] merge in master 58532e6 [Holden Karau] its the __name__ method, also treat None values as not set 56ef24a [Holden Karau] fix test and regenerate base afdaa5c [Holden Karau] make sure different classes have different results 68eb528 [Holden Karau] switch default seed to hash of type of self 89c4611 [Holden Karau] Merge branch 'master' into SPARK-7511-pyspark-ml-seed-param-should-be-random-by-default-or-42-is-quite-funny-but-not-very-random 31cd96f [Holden Karau] specify the seed to randomforestregressor test e1b947f [Holden Karau] Style fixes ce90ec8 [Holden Karau] merge in master bcdf3c9 [Holden Karau] update docstring seeds to none and some other default seeds from 42 65eba21 [Holden Karau] pep8 fixes 0e3797e [Holden Karau] Make seed default to random in more places 213a543 [Holden Karau] Simplify the generated code to only include set default if there is a default rather than having None is note None in the generated code 1ff17c2 [Holden Karau] Make the seed random for HasSeed in python
Diffstat (limited to 'python/pyspark/ml/tests.py')
-rw-r--r--python/pyspark/ml/tests.py67
1 files changed, 60 insertions, 7 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 10fe0ef8db..6adbf166f3 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -33,7 +33,8 @@ else:
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
from pyspark.sql import DataFrame, SQLContext
from pyspark.ml.param import Param, Params
-from pyspark.ml.param.shared import HasMaxIter, HasInputCol
+from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed
+from pyspark.ml.util import keyword_only
from pyspark.ml import Estimator, Model, Pipeline, Transformer
from pyspark.ml.feature import *
from pyspark.mllib.linalg import DenseVector
@@ -111,14 +112,46 @@ class PipelineTests(PySparkTestCase):
self.assertEqual(6, dataset.index)
-class TestParams(HasMaxIter, HasInputCol):
+class TestParams(HasMaxIter, HasInputCol, HasSeed):
"""
- A subclass of Params mixed with HasMaxIter and HasInputCol.
+ A subclass of Params mixed with HasMaxIter, HasInputCol and HasSeed.
"""
-
- def __init__(self):
+ @keyword_only
+ def __init__(self, seed=None):
super(TestParams, self).__init__()
self._setDefault(maxIter=10)
+ kwargs = self.__init__._input_kwargs
+ self.setParams(**kwargs)
+
+ @keyword_only
+ def setParams(self, seed=None):
+ """
+ setParams(self, seed=None)
+ Sets params for this test.
+ """
+ kwargs = self.setParams._input_kwargs
+ return self._set(**kwargs)
+
+
+class OtherTestParams(HasMaxIter, HasInputCol, HasSeed):
+ """
+ A subclass of Params mixed with HasMaxIter, HasInputCol and HasSeed.
+ """
+ @keyword_only
+ def __init__(self, seed=None):
+ super(OtherTestParams, self).__init__()
+ self._setDefault(maxIter=10)
+ kwargs = self.__init__._input_kwargs
+ self.setParams(**kwargs)
+
+ @keyword_only
+ def setParams(self, seed=None):
+ """
+ setParams(self, seed=None)
+ Sets params for this test.
+ """
+ kwargs = self.setParams._input_kwargs
+ return self._set(**kwargs)
class ParamTests(PySparkTestCase):
@@ -134,9 +167,10 @@ class ParamTests(PySparkTestCase):
testParams = TestParams()
maxIter = testParams.maxIter
inputCol = testParams.inputCol
+ seed = testParams.seed
params = testParams.params
- self.assertEqual(params, [inputCol, maxIter])
+ self.assertEqual(params, [inputCol, maxIter, seed])
self.assertTrue(testParams.hasParam(maxIter))
self.assertTrue(testParams.hasDefault(maxIter))
@@ -154,10 +188,29 @@ class ParamTests(PySparkTestCase):
with self.assertRaises(KeyError):
testParams.getInputCol()
+ # Since the default is normally random, set it to a known number for debug str
+ testParams._setDefault(seed=41)
+ testParams.setSeed(43)
+
self.assertEquals(
testParams.explainParams(),
"\n".join(["inputCol: input column name (undefined)",
- "maxIter: max number of iterations (>= 0) (default: 10, current: 100)"]))
+ "maxIter: max number of iterations (>= 0) (default: 10, current: 100)",
+ "seed: random seed (default: 41, current: 43)"]))
+
+ def test_hasseed(self):
+ noSeedSpecd = TestParams()
+ withSeedSpecd = TestParams(seed=42)
+ other = OtherTestParams()
+ # Check that we no longer use 42 as the magic number
+ self.assertNotEqual(noSeedSpecd.getSeed(), 42)
+ origSeed = noSeedSpecd.getSeed()
+ # Check that we only compute the seed once
+ self.assertEqual(noSeedSpecd.getSeed(), origSeed)
+ # Check that a specified seed is honored
+ self.assertEqual(withSeedSpecd.getSeed(), 42)
+ # Check that a different class has a different seed
+ self.assertNotEqual(other.getSeed(), noSeedSpecd.getSeed())
class FeatureTests(PySparkTestCase):