aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/classification.py
diff options
context:
space:
mode:
authorsethah <seth.hendrickson16@gmail.com>2016-03-23 11:20:44 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-03-23 11:20:44 -0700
commit30bdb5cbd9aec191cf15cdc83c3fee375c04c2b2 (patch)
tree4d48b42ebe347fc40d5deeb3a77996db0c30eea1 /python/pyspark/ml/classification.py
parent48ee16d8012602c75d50aa2a85e26b7de3c48944 (diff)
downloadspark-30bdb5cbd9aec191cf15cdc83c3fee375c04c2b2.tar.gz
spark-30bdb5cbd9aec191cf15cdc83c3fee375c04c2b2.tar.bz2
spark-30bdb5cbd9aec191cf15cdc83c3fee375c04c2b2.zip
[SPARK-13068][PYSPARK][ML] Type conversion for Pyspark params
## What changes were proposed in this pull request? This patch adds type conversion functionality for parameters in Pyspark. A `typeConverter` field is added to the constructor of `Param` class. This argument is a function which converts values passed to this param to the appropriate type if possible. This is beneficial so that the params can fail at set time if they are given inappropriate values, but even more so because coherent error messages are now provided when Py4J cannot cast the python type to the appropriate Java type. This patch also adds a `TypeConverters` class with factory methods for common type conversions. Most of the changes involve adding these factory type converters to existing params. The previous solution to this issue, `expectedType`, is deprecated and can be removed in 2.1.0 as discussed on the Jira. ## How was this patch tested? Unit tests were added in python/pyspark/ml/tests.py to test parameter type conversion. These tests check that values that should be convertible are converted correctly, and that the appropriate errors are thrown when invalid values are provided. Author: sethah <seth.hendrickson16@gmail.com> Closes #11663 from sethah/SPARK-13068-tc.
Diffstat (limited to 'python/pyspark/ml/classification.py')
-rw-r--r--python/pyspark/ml/classification.py20
1 files changed, 13 insertions, 7 deletions
diff --git a/python/pyspark/ml/classification.py b/python/pyspark/ml/classification.py
index 8075108114..fdeccf822c 100644
--- a/python/pyspark/ml/classification.py
+++ b/python/pyspark/ml/classification.py
@@ -20,6 +20,7 @@ import warnings
from pyspark import since
from pyspark.ml.util import *
from pyspark.ml.wrapper import JavaEstimator, JavaModel
+from pyspark.ml.param import TypeConverters
from pyspark.ml.param.shared import *
from pyspark.ml.regression import (
RandomForestParams, TreeEnsembleParams, DecisionTreeModel, TreeEnsembleModels)
@@ -87,7 +88,8 @@ class LogisticRegression(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredicti
threshold = Param(Params._dummy(), "threshold",
"Threshold in binary classification prediction, in range [0, 1]." +
- " If threshold and thresholds are both set, they must match.")
+ " If threshold and thresholds are both set, they must match.",
+ typeConverter=TypeConverters.toFloat)
@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
@@ -243,7 +245,7 @@ class TreeClassifierParams(object):
impurity = Param(Params._dummy(), "impurity",
"Criterion used for information gain calculation (case-insensitive). " +
"Supported options: " +
- ", ".join(supportedImpurities))
+ ", ".join(supportedImpurities), typeConverter=TypeConverters.toString)
def __init__(self):
super(TreeClassifierParams, self).__init__()
@@ -534,7 +536,8 @@ class GBTClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol
lossType = Param(Params._dummy(), "lossType",
"Loss function which GBT tries to minimize (case-insensitive). " +
- "Supported options: " + ", ".join(GBTParams.supportedLossTypes))
+ "Supported options: " + ", ".join(GBTParams.supportedLossTypes),
+ typeConverter=TypeConverters.toString)
@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
@@ -652,9 +655,10 @@ class NaiveBayes(JavaEstimator, HasFeaturesCol, HasLabelCol, HasPredictionCol, H
"""
smoothing = Param(Params._dummy(), "smoothing", "The smoothing parameter, should be >= 0, " +
- "default is 1.0")
+ "default is 1.0", typeConverter=TypeConverters.toFloat)
modelType = Param(Params._dummy(), "modelType", "The model type which is a string " +
- "(case-sensitive). Supported options: multinomial (default) and bernoulli.")
+ "(case-sensitive). Supported options: multinomial (default) and bernoulli.",
+ typeConverter=TypeConverters.toString)
@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",
@@ -782,11 +786,13 @@ class MultilayerPerceptronClassifier(JavaEstimator, HasFeaturesCol, HasLabelCol,
layers = Param(Params._dummy(), "layers", "Sizes of layers from input layer to output layer " +
"E.g., Array(780, 100, 10) means 780 inputs, one hidden layer with 100 " +
- "neurons and output layer of 10 neurons, default is [1, 1].")
+ "neurons and output layer of 10 neurons, default is [1, 1].",
+ typeConverter=TypeConverters.toListInt)
blockSize = Param(Params._dummy(), "blockSize", "Block size for stacking input data in " +
"matrices. Data is stacked within partitions. If block size is more than " +
"remaining data in a partition then it is adjusted to the size of this " +
- "data. Recommended size is between 10 and 1000, default is 128.")
+ "data. Recommended size is between 10 and 1000, default is 128.",
+ typeConverter=TypeConverters.toInt)
@keyword_only
def __init__(self, featuresCol="features", labelCol="label", predictionCol="prediction",