aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark
diff options
context:
space:
mode:
authorjbencook <jbenjamincook@gmail.com>2014-12-16 11:37:23 -0800
committerXiangrui Meng <meng@databricks.com>2014-12-16 11:37:23 -0800
commitcb484474934d664000df3d63a326bcd6b12f2f09 (patch)
tree4d8f3b76fd8e0216aa4f5df8d58f378ebddf3485 /python/pyspark
parented362008f0a317729f8404e86e57d8a6ceb60f21 (diff)
downloadspark-cb484474934d664000df3d63a326bcd6b12f2f09.tar.gz
spark-cb484474934d664000df3d63a326bcd6b12f2f09.tar.bz2
spark-cb484474934d664000df3d63a326bcd6b12f2f09.zip
[SPARK-4855][mllib] testing the Chi-squared hypothesis test
This PR tests the pyspark Chi-squared hypothesis test from this commit: c8abddc5164d8cf11cdede6ab3d5d1ea08028708 and moves some of the error messaging in to python. It is a port of the Scala tests here: [HypothesisTestSuite.scala](https://github.com/apache/spark/blob/master/mllib/src/test/scala/org/apache/spark/mllib/stat/HypothesisTestSuite.scala) Hopefully, SPARK-2980 can be closed. Author: jbencook <jbenjamincook@gmail.com> Closes #3679 from jbencook/master and squashes the following commits: 44078e0 [jbencook] checking that bad input throws the correct exceptions f12ee10 [jbencook] removing checks for ValueError since input tests are on the Scala side 7536cf1 [jbencook] removing python checks for invalid input a17ee84 [jbencook] [SPARK-2980][mllib] adding unit tests for the pyspark chi-squared test 3aeb0d9 [jbencook] [SPARK-2980][mllib] bringing Chi-squared error messages to the python side
Diffstat (limited to 'python/pyspark')
-rw-r--r--python/pyspark/mllib/tests.py100
1 files changed, 99 insertions, 1 deletions
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 8332f8e061..5034f229e8 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -23,6 +23,7 @@ import sys
import array as pyarray
from numpy import array, array_equal
+from py4j.protocol import Py4JJavaError
if sys.version_info[:2] <= (2, 6):
try:
@@ -34,7 +35,7 @@ else:
import unittest
from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\
- DenseMatrix
+ DenseMatrix, Vectors, Matrices
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.random import RandomRDDs
from pyspark.mllib.stat import Statistics
@@ -400,6 +401,103 @@ class SciPyTests(PySparkTestCase):
self.assertTrue(dt_model.predict(features[3]) > 0)
+class ChiSqTestTests(PySparkTestCase):
+ def test_goodness_of_fit(self):
+ from numpy import inf
+
+ observed = Vectors.dense([4, 6, 5])
+ pearson = Statistics.chiSqTest(observed)
+
+ # Validated against the R command `chisq.test(c(4, 6, 5), p=c(1/3, 1/3, 1/3))`
+ self.assertEqual(pearson.statistic, 0.4)
+ self.assertEqual(pearson.degreesOfFreedom, 2)
+ self.assertAlmostEqual(pearson.pValue, 0.8187, 4)
+
+ # Different expected and observed sum
+ observed1 = Vectors.dense([21, 38, 43, 80])
+ expected1 = Vectors.dense([3, 5, 7, 20])
+ pearson1 = Statistics.chiSqTest(observed1, expected1)
+
+ # Results validated against the R command
+ # `chisq.test(c(21, 38, 43, 80), p=c(3/35, 1/7, 1/5, 4/7))`
+ self.assertAlmostEqual(pearson1.statistic, 14.1429, 4)
+ self.assertEqual(pearson1.degreesOfFreedom, 3)
+ self.assertAlmostEqual(pearson1.pValue, 0.002717, 4)
+
+ # Vectors with different sizes
+ observed3 = Vectors.dense([1.0, 2.0, 3.0])
+ expected3 = Vectors.dense([1.0, 2.0, 3.0, 4.0])
+ self.assertRaises(ValueError, Statistics.chiSqTest, observed3, expected3)
+
+ # Negative counts in observed
+ neg_obs = Vectors.dense([1.0, 2.0, 3.0, -4.0])
+ self.assertRaises(Py4JJavaError, Statistics.chiSqTest, neg_obs, expected1)
+
+ # Count = 0.0 in expected but not observed
+ zero_expected = Vectors.dense([1.0, 0.0, 3.0])
+ pearson_inf = Statistics.chiSqTest(observed, zero_expected)
+ self.assertEqual(pearson_inf.statistic, inf)
+ self.assertEqual(pearson_inf.degreesOfFreedom, 2)
+ self.assertEqual(pearson_inf.pValue, 0.0)
+
+ # 0.0 in expected and observed simultaneously
+ zero_observed = Vectors.dense([2.0, 0.0, 1.0])
+ self.assertRaises(Py4JJavaError, Statistics.chiSqTest, zero_observed, zero_expected)
+
+ def test_matrix_independence(self):
+ data = [40.0, 24.0, 29.0, 56.0, 32.0, 42.0, 31.0, 10.0, 0.0, 30.0, 15.0, 12.0]
+ chi = Statistics.chiSqTest(Matrices.dense(3, 4, data))
+
+ # Results validated against R command
+ # `chisq.test(rbind(c(40, 56, 31, 30),c(24, 32, 10, 15), c(29, 42, 0, 12)))`
+ self.assertAlmostEqual(chi.statistic, 21.9958, 4)
+ self.assertEqual(chi.degreesOfFreedom, 6)
+ self.assertAlmostEqual(chi.pValue, 0.001213, 4)
+
+ # Negative counts
+ neg_counts = Matrices.dense(2, 2, [4.0, 5.0, 3.0, -3.0])
+ self.assertRaises(Py4JJavaError, Statistics.chiSqTest, neg_counts)
+
+ # Row sum = 0.0
+ row_zero = Matrices.dense(2, 2, [0.0, 1.0, 0.0, 2.0])
+ self.assertRaises(Py4JJavaError, Statistics.chiSqTest, row_zero)
+
+ # Column sum = 0.0
+ col_zero = Matrices.dense(2, 2, [0.0, 0.0, 2.0, 2.0])
+ self.assertRaises(Py4JJavaError, Statistics.chiSqTest, col_zero)
+
+ def test_chi_sq_pearson(self):
+ data = [
+ LabeledPoint(0.0, Vectors.dense([0.5, 10.0])),
+ LabeledPoint(0.0, Vectors.dense([1.5, 20.0])),
+ LabeledPoint(1.0, Vectors.dense([1.5, 30.0])),
+ LabeledPoint(0.0, Vectors.dense([3.5, 30.0])),
+ LabeledPoint(0.0, Vectors.dense([3.5, 40.0])),
+ LabeledPoint(1.0, Vectors.dense([3.5, 40.0]))
+ ]
+
+ for numParts in [2, 4, 6, 8]:
+ chi = Statistics.chiSqTest(self.sc.parallelize(data, numParts))
+ feature1 = chi[0]
+ self.assertEqual(feature1.statistic, 0.75)
+ self.assertEqual(feature1.degreesOfFreedom, 2)
+ self.assertAlmostEqual(feature1.pValue, 0.6873, 4)
+
+ feature2 = chi[1]
+ self.assertEqual(feature2.statistic, 1.5)
+ self.assertEqual(feature2.degreesOfFreedom, 3)
+ self.assertAlmostEqual(feature2.pValue, 0.6823, 4)
+
+ def test_right_number_of_results(self):
+ num_cols = 1001
+ sparse_data = [
+ LabeledPoint(0.0, Vectors.sparse(num_cols, [(100, 2.0)])),
+ LabeledPoint(0.1, Vectors.sparse(num_cols, [(200, 1.0)]))
+ ]
+ chi = Statistics.chiSqTest(self.sc.parallelize(sparse_data))
+ self.assertEqual(len(chi), num_cols)
+ self.assertIsNotNone(chi[1000])
+
if __name__ == "__main__":
if not _have_scipy:
print "NOTE: Skipping SciPy tests as it does not seem to be installed"