diff options
Diffstat (limited to 'python/pyspark/mllib/tests.py')
-rw-r--r-- | python/pyspark/mllib/tests.py | 19 |
1 files changed, 19 insertions, 0 deletions
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index 463faf7b6f..d6fb87b378 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -36,6 +36,8 @@ else: from pyspark.serializers import PickleSerializer from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, _convert_to_vector from pyspark.mllib.regression import LabeledPoint +from pyspark.mllib.random import RandomRDDs +from pyspark.mllib.stat import Statistics from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase @@ -202,6 +204,23 @@ class ListTests(PySparkTestCase): self.assertTrue(dt_model.predict(features[3]) > 0) +class StatTests(PySparkTestCase): + # SPARK-4023 + def test_col_with_different_rdds(self): + # numpy + data = RandomRDDs.normalVectorRDD(self.sc, 1000, 10, 10) + summary = Statistics.colStats(data) + self.assertEqual(1000, summary.count()) + # array + data = self.sc.parallelize([range(10)] * 10) + summary = Statistics.colStats(data) + self.assertEqual(10, summary.count()) + # array + data = self.sc.parallelize([pyarray.array("d", range(10))] * 10) + summary = Statistics.colStats(data) + self.assertEqual(10, summary.count()) + + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): |