aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/tests.py
diff options
context:
space:
mode:
authorsethah <seth.hendrickson16@gmail.com>2016-03-23 11:20:44 -0700
committerJoseph K. Bradley <joseph@databricks.com>2016-03-23 11:20:44 -0700
commit30bdb5cbd9aec191cf15cdc83c3fee375c04c2b2 (patch)
tree4d48b42ebe347fc40d5deeb3a77996db0c30eea1 /python/pyspark/ml/tests.py
parent48ee16d8012602c75d50aa2a85e26b7de3c48944 (diff)
downloadspark-30bdb5cbd9aec191cf15cdc83c3fee375c04c2b2.tar.gz
spark-30bdb5cbd9aec191cf15cdc83c3fee375c04c2b2.tar.bz2
spark-30bdb5cbd9aec191cf15cdc83c3fee375c04c2b2.zip
[SPARK-13068][PYSPARK][ML] Type conversion for Pyspark params
## What changes were proposed in this pull request? This patch adds type conversion functionality for parameters in Pyspark. A `typeConverter` field is added to the constructor of `Param` class. This argument is a function which converts values passed to this param to the appropriate type if possible. This is beneficial so that the params can fail at set time if they are given inappropriate values, but even more so because coherent error messages are now provided when Py4J cannot cast the python type to the appropriate Java type. This patch also adds a `TypeConverters` class with factory methods for common type conversions. Most of the changes involve adding these factory type converters to existing params. The previous solution to this issue, `expectedType`, is deprecated and can be removed in 2.1.0 as discussed on the Jira. ## How was this patch tested? Unit tests were added in python/pyspark/ml/tests.py to test parameter type conversion. These tests check that values that should be convertible are converted correctly, and that the appropriate errors are thrown when invalid values are provided. Author: sethah <seth.hendrickson16@gmail.com> Closes #11663 from sethah/SPARK-13068-tc.
Diffstat (limited to 'python/pyspark/ml/tests.py')
-rw-r--r--python/pyspark/ml/tests.py83
1 files changed, 66 insertions, 17 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 211248e8b2..2fa5da7738 100644
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -18,8 +18,11 @@
"""
Unit tests for Spark ML Python APIs.
"""
-
+import array
import sys
+if sys.version > '3':
+ xrange = range
+
try:
import xmlrunner
except ImportError:
@@ -36,19 +39,20 @@ else:
from shutil import rmtree
import tempfile
+import numpy as np
from pyspark.ml import Estimator, Model, Pipeline, PipelineModel, Transformer
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.clustering import KMeans
from pyspark.ml.evaluation import RegressionEvaluator
from pyspark.ml.feature import *
-from pyspark.ml.param import Param, Params
+from pyspark.ml.param import Param, Params, TypeConverters
from pyspark.ml.param.shared import HasMaxIter, HasInputCol, HasSeed
from pyspark.ml.regression import LinearRegression
from pyspark.ml.tuning import *
from pyspark.ml.util import keyword_only
from pyspark.ml.wrapper import JavaWrapper
-from pyspark.mllib.linalg import DenseVector
+from pyspark.mllib.linalg import DenseVector, SparseVector
from pyspark.sql import DataFrame, SQLContext, Row
from pyspark.sql.functions import rand
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
@@ -104,20 +108,65 @@ class ParamTypeConversionTests(PySparkTestCase):
Test that param type conversion happens.
"""
- def test_int_to_float(self):
- from pyspark.mllib.linalg import Vectors
- df = self.sc.parallelize([
- Row(label=1.0, weight=2.0, features=Vectors.dense(1.0))]).toDF()
- lr = LogisticRegression(elasticNetParam=0)
- lr.fit(df)
- lr.setElasticNetParam(0)
- lr.fit(df)
-
- def test_invalid_to_float(self):
- from pyspark.mllib.linalg import Vectors
- self.assertRaises(Exception, lambda: LogisticRegression(elasticNetParam="happy"))
- lr = LogisticRegression(elasticNetParam=0)
- self.assertRaises(Exception, lambda: lr.setElasticNetParam("panda"))
+ def test_int(self):
+ lr = LogisticRegression(maxIter=5.0)
+ self.assertEqual(lr.getMaxIter(), 5)
+ self.assertTrue(type(lr.getMaxIter()) == int)
+ self.assertRaises(TypeError, lambda: LogisticRegression(maxIter="notAnInt"))
+ self.assertRaises(TypeError, lambda: LogisticRegression(maxIter=5.1))
+
+ def test_float(self):
+ lr = LogisticRegression(tol=1)
+ self.assertEqual(lr.getTol(), 1.0)
+ self.assertTrue(type(lr.getTol()) == float)
+ self.assertRaises(TypeError, lambda: LogisticRegression(tol="notAFloat"))
+
+ def test_vector(self):
+ ewp = ElementwiseProduct(scalingVec=[1, 3])
+ self.assertEqual(ewp.getScalingVec(), DenseVector([1.0, 3.0]))
+ ewp = ElementwiseProduct(scalingVec=np.array([1.2, 3.4]))
+ self.assertEqual(ewp.getScalingVec(), DenseVector([1.2, 3.4]))
+ self.assertRaises(TypeError, lambda: ElementwiseProduct(scalingVec=["a", "b"]))
+
+ def test_list(self):
+ l = [0, 1]
+ for lst_like in [l, np.array(l), DenseVector(l), SparseVector(len(l), range(len(l)), l),
+ array.array('l', l), xrange(2), tuple(l)]:
+ converted = TypeConverters.toList(lst_like)
+ self.assertEqual(type(converted), list)
+ self.assertListEqual(converted, l)
+
+ def test_list_int(self):
+ for indices in [[1.0, 2.0], np.array([1.0, 2.0]), DenseVector([1.0, 2.0]),
+ SparseVector(2, {0: 1.0, 1: 2.0}), xrange(1, 3), (1.0, 2.0),
+ array.array('d', [1.0, 2.0])]:
+ vs = VectorSlicer(indices=indices)
+ self.assertListEqual(vs.getIndices(), [1, 2])
+ self.assertTrue(all([type(v) == int for v in vs.getIndices()]))
+ self.assertRaises(TypeError, lambda: VectorSlicer(indices=["a", "b"]))
+
+ def test_list_float(self):
+ b = Bucketizer(splits=[1, 4])
+ self.assertEqual(b.getSplits(), [1.0, 4.0])
+ self.assertTrue(all([type(v) == float for v in b.getSplits()]))
+ self.assertRaises(TypeError, lambda: Bucketizer(splits=["a", 1.0]))
+
+ def test_list_string(self):
+ for labels in [np.array(['a', u'b']), ['a', u'b'], np.array(['a', 'b'])]:
+ idx_to_string = IndexToString(labels=labels)
+ self.assertListEqual(idx_to_string.getLabels(), ['a', 'b'])
+ self.assertRaises(TypeError, lambda: IndexToString(labels=['a', 2]))
+
+ def test_string(self):
+ lr = LogisticRegression()
+ for col in ['features', u'features', np.str_('features')]:
+ lr.setFeaturesCol(col)
+ self.assertEqual(lr.getFeaturesCol(), 'features')
+ self.assertRaises(TypeError, lambda: LogisticRegression(featuresCol=2.3))
+
+ def test_bool(self):
+ self.assertRaises(TypeError, lambda: LogisticRegression(fitIntercept=1))
+ self.assertRaises(TypeError, lambda: LogisticRegression(fitIntercept="false"))
class PipelineTests(PySparkTestCase):