aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/ml/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/ml/tests.py')
-rwxr-xr-xpython/pyspark/ml/tests.py20
1 files changed, 20 insertions, 0 deletions
diff --git a/python/pyspark/ml/tests.py b/python/pyspark/ml/tests.py
index 3524160557..f052f5bb77 100755
--- a/python/pyspark/ml/tests.py
+++ b/python/pyspark/ml/tests.py
@@ -1223,6 +1223,26 @@ class HashingTFTest(SparkSessionTestCase):
": expected " + str(expected[i]) + ", got " + str(features[i]))
+class GeneralizedLinearRegressionTest(SparkSessionTestCase):
+
+ def test_tweedie_distribution(self):
+
+ df = self.spark.createDataFrame(
+ [(1.0, Vectors.dense(0.0, 0.0)),
+ (1.0, Vectors.dense(1.0, 2.0)),
+ (2.0, Vectors.dense(0.0, 0.0)),
+ (2.0, Vectors.dense(1.0, 1.0)), ], ["label", "features"])
+
+ glr = GeneralizedLinearRegression(family="tweedie", variancePower=1.6)
+ model = glr.fit(df)
+ self.assertTrue(np.allclose(model.coefficients.toArray(), [-0.4645, 0.3402], atol=1E-4))
+ self.assertTrue(np.isclose(model.intercept, 0.7841, atol=1E-4))
+
+ model2 = glr.setLinkPower(-1.0).fit(df)
+ self.assertTrue(np.allclose(model2.coefficients.toArray(), [-0.6667, 0.5], atol=1E-4))
+ self.assertTrue(np.isclose(model2.intercept, 0.6667, atol=1E-4))
+
+
class ALSTest(SparkSessionTestCase):
def test_storage_levels(self):