aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/tests.py
diff options
context:
space:
mode:
authorHolden Karau <holden@us.ibm.com>2016-01-06 10:43:03 -0800
committerJoseph K. Bradley <joseph@databricks.com>2016-01-06 10:43:03 -0800
commit3b29004d2439c03a7d9bfdf7c2edd757d3d8c240 (patch)
tree66fb557cad74f3e1507ecbcd1113bc913fac8fbc /python/pyspark/ml/tests.py
parent9061e777fdcd5767718808e325e8953d484aa761 (diff)
downloadspark-3b29004d2439c03a7d9bfdf7c2edd757d3d8c240.tar.gz
spark-3b29004d2439c03a7d9bfdf7c2edd757d3d8c240.tar.bz2
spark-3b29004d2439c03a7d9bfdf7c2edd757d3d8c240.zip
[SPARK-7675][ML][PYSPARK] sparkml params type conversion
From JIRA: Currently, PySpark wrappers for spark.ml Scala classes are brittle when accepting Param types. E.g., Normalizer's "p" param cannot be set to "2" (an integer); it must be set to "2.0" (a float). Fixing this is not trivial since there does not appear to be a natural place to insert the conversion before Python wrappers call Java's Params setter method. A possible fix will be to include a method "_checkType" to PySpark's Param class which checks the type, prints an error if needed, and converts types when relevant (e.g., int to float, or scipy matrix to array). The Java wrapper method which copies params to Scala can call this method when available. This fix instead checks the types at set time since I think failing sooner is better, but I can switch it around to check at copy time if that would be better. So far this only converts int to float and other conversions (like scipymatrix to array) are left for the future. Author: Holden Karau <holden@us.ibm.com> Closes #9581 from holdenk/SPARK-7675-PySpark-sparkml-Params-type-conversion.
Diffstat (limited to 'python/pyspark/ml/tests.py')
-rw-r--r--python/pyspark/ml/tests.py22
1 files changed, 22 insertions, 0 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 7a16cf52cc..4eb17bfdcc 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -37,6 +37,7 @@ else:
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
from pyspark.sql import DataFrame, SQLContext, Row
from pyspark.sql.functions import rand
+from pyspark.ml.classification import LogisticRegression
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.param import Param, Params
from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed
@@ -92,6 +93,27 @@ class MockModel(MockTransformer, Model, HasFake):
pass
+class ParamTypeConversionTests(PySparkTestCase):
+ """
+ Test that param type conversion happens.
+ """
+
+ def test_int_to_float(self):
+ from pyspark.mllib.linalg import Vectors
+ df = self.sc.parallelize([
+ Row(label=1.0, weight=2.0, features=Vectors.dense(1.0))]).toDF()
+ lr = LogisticRegression(elasticNetParam=0)
+ lr.fit(df)
+ lr.setElasticNetParam(0)
+ lr.fit(df)
+
+ def test_invalid_to_float(self):
+ from pyspark.mllib.linalg import Vectors
+ self.assertRaises(Exception, lambda: LogisticRegression(elasticNetParam="happy"))
+ lr = LogisticRegression(elasticNetParam=0)
+ self.assertRaises(Exception, lambda: lr.setElasticNetParam("panda"))
+
+
class PipelineTests(PySparkTestCase):
def test_pipeline(self):