aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/regression.py
diff options
context:
space:
mode:
authorsethah <seth.hendrickson16@gmail.com>2016-03-23 11:20:44 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-03-23 11:20:44 -0700
commit30bdb5cbd9aec191cf15cdc83c3fee375c04c2b2 (patch)
tree4d48b42ebe347fc40d5deeb3a77996db0c30eea1 /python/pyspark/ml/regression.py
parent48ee16d8012602c75d50aa2a85e26b7de3c48944 (diff)
downloadspark-30bdb5cbd9aec191cf15cdc83c3fee375c04c2b2.tar.gz
spark-30bdb5cbd9aec191cf15cdc83c3fee375c04c2b2.tar.bz2
spark-30bdb5cbd9aec191cf15cdc83c3fee375c04c2b2.zip
[SPARK-13068][PYSPARK][ML] Type conversion for Pyspark params
## What changes were proposed in this pull request? This patch adds type conversion functionality for parameters in Pyspark. A `typeConverter` field is added to the constructor of `Param` class. This argument is a function which converts values passed to this param to the appropriate type if possible. This is beneficial so that the params can fail at set time if they are given inappropriate values, but even more so because coherent error messages are now provided when Py4J cannot cast the python type to the appropriate Java type. This patch also adds a `TypeConverters` class with factory methods for common type conversions. Most of the changes involve adding these factory type converters to existing params. The previous solution to this issue, `expectedType`, is deprecated and can be removed in 2.1.0 as discussed on the Jira. ## How was this patch tested? Unit tests were added in python/pyspark/ml/tests.py to test parameter type conversion. These tests check that values that should be convertible are converted correctly, and that the appropriate errors are thrown when invalid values are provided. Author: sethah <seth.hendrickson16@gmail.com> Closes #11663 from sethah/SPARK-13068-tc.
Diffstat (limited to 'python/pyspark/ml/regression.py')
-rw-r--r--python/pyspark/ml/regression.py25
1 files changed, 16 insertions, 9 deletions
diff --git a/python/pyspark/ml/regression.py b/python/pyspark/ml/regression.py
index 664a44bc47..898260879d 100644
--- a/python/pyspark/ml/regression.py
+++ b/python/pyspark/ml/regression.py
@@ -189,10 +189,11 @@ class IsotonicRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
isotonic = \
Param(Params._dummy(), "isotonic",
"whether the output sequence should be isotonic/increasing (true) or" +
- "antitonic/decreasing (false).")
+ "antitonic/decreasing (false).", typeConverter=TypeConverters.toBoolean)
featureIndex = \
Param(Params._dummy(), "featureIndex",
- "The index of the feature if featuresCol is a vector column, no effect otherwise.")
+ "The index of the feature if featuresCol is a vector column, no effect otherwise.",
+ typeConverter=TypeConverters.toInt)
@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
@@ -278,7 +279,8 @@ class TreeEnsembleParams(DecisionTreeParams):
"""
subsamplingRate = Param(Params._dummy(), "subsamplingRate", "Fraction of the training data " +
- "used for learning each decision tree, in range (0, 1].")
+ "used for learning each decision tree, in range (0, 1].",
+ typeConverter=TypeConverters.toFloat)
def __init__(self):
super(TreeEnsembleParams, self).__init__()
@@ -335,11 +337,13 @@ class RandomForestParams(TreeEnsembleParams):
"""
supportedFeatureSubsetStrategies = ["auto", "all", "onethird", "sqrt", "log2"]
- numTrees = Param(Params._dummy(), "numTrees", "Number of trees to train (>= 1).")
+ numTrees = Param(Params._dummy(), "numTrees", "Number of trees to train (>= 1).",
+ typeConverter=TypeConverters.toInt)
featureSubsetStrategy = \
Param(Params._dummy(), "featureSubsetStrategy",
"The number of features to consider for splits at each tree node. Supported " +
- "options: " + ", ".join(supportedFeatureSubsetStrategies))
+ "options: " + ", ".join(supportedFeatureSubsetStrategies),
+ typeConverter=TypeConverters.toString)
def __init__(self):
super(RandomForestParams, self).__init__()
@@ -653,7 +657,8 @@ class GBTRegressor(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol,
lossType = Param(Params._dummy(), "lossType",
"Loss function which GBT tries to minimize (case-insensitive). " +
- "Supported options: " + ", ".join(GBTParams.supportedLossTypes))
+ "Supported options: " + ", ".join(GBTParams.supportedLossTypes),
+ typeConverter=TypeConverters.toString)
@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
@@ -767,14 +772,16 @@ class AFTSurvivalRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredi
censorCol = Param(Params._dummy(), "censorCol",
"censor column name. The value of this column could be 0 or 1. " +
"If the value is 1, it means the event has occurred i.e. " +
- "uncensored; otherwise censored.")
+ "uncensored; otherwise censored.", typeConverter=TypeConverters.toString)
quantileProbabilities = \
Param(Params._dummy(), "quantileProbabilities",
"quantile probabilities array. Values of the quantile probabilities array " +
- "should be in the range (0, 1) and the array should be non-empty.")
+ "should be in the range (0, 1) and the array should be non-empty.",
+ typeConverter=TypeConverters.toListFloat)
quantilesCol = Param(Params._dummy(), "quantilesCol",
"quantiles column name. This column will output quantiles of " +
- "corresponding quantileProbabilities if it is set.")
+ "corresponding quantileProbabilities if it is set.",
+ typeConverter=TypeConverters.toString)
@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",