aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2015-04-15 23:49:42 -0700
committerXiangrui Meng <meng@databricks.com>2015-04-15 23:49:42 -0700
commit57cd1e86d1d450f85fc9e296aff498a940452113 (patch)
tree10e973e431fc3ca3e92c823eed077dae5772f5f5 /python
parent52c3439a8a107ce1fc10e4f0b59fd7881e851622 (diff)
downloadspark-57cd1e86d1d450f85fc9e296aff498a940452113.tar.gz
spark-57cd1e86d1d450f85fc9e296aff498a940452113.tar.bz2
spark-57cd1e86d1d450f85fc9e296aff498a940452113.zip
[SPARK-6893][ML] default pipeline parameter handling in python
Same as #5431 but for Python. jkbradley Author: Xiangrui Meng <meng@databricks.com> Closes #5534 from mengxr/SPARK-6893 and squashes the following commits: d3b519b [Xiangrui Meng] address comments ebaccc6 [Xiangrui Meng] style update fce244e [Xiangrui Meng] update explainParams with test 4d6b07a [Xiangrui Meng] add tests 5294500 [Xiangrui Meng] update default param handling in python
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/ml/classification.py3
-rw-r--r--python/pyspark/ml/feature.py19
-rw-r--r--python/pyspark/ml/param/__init__.py146
-rw-r--r--python/pyspark/ml/param/_shared_params_code_gen.py (renamed from python/pyspark/ml/param/_gen_shared_params.py)42
-rw-r--r--python/pyspark/ml/param/shared.py106
-rw-r--r--python/pyspark/ml/pipeline.py6
-rw-r--r--python/pyspark/ml/tests.py52
-rw-r--r--python/pyspark/ml/util.py4
-rw-r--r--python/pyspark/ml/wrapper.py2
9 files changed, 266 insertions, 114 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 7f42de531f..d7bc09fd77 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -59,6 +59,7 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
maxIter=100, regParam=0.1)
"""
super(LogisticRegression, self).__init__()
+ self._setDefault(maxIter=100, regParam=0.1)
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
@@ -71,7 +72,7 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
Sets params for logistic regression.
"""
kwargs = self.setParams._input_kwargs
- return self._set_params(**kwargs)
+ return self._set(**kwargs)
def _create_model(self, java_model):
return LogisticRegressionModel(java_model)
diff --git a/python/pyspark/ml/feature.py b/python/pyspark/ml/feature.py
index 1cfcd019df..263fe2a5bc 100644
--- a/python/pyspark/ml/feature.py
+++ b/python/pyspark/ml/feature.py
@@ -52,22 +52,22 @@ class Tokenizer(JavaTransformer, HasInputCol, HasOutputCol):
_java_class = "org.apache.spark.ml.feature.Tokenizer"
@keyword_only
- def __init__(self, inputCol="input", outputCol="output"):
+ def __init__(self, inputCol=None, outputCol=None):
"""
- __init__(self, inputCol="input", outputCol="output")
+ __init__(self, inputCol=None, outputCol=None)
"""
super(Tokenizer, self).__init__()
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
@keyword_only
- def setParams(self, inputCol="input", outputCol="output"):
+ def setParams(self, inputCol=None, outputCol=None):
"""
setParams(self, inputCol="input", outputCol="output")
Sets params for this Tokenizer.
"""
kwargs = self.setParams._input_kwargs
- return self._set_params(**kwargs)
+ return self._set(**kwargs)
@inherit_doc
@@ -91,22 +91,23 @@ class HashingTF(JavaTransformer, HasInputCol, HasOutputCol, HasNumFeatures):
_java_class = "org.apache.spark.ml.feature.HashingTF"
@keyword_only
- def __init__(self, numFeatures=1 << 18, inputCol="input", outputCol="output"):
+ def __init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None):
"""
- __init__(self, numFeatures=1 << 18, inputCol="input", outputCol="output")
+ __init__(self, numFeatures=1 << 18, inputCol=None, outputCol=None)
"""
super(HashingTF, self).__init__()
+ self._setDefault(numFeatures=1 << 18)
kwargs = self.__init__._input_kwargs
self.setParams(**kwargs)
@keyword_only
- def setParams(self, numFeatures=1 << 18, inputCol="input", outputCol="output"):
+ def setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None):
"""
- setParams(self, numFeatures=1 << 18, inputCol="input", outputCol="output")
+ setParams(self, numFeatures=1 << 18, inputCol=None, outputCol=None)
Sets params for this HashingTF.
"""
kwargs = self.setParams._input_kwargs
- return self._set_params(**kwargs)
+ return self._set(**kwargs)
if __name__ == "__main__":
diff --git a/python/pyspark/ml/param/__init__.py b/python/pyspark/ml/param/__init__.py
index e3a53dd780..5c62620562 100644
--- a/python/pyspark/ml/param/__init__.py
+++ b/python/pyspark/ml/param/__init__.py
@@ -25,23 +25,21 @@ __all__ = ['Param', 'Params']
class Param(object):
"""
- A param with self-contained documentation and optionally default value.
+ A param with self-contained documentation.
"""
- def __init__(self, parent, name, doc, defaultValue=None):
- if not isinstance(parent, Identifiable):
- raise ValueError("Parent must be identifiable but got type %s." % type(parent).__name__)
+ def __init__(self, parent, name, doc):
+ if not isinstance(parent, Params):
+ raise ValueError("Parent must be a Params but got type %s." % type(parent).__name__)
self.parent = parent
self.name = str(name)
self.doc = str(doc)
- self.defaultValue = defaultValue
def __str__(self):
- return str(self.parent) + "-" + self.name
+ return str(self.parent) + "__" + self.name
def __repr__(self):
- return "Param(parent=%r, name=%r, doc=%r, defaultValue=%r)" % \
- (self.parent, self.name, self.doc, self.defaultValue)
+ return "Param(parent=%r, name=%r, doc=%r)" % (self.parent, self.name, self.doc)
class Params(Identifiable):
@@ -52,26 +50,128 @@ class Params(Identifiable):
__metaclass__ = ABCMeta
- def __init__(self):
- super(Params, self).__init__()
- #: embedded param map
- self.paramMap = {}
+ #: internal param map for user-supplied values param map
+ paramMap = {}
+
+ #: internal param map for default values
+ defaultParamMap = {}
@property
def params(self):
"""
- Returns all params. The default implementation uses
- :py:func:`dir` to get all attributes of type
+ Returns all params ordered by name. The default implementation
+ uses :py:func:`dir` to get all attributes of type
:py:class:`Param`.
"""
return filter(lambda attr: isinstance(attr, Param),
[getattr(self, x) for x in dir(self) if x != "params"])
- def _merge_params(self, params):
- paramMap = self.paramMap.copy()
- paramMap.update(params)
+ def _explain(self, param):
+ """
+ Explains a single param and returns its name, doc, and optional
+ default value and user-supplied value in a string.
+ """
+ param = self._resolveParam(param)
+ values = []
+ if self.isDefined(param):
+ if param in self.defaultParamMap:
+ values.append("default: %s" % self.defaultParamMap[param])
+ if param in self.paramMap:
+ values.append("current: %s" % self.paramMap[param])
+ else:
+ values.append("undefined")
+ valueStr = "(" + ", ".join(values) + ")"
+ return "%s: %s %s" % (param.name, param.doc, valueStr)
+
+ def explainParams(self):
+ """
+ Returns the documentation of all params with their optionally
+ default values and user-supplied values.
+ """
+ return "\n".join([self._explain(param) for param in self.params])
+
+ def getParam(self, paramName):
+ """
+ Gets a param by its name.
+ """
+ param = getattr(self, paramName)
+ if isinstance(param, Param):
+ return param
+ else:
+ raise ValueError("Cannot find param with name %s." % paramName)
+
+ def isSet(self, param):
+ """
+ Checks whether a param is explicitly set by user.
+ """
+ param = self._resolveParam(param)
+ return param in self.paramMap
+
+ def hasDefault(self, param):
+ """
+ Checks whether a param has a default value.
+ """
+ param = self._resolveParam(param)
+ return param in self.defaultParamMap
+
+ def isDefined(self, param):
+ """
+ Checks whether a param is explicitly set by user or has a default value.
+ """
+ return self.isSet(param) or self.hasDefault(param)
+
+ def getOrDefault(self, param):
+ """
+ Gets the value of a param in the user-supplied param map or its
+ default value. Raises an error if either is set.
+ """
+ if isinstance(param, Param):
+ if param in self.paramMap:
+ return self.paramMap[param]
+ else:
+ return self.defaultParamMap[param]
+ elif isinstance(param, str):
+ return self.getOrDefault(self.getParam(param))
+ else:
+ raise KeyError("Cannot recognize %r as a param." % param)
+
+ def extractParamMap(self, extraParamMap={}):
+ """
+ Extracts the embedded default param values and user-supplied
+ values, and then merges them with extra values from input into
+ a flat param map, where the latter value is used if there exist
+ conflicts, i.e., with ordering: default param values <
+ user-supplied values < extraParamMap.
+ :param extraParamMap: extra param values
+ :return: merged param map
+ """
+ paramMap = self.defaultParamMap.copy()
+ paramMap.update(self.paramMap)
+ paramMap.update(extraParamMap)
return paramMap
+ def _shouldOwn(self, param):
+ """
+ Validates that the input param belongs to this Params instance.
+ """
+ if param.parent is not self:
+ raise ValueError("Param %r does not belong to %r." % (param, self))
+
+ def _resolveParam(self, param):
+ """
+ Resolves a param and validates the ownership.
+ :param param: param name or the param instance, which must
+ belong to this Params instance
+ :return: resolved param instance
+ """
+ if isinstance(param, Param):
+ self._shouldOwn(param)
+ return param
+ elif isinstance(param, str):
+ return self.getParam(param)
+ else:
+ raise ValueError("Cannot resolve %r as a param." % param)
+
@staticmethod
def _dummy():
"""
@@ -81,10 +181,18 @@ class Params(Identifiable):
dummy.uid = "undefined"
return dummy
- def _set_params(self, **kwargs):
+ def _set(self, **kwargs):
"""
- Sets params.
+ Sets user-supplied params.
"""
for param, value in kwargs.iteritems():
self.paramMap[getattr(self, param)] = value
return self
+
+ def _setDefault(self, **kwargs):
+ """
+ Sets default params.
+ """
+ for param, value in kwargs.iteritems():
+ self.defaultParamMap[getattr(self, param)] = value
+ return self
diff --git a/python/pyspark/ml/param/_gen_shared_params.py b/python/pyspark/ml/param/_shared_params_code_gen.py
index 5eb81106f1..55f4224976 100644
--- a/python/pyspark/ml/param/_gen_shared_params.py
+++ b/python/pyspark/ml/param/_shared_params_code_gen.py
@@ -32,29 +32,34 @@ header = """#
# limitations under the License.
#"""
+# Code generator for shared params (shared.py). Run under this folder with:
+# python _shared_params_code_gen.py > shared.py
-def _gen_param_code(name, doc, defaultValue):
+
+def _gen_param_code(name, doc, defaultValueStr):
"""
Generates Python code for a shared param class.
:param name: param name
:param doc: param doc
- :param defaultValue: string representation of the param
+ :param defaultValueStr: string representation of the default value
:return: code string
"""
# TODO: How to correctly inherit instance attributes?
template = '''class Has$Name(Params):
"""
- Params with $name.
+ Mixin for param $name: $doc.
"""
# a placeholder to make it appear in the generated doc
- $name = Param(Params._dummy(), "$name", "$doc", $defaultValue)
+ $name = Param(Params._dummy(), "$name", "$doc")
def __init__(self):
super(Has$Name, self).__init__()
#: param for $doc
- self.$name = Param(self, "$name", "$doc", $defaultValue)
+ self.$name = Param(self, "$name", "$doc")
+ if $defaultValueStr is not None:
+ self._setDefault($name=$defaultValueStr)
def set$Name(self, value):
"""
@@ -67,32 +72,29 @@ def _gen_param_code(name, doc, defaultValue):
"""
Gets the value of $name or its default value.
"""
- if self.$name in self.paramMap:
- return self.paramMap[self.$name]
- else:
- return self.$name.defaultValue'''
+ return self.getOrDefault(self.$name)'''
- upperCamelName = name[0].upper() + name[1:]
+ Name = name[0].upper() + name[1:]
return template \
.replace("$name", name) \
- .replace("$Name", upperCamelName) \
+ .replace("$Name", Name) \
.replace("$doc", doc) \
- .replace("$defaultValue", defaultValue)
+ .replace("$defaultValueStr", str(defaultValueStr))
if __name__ == "__main__":
print header
- print "\n# DO NOT MODIFY. The code is generated by _gen_shared_params.py.\n"
+ print "\n# DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py.\n"
print "from pyspark.ml.param import Param, Params\n\n"
shared = [
- ("maxIter", "max number of iterations", "100"),
- ("regParam", "regularization constant", "0.1"),
+ ("maxIter", "max number of iterations", None),
+ ("regParam", "regularization constant", None),
("featuresCol", "features column name", "'features'"),
("labelCol", "label column name", "'label'"),
("predictionCol", "prediction column name", "'prediction'"),
- ("inputCol", "input column name", "'input'"),
- ("outputCol", "output column name", "'output'"),
- ("numFeatures", "number of features", "1 << 18")]
+ ("inputCol", "input column name", None),
+ ("outputCol", "output column name", None),
+ ("numFeatures", "number of features", None)]
code = []
- for name, doc, defaultValue in shared:
- code.append(_gen_param_code(name, doc, defaultValue))
+ for name, doc, defaultValueStr in shared:
+ code.append(_gen_param_code(name, doc, defaultValueStr))
print "\n\n\n".join(code)
diff --git a/python/pyspark/ml/param/shared.py b/python/pyspark/ml/param/shared.py
index 586822f2de..13b6749998 100644
--- a/python/pyspark/ml/param/shared.py
+++ b/python/pyspark/ml/param/shared.py
@@ -15,23 +15,25 @@
# limitations under the License.
#
-# DO NOT MODIFY. The code is generated by _gen_shared_params.py.
+# DO NOT MODIFY THIS FILE! It was generated by _shared_params_code_gen.py.
from pyspark.ml.param import Param, Params
class HasMaxIter(Params):
"""
- Params with maxIter.
+ Mixin for param maxIter: max number of iterations.
"""
# a placeholder to make it appear in the generated doc
- maxIter = Param(Params._dummy(), "maxIter", "max number of iterations", 100)
+ maxIter = Param(Params._dummy(), "maxIter", "max number of iterations")
def __init__(self):
super(HasMaxIter, self).__init__()
#: param for max number of iterations
- self.maxIter = Param(self, "maxIter", "max number of iterations", 100)
+ self.maxIter = Param(self, "maxIter", "max number of iterations")
+ if None is not None:
+ self._setDefault(maxIter=None)
def setMaxIter(self, value):
"""
@@ -44,24 +46,23 @@ class HasMaxIter(Params):
"""
Gets the value of maxIter or its default value.
"""
- if self.maxIter in self.paramMap:
- return self.paramMap[self.maxIter]
- else:
- return self.maxIter.defaultValue
+ return self.getOrDefault(self.maxIter)
class HasRegParam(Params):
"""
- Params with regParam.
+ Mixin for param regParam: regularization constant.
"""
# a placeholder to make it appear in the generated doc
- regParam = Param(Params._dummy(), "regParam", "regularization constant", 0.1)
+ regParam = Param(Params._dummy(), "regParam", "regularization constant")
def __init__(self):
super(HasRegParam, self).__init__()
#: param for regularization constant
- self.regParam = Param(self, "regParam", "regularization constant", 0.1)
+ self.regParam = Param(self, "regParam", "regularization constant")
+ if None is not None:
+ self._setDefault(regParam=None)
def setRegParam(self, value):
"""
@@ -74,24 +75,23 @@ class HasRegParam(Params):
"""
Gets the value of regParam or its default value.
"""
- if self.regParam in self.paramMap:
- return self.paramMap[self.regParam]
- else:
- return self.regParam.defaultValue
+ return self.getOrDefault(self.regParam)
class HasFeaturesCol(Params):
"""
- Params with featuresCol.
+ Mixin for param featuresCol: features column name.
"""
# a placeholder to make it appear in the generated doc
- featuresCol = Param(Params._dummy(), "featuresCol", "features column name", 'features')
+ featuresCol = Param(Params._dummy(), "featuresCol", "features column name")
def __init__(self):
super(HasFeaturesCol, self).__init__()
#: param for features column name
- self.featuresCol = Param(self, "featuresCol", "features column name", 'features')
+ self.featuresCol = Param(self, "featuresCol", "features column name")
+ if 'features' is not None:
+ self._setDefault(featuresCol='features')
def setFeaturesCol(self, value):
"""
@@ -104,24 +104,23 @@ class HasFeaturesCol(Params):
"""
Gets the value of featuresCol or its default value.
"""
- if self.featuresCol in self.paramMap:
- return self.paramMap[self.featuresCol]
- else:
- return self.featuresCol.defaultValue
+ return self.getOrDefault(self.featuresCol)
class HasLabelCol(Params):
"""
- Params with labelCol.
+ Mixin for param labelCol: label column name.
"""
# a placeholder to make it appear in the generated doc
- labelCol = Param(Params._dummy(), "labelCol", "label column name", 'label')
+ labelCol = Param(Params._dummy(), "labelCol", "label column name")
def __init__(self):
super(HasLabelCol, self).__init__()
#: param for label column name
- self.labelCol = Param(self, "labelCol", "label column name", 'label')
+ self.labelCol = Param(self, "labelCol", "label column name")
+ if 'label' is not None:
+ self._setDefault(labelCol='label')
def setLabelCol(self, value):
"""
@@ -134,24 +133,23 @@ class HasLabelCol(Params):
"""
Gets the value of labelCol or its default value.
"""
- if self.labelCol in self.paramMap:
- return self.paramMap[self.labelCol]
- else:
- return self.labelCol.defaultValue
+ return self.getOrDefault(self.labelCol)
class HasPredictionCol(Params):
"""
- Params with predictionCol.
+ Mixin for param predictionCol: prediction column name.
"""
# a placeholder to make it appear in the generated doc
- predictionCol = Param(Params._dummy(), "predictionCol", "prediction column name", 'prediction')
+ predictionCol = Param(Params._dummy(), "predictionCol", "prediction column name")
def __init__(self):
super(HasPredictionCol, self).__init__()
#: param for prediction column name
- self.predictionCol = Param(self, "predictionCol", "prediction column name", 'prediction')
+ self.predictionCol = Param(self, "predictionCol", "prediction column name")
+ if 'prediction' is not None:
+ self._setDefault(predictionCol='prediction')
def setPredictionCol(self, value):
"""
@@ -164,24 +162,23 @@ class HasPredictionCol(Params):
"""
Gets the value of predictionCol or its default value.
"""
- if self.predictionCol in self.paramMap:
- return self.paramMap[self.predictionCol]
- else:
- return self.predictionCol.defaultValue
+ return self.getOrDefault(self.predictionCol)
class HasInputCol(Params):
"""
- Params with inputCol.
+ Mixin for param inputCol: input column name.
"""
# a placeholder to make it appear in the generated doc
- inputCol = Param(Params._dummy(), "inputCol", "input column name", 'input')
+ inputCol = Param(Params._dummy(), "inputCol", "input column name")
def __init__(self):
super(HasInputCol, self).__init__()
#: param for input column name
- self.inputCol = Param(self, "inputCol", "input column name", 'input')
+ self.inputCol = Param(self, "inputCol", "input column name")
+ if None is not None:
+ self._setDefault(inputCol=None)
def setInputCol(self, value):
"""
@@ -194,24 +191,23 @@ class HasInputCol(Params):
"""
Gets the value of inputCol or its default value.
"""
- if self.inputCol in self.paramMap:
- return self.paramMap[self.inputCol]
- else:
- return self.inputCol.defaultValue
+ return self.getOrDefault(self.inputCol)
class HasOutputCol(Params):
"""
- Params with outputCol.
+ Mixin for param outputCol: output column name.
"""
# a placeholder to make it appear in the generated doc
- outputCol = Param(Params._dummy(), "outputCol", "output column name", 'output')
+ outputCol = Param(Params._dummy(), "outputCol", "output column name")
def __init__(self):
super(HasOutputCol, self).__init__()
#: param for output column name
- self.outputCol = Param(self, "outputCol", "output column name", 'output')
+ self.outputCol = Param(self, "outputCol", "output column name")
+ if None is not None:
+ self._setDefault(outputCol=None)
def setOutputCol(self, value):
"""
@@ -224,24 +220,23 @@ class HasOutputCol(Params):
"""
Gets the value of outputCol or its default value.
"""
- if self.outputCol in self.paramMap:
- return self.paramMap[self.outputCol]
- else:
- return self.outputCol.defaultValue
+ return self.getOrDefault(self.outputCol)
class HasNumFeatures(Params):
"""
- Params with numFeatures.
+ Mixin for param numFeatures: number of features.
"""
# a placeholder to make it appear in the generated doc
- numFeatures = Param(Params._dummy(), "numFeatures", "number of features", 1 << 18)
+ numFeatures = Param(Params._dummy(), "numFeatures", "number of features")
def __init__(self):
super(HasNumFeatures, self).__init__()
#: param for number of features
- self.numFeatures = Param(self, "numFeatures", "number of features", 1 << 18)
+ self.numFeatures = Param(self, "numFeatures", "number of features")
+ if None is not None:
+ self._setDefault(numFeatures=None)
def setNumFeatures(self, value):
"""
@@ -254,7 +249,4 @@ class HasNumFeatures(Params):
"""
Gets the value of numFeatures or its default value.
"""
- if self.numFeatures in self.paramMap:
- return self.paramMap[self.numFeatures]
- else:
- return self.numFeatures.defaultValue
+ return self.getOrDefault(self.numFeatures)
diff --git a/python/pyspark/ml/pipeline.py b/python/pyspark/ml/pipeline.py
index 83880a5afc..d94ecfff09 100644
--- a/python/pyspark/ml/pipeline.py
+++ b/python/pyspark/ml/pipeline.py
@@ -124,10 +124,10 @@ class Pipeline(Estimator):
Sets params for Pipeline.
"""
kwargs = self.setParams._input_kwargs
- return self._set_params(**kwargs)
+ return self._set(**kwargs)
def fit(self, dataset, params={}):
- paramMap = self._merge_params(params)
+ paramMap = self.extractParamMap(params)
stages = paramMap[self.stages]
for stage in stages:
if not (isinstance(stage, Estimator) or isinstance(stage, Transformer)):
@@ -164,7 +164,7 @@ class PipelineModel(Transformer):
self.transformers = transformers
def transform(self, dataset, params={}):
- paramMap = self._merge_params(params)
+ paramMap = self.extractParamMap(params)
for t in self.transformers:
dataset = t.transform(dataset, paramMap)
return dataset
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index b627c2b4e9..3a42bcf723 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -33,6 +33,7 @@ else:
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
from pyspark.sql import DataFrame
from pyspark.ml.param import Param
+from pyspark.ml.param.shared import HasMaxIter, HasInputCol
from pyspark.ml.pipeline import Transformer, Estimator, Pipeline
@@ -46,7 +47,7 @@ class MockTransformer(Transformer):
def __init__(self):
super(MockTransformer, self).__init__()
- self.fake = Param(self, "fake", "fake", None)
+ self.fake = Param(self, "fake", "fake")
self.dataset_index = None
self.fake_param_value = None
@@ -62,7 +63,7 @@ class MockEstimator(Estimator):
def __init__(self):
super(MockEstimator, self).__init__()
- self.fake = Param(self, "fake", "fake", None)
+ self.fake = Param(self, "fake", "fake")
self.dataset_index = None
self.fake_param_value = None
self.model = None
@@ -111,5 +112,52 @@ class PipelineTests(PySparkTestCase):
self.assertEqual(6, dataset.index)
+class TestParams(HasMaxIter, HasInputCol):
+ """
+ A subclass of Params mixed with HasMaxIter and HasInputCol.
+ """
+
+ def __init__(self):
+ super(TestParams, self).__init__()
+ self._setDefault(maxIter=10)
+
+
+class ParamTests(PySparkTestCase):
+
+ def test_param(self):
+ testParams = TestParams()
+ maxIter = testParams.maxIter
+ self.assertEqual(maxIter.name, "maxIter")
+ self.assertEqual(maxIter.doc, "max number of iterations")
+ self.assertTrue(maxIter.parent is testParams)
+
+ def test_params(self):
+ testParams = TestParams()
+ maxIter = testParams.maxIter
+ inputCol = testParams.inputCol
+
+ params = testParams.params
+ self.assertEqual(params, [inputCol, maxIter])
+
+ self.assertTrue(testParams.hasDefault(maxIter))
+ self.assertFalse(testParams.isSet(maxIter))
+ self.assertTrue(testParams.isDefined(maxIter))
+ self.assertEqual(testParams.getMaxIter(), 10)
+ testParams.setMaxIter(100)
+ self.assertTrue(testParams.isSet(maxIter))
+ self.assertEquals(testParams.getMaxIter(), 100)
+
+ self.assertFalse(testParams.hasDefault(inputCol))
+ self.assertFalse(testParams.isSet(inputCol))
+ self.assertFalse(testParams.isDefined(inputCol))
+ with self.assertRaises(KeyError):
+ testParams.getInputCol()
+
+ self.assertEquals(
+ testParams.explainParams(),
+ "\n".join(["inputCol: input column name (undefined)",
+ "maxIter: max number of iterations (default: 10, current: 100)"]))
+
+
if __name__ == "__main__":
unittest.main()
diff --git a/python/pyspark/ml/util.py b/python/pyspark/ml/util.py
index 6f7f39c40e..d3cb100a9e 100644
--- a/python/pyspark/ml/util.py
+++ b/python/pyspark/ml/util.py
@@ -40,8 +40,8 @@ class Identifiable(object):
def __init__(self):
#: A unique id for the object. The default implementation
- #: concatenates the class name, "-", and 8 random hex chars.
- self.uid = type(self).__name__ + "-" + uuid.uuid4().hex[:8]
+ #: concatenates the class name, "_", and 8 random hex chars.
+ self.uid = type(self).__name__ + "_" + uuid.uuid4().hex[:8]
def __repr__(self):
return self.uid
diff --git a/python/pyspark/ml/wrapper.py b/python/pyspark/ml/wrapper.py
index 31a66b3d2f..394f23c5e9 100644
--- a/python/pyspark/ml/wrapper.py
+++ b/python/pyspark/ml/wrapper.py
@@ -64,7 +64,7 @@ class JavaWrapper(Params):
:param params: additional params (overwriting embedded values)
:param java_obj: Java object to receive the params
"""
- paramMap = self._merge_params(params)
+ paramMap = self.extractParamMap(params)
for param in self.params:
if param in paramMap:
java_obj.set(param.name, paramMap[param])