From 30bdb5cbd9aec191cf15cdc83c3fee375c04c2b2 Mon Sep 17 00:00:00 2001 From: sethah Date: Wed, 23 Mar 2016 11:20:44 -0700 Subject: [SPARK-13068][PYSPARK][ML] Type conversion for Pyspark params ## What changes were proposed in this pull request? This patch adds type conversion functionality for parameters in Pyspark. A `typeConverter` field is added to the constructor of `Param` class. This argument is a function which converts values passed to this param to the appropriate type if possible. This is beneficial so that the params can fail at set time if they are given inappropriate values, but even more so because coherent error messages are now provided when Py4J cannot cast the python type to the appropriate Java type. This patch also adds a `TypeConverters` class with factory methods for common type conversions. Most of the changes involve adding these factory type converters to existing params. The previous solution to this issue, `expectedType`, is deprecated and can be removed in 2.1.0 as discussed on the Jira. ## How was this patch tested? Unit tests were added in python/pyspark/ml/tests.py to test parameter type conversion. These tests check that values that should be convertible are converted correctly, and that the appropriate errors are thrown when invalid values are provided. Author: sethah Closes #11663 from sethah/SPARK-13068-tc. --- python/pyspark/ml/tuning.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) (limited to 'python/pyspark/ml/tuning.py') diff --git a/python/pyspark/ml/tuning.py b/python/pyspark/ml/tuning.py index 77af0094df..a528d22e18 100644 --- a/python/pyspark/ml/tuning.py +++ b/python/pyspark/ml/tuning.py @@ -20,7 +20,7 @@ import numpy as np from pyspark import since from pyspark.ml import Estimator, Model -from pyspark.ml.param import Params, Param +from pyspark.ml.param import Params, Param, TypeConverters from pyspark.ml.param.shared import HasSeed from pyspark.ml.util import keyword_only from pyspark.sql.functions import rand @@ -121,7 +121,8 @@ class CrossValidator(Estimator, HasSeed): evaluator = Param( Params._dummy(), "evaluator", "evaluator used to select hyper-parameters that maximize the cross-validated metric") - numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation") + numFolds = Param(Params._dummy(), "numFolds", "number of folds for cross validation", + typeConverter=TypeConverters.toInt) @keyword_only def __init__(self, estimator=None, estimatorParamMaps=None, evaluator=None, numFolds=3, -- cgit v1.2.3