aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/tests.py
diff options
context:
space:
mode:
authorHolden Karau <holden@us.ibm.com>2016-01-26 15:53:48 -0800
committerJoseph K. Bradley <joseph@databricks.com>2016-01-26 15:53:48 -0800
commiteb917291ca1a2d68ca0639cb4b1464a546603eba (patch)
tree380dcaa33273baa68beaf089387bd498d5ee88e8 /python/pyspark/ml/tests.py
parent19fdb21afbf0eae4483cf6d4ef32daffd1994b89 (diff)
downloadspark-eb917291ca1a2d68ca0639cb4b1464a546603eba.tar.gz
spark-eb917291ca1a2d68ca0639cb4b1464a546603eba.tar.bz2
spark-eb917291ca1a2d68ca0639cb4b1464a546603eba.zip
[SPARK-10509][PYSPARK] Reduce excessive param boiler plate code
The current python ml params require cut-and-pasting the param setup and description between the class & ```__init__``` methods. Remove this possible case of errors & simplify use of custom params by adding a ```_copy_new_parent``` method to param so as to avoid cut and pasting (and cut and pasting at different indentation levels urgh). Author: Holden Karau <holden@us.ibm.com> Closes #10216 from holdenk/SPARK-10509-excessive-param-boiler-plate-code.
Diffstat (limited to 'python/pyspark/ml/tests.py')
-rw-r--r--python/pyspark/ml/tests.py12
1 files changed, 12 insertions, 0 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 9ea639dc4f..c45a159c46 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -185,6 +185,18 @@ class OtherTestParams(HasMaxIter, HasInputCol, HasSeed):
class ParamTests(PySparkTestCase):
+ def test_copy_new_parent(self):
+ testParams = TestParams()
+ # Copying an instantiated param should fail
+ with self.assertRaises(ValueError):
+ testParams.maxIter._copy_new_parent(testParams)
+ # Copying a dummy param should succeed
+ TestParams.maxIter._copy_new_parent(testParams)
+ maxIter = testParams.maxIter
+ self.assertEqual(maxIter.name, "maxIter")
+ self.assertEqual(maxIter.doc, "max number of iterations (>= 0).")
+ self.assertTrue(maxIter.parent == testParams.uid)
+
def test_param(self):
testParams = TestParams()
maxIter = testParams.maxIter