diff options
author | Holden Karau <holden@us.ibm.com> | 2016-01-06 10:43:03 -0800 |
---|---|---|
committer | Joseph K. Bradley <joseph@databricks.com> | 2016-01-06 10:43:03 -0800 |
commit | 3b29004d2439c03a7d9bfdf7c2edd757d3d8c240 (patch) | |
tree | 66fb557cad74f3e1507ecbcd1113bc913fac8fbc /python/pyspark/ml/tests.py | |
parent | 9061e777fdcd5767718808e325e8953d484aa761 (diff) | |
download | spark-3b29004d2439c03a7d9bfdf7c2edd757d3d8c240.tar.gz spark-3b29004d2439c03a7d9bfdf7c2edd757d3d8c240.tar.bz2 spark-3b29004d2439c03a7d9bfdf7c2edd757d3d8c240.zip |
[SPARK-7675][ML][PYSPARK] sparkml params type conversion
From JIRA:
Currently, PySpark wrappers for spark.ml Scala classes are brittle when accepting Param types. E.g., Normalizer's "p" param cannot be set to "2" (an integer); it must be set to "2.0" (a float). Fixing this is not trivial since there does not appear to be a natural place to insert the conversion before Python wrappers call Java's Params setter method.
A possible fix will be to include a method "_checkType" to PySpark's Param class which checks the type, prints an error if needed, and converts types when relevant (e.g., int to float, or scipy matrix to array). The Java wrapper method which copies params to Scala can call this method when available.
This fix instead checks the types at set time since I think failing sooner is better, but I can switch it around to check at copy time if that would be better. So far this only converts int to float and other conversions (like scipymatrix to array) are left for the future.
Author: Holden Karau <holden@us.ibm.com>
Closes #9581 from holdenk/SPARK-7675-PySpark-sparkml-Params-type-conversion.
Diffstat (limited to 'python/pyspark/ml/tests.py')
-rw-r--r-- | python/pyspark/ml/tests.py | 22 |
1 files changed, 22 insertions, 0 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py index 7a16cf52cc..4eb17bfdcc 100644 --- a/python/pyspark/ml/tests.py +++ b/python/pyspark/ml/tests.py @@ -37,6 +37,7 @@ else: from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase from pyspark.sql import DataFrame, SQLContext, Row from pyspark.sql.functions import rand +from pyspark.ml.classification import LogisticRegression from pyspark.ml.evaluation import RegressionEvaluator from pyspark.ml.param import Param, Params from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed @@ -92,6 +93,27 @@ class MockModel(MockTransformer, Model, HasFake): pass +class ParamTypeConversionTests(PySparkTestCase): + """ + Test that param type conversion happens. + """ + + def test_int_to_float(self): + from pyspark.mllib.linalg import Vectors + df = self.sc.parallelize([ + Row(label=1.0, weight=2.0, features=Vectors.dense(1.0))]).toDF() + lr = LogisticRegression(elasticNetParam=0) + lr.fit(df) + lr.setElasticNetParam(0) + lr.fit(df) + + def test_invalid_to_float(self): + from pyspark.mllib.linalg import Vectors + self.assertRaises(Exception, lambda: LogisticRegression(elasticNetParam="happy")) + lr = LogisticRegression(elasticNetParam=0) + self.assertRaises(Exception, lambda: lr.setElasticNetParam("panda")) + + class PipelineTests(PySparkTestCase): def test_pipeline(self): |