aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/tests.py
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2014-10-21 09:29:45 -0700
committerXiangrui Meng <meng@databricks.com>2014-10-21 09:29:45 -0700
commit85708168341a9406c451df20af3374c0850ce166 (patch)
tree186ef9febf3e61c1c9a6825dc0b6378f203d7e8f /python/pyspark/mllib/tests.py
parent5a8f64f33632fbf89d16cade2e0e66c5ed60760b (diff)
downloadspark-85708168341a9406c451df20af3374c0850ce166.tar.gz
spark-85708168341a9406c451df20af3374c0850ce166.tar.bz2
spark-85708168341a9406c451df20af3374c0850ce166.zip
[SPARK-4023] [MLlib] [PySpark] convert rdd into RDD of Vector
Convert the input rdd to RDD of Vector. cc mengxr Author: Davies Liu <davies@databricks.com> Closes #2870 from davies/fix4023 and squashes the following commits: 1eac767 [Davies Liu] address comments 0871576 [Davies Liu] convert rdd into RDD of Vector
Diffstat (limited to 'python/pyspark/mllib/tests.py')
-rw-r--r--python/pyspark/mllib/tests.py19
1 files changed, 19 insertions, 0 deletions
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index 463faf7b6f..d6fb87b378 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -36,6 +36,8 @@ else:
from pyspark.serializers import PickleSerializer
from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, _convert_to_vector
from pyspark.mllib.regression import LabeledPoint
+from pyspark.mllib.random import RandomRDDs
+from pyspark.mllib.stat import Statistics
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
@@ -202,6 +204,23 @@ class ListTests(PySparkTestCase):
self.assertTrue(dt_model.predict(features[3]) > 0)
+class StatTests(PySparkTestCase):
+ # SPARK-4023
+ def test_col_with_different_rdds(self):
+ # numpy
+ data = RandomRDDs.normalVectorRDD(self.sc, 1000, 10, 10)
+ summary = Statistics.colStats(data)
+ self.assertEqual(1000, summary.count())
+ # array
+ data = self.sc.parallelize([range(10)] * 10)
+ summary = Statistics.colStats(data)
+ self.assertEqual(10, summary.count())
+ # array
+ data = self.sc.parallelize([pyarray.array("d", range(10))] * 10)
+ summary = Statistics.colStats(data)
+ self.assertEqual(10, summary.count())
+
+
@unittest.skipIf(not _have_scipy, "SciPy not installed")
class SciPyTests(PySparkTestCase):