aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/tests.py')
-rw-r--r--python/pyspark/ml/tests.py67
1 files changed, 60 insertions, 7 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 10fe0ef8db..6adbf166f3 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -33,7 +33,8 @@ else:
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
from pyspark.sql import DataFrame, SQLContext
from pyspark.ml.param import Param, Params
-from pyspark.ml.param.shared import HasMaxIter, HasInputCol
+from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed
+from pyspark.ml.util import keyword_only
from pyspark.ml import Estimator, Model, Pipeline, Transformer
from pyspark.ml.feature import *
from pyspark.mllib.linalg import DenseVector
@@ -111,14 +112,46 @@ class PipelineTests(PySparkTestCase):
self.assertEqual(6, dataset.index)
-class TestParams(HasMaxIter, HasInputCol):
+class TestParams(HasMaxIter, HasInputCol, HasSeed):
"""
- A subclass of Params mixed with HasMaxIter and HasInputCol.
+ A subclass of Params mixed with HasMaxIter, HasInputCol and HasSeed.
"""
-
- def __init__(self):
+ @keyword_only
+ def __init__(self, seed=None):
super(TestParams, self).__init__()
self._setDefault(maxIter=10)
+ kwargs = self.__init__._input_kwargs
+ self.setParams(**kwargs)
+
+ @keyword_only
+ def setParams(self, seed=None):
+ """
+ setParams(self, seed=None)
+ Sets params for this test.
+ """
+ kwargs = self.setParams._input_kwargs
+ return self._set(**kwargs)
+
+
+class OtherTestParams(HasMaxIter, HasInputCol, HasSeed):
+ """
+ A subclass of Params mixed with HasMaxIter, HasInputCol and HasSeed.
+ """
+ @keyword_only
+ def __init__(self, seed=None):
+ super(OtherTestParams, self).__init__()
+ self._setDefault(maxIter=10)
+ kwargs = self.__init__._input_kwargs
+ self.setParams(**kwargs)
+
+ @keyword_only
+ def setParams(self, seed=None):
+ """
+ setParams(self, seed=None)
+ Sets params for this test.
+ """
+ kwargs = self.setParams._input_kwargs
+ return self._set(**kwargs)
class ParamTests(PySparkTestCase):
@@ -134,9 +167,10 @@ class ParamTests(PySparkTestCase):
testParams = TestParams()
maxIter = testParams.maxIter
inputCol = testParams.inputCol
+ seed = testParams.seed
params = testParams.params
- self.assertEqual(params, [inputCol, maxIter])
+ self.assertEqual(params, [inputCol, maxIter, seed])
self.assertTrue(testParams.hasParam(maxIter))
self.assertTrue(testParams.hasDefault(maxIter))
@@ -154,10 +188,29 @@ class ParamTests(PySparkTestCase):
with self.assertRaises(KeyError):
testParams.getInputCol()
+ # Since the default is normally random, set it to a known number for debug str
+ testParams._setDefault(seed=41)
+ testParams.setSeed(43)
+
self.assertEquals(
testParams.explainParams(),
"\n".join(["inputCol: input column name (undefined)",
- "maxIter: max number of iterations (>= 0) (default: 10, current: 100)"]))
+ "maxIter: max number of iterations (>= 0) (default: 10, current: 100)",
+ "seed: random seed (default: 41, current: 43)"]))
+
+ def test_hasseed(self):
+ noSeedSpecd = TestParams()
+ withSeedSpecd = TestParams(seed=42)
+ other = OtherTestParams()
+ # Check that we no longer use 42 as the magic number
+ self.assertNotEqual(noSeedSpecd.getSeed(), 42)
+ origSeed = noSeedSpecd.getSeed()
+ # Check that we only compute the seed once
+ self.assertEqual(noSeedSpecd.getSeed(), origSeed)
+ # Check that a specified seed is honored
+ self.assertEqual(withSeedSpecd.getSeed(), 42)
+ # Check that a different class has a different seed
+ self.assertNotEqual(other.getSeed(), noSeedSpecd.getSeed())
class FeatureTests(PySparkTestCase):