aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/mllib/tests.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/mllib/tests.py')
-rw-r--r--python/pyspark/mllib/tests.py39
1 files changed, 36 insertions, 3 deletions
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index d6fb87b378..9fa4d6f6a2 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -33,14 +33,14 @@ if sys.version_info[:2] <= (2, 6):
else:
import unittest
-from pyspark.serializers import PickleSerializer
-from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, _convert_to_vector
+from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.random import RandomRDDs
from pyspark.mllib.stat import Statistics
+from pyspark.serializers import PickleSerializer
+from pyspark.sql import SQLContext
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
-
_have_scipy = False
try:
import scipy.sparse
@@ -221,6 +221,39 @@ class StatTests(PySparkTestCase):
self.assertEqual(10, summary.count())
+class VectorUDTTests(PySparkTestCase):
+
+ dv0 = DenseVector([])
+ dv1 = DenseVector([1.0, 2.0])
+ sv0 = SparseVector(2, [], [])
+ sv1 = SparseVector(2, [1], [2.0])
+ udt = VectorUDT()
+
+ def test_json_schema(self):
+ self.assertEqual(VectorUDT.fromJson(self.udt.jsonValue()), self.udt)
+
+ def test_serialization(self):
+ for v in [self.dv0, self.dv1, self.sv0, self.sv1]:
+ self.assertEqual(v, self.udt.deserialize(self.udt.serialize(v)))
+
+ def test_infer_schema(self):
+ sqlCtx = SQLContext(self.sc)
+ rdd = self.sc.parallelize([LabeledPoint(1.0, self.dv1), LabeledPoint(0.0, self.sv1)])
+ srdd = sqlCtx.inferSchema(rdd)
+ schema = srdd.schema()
+ field = [f for f in schema.fields if f.name == "features"][0]
+ self.assertEqual(field.dataType, self.udt)
+ vectors = srdd.map(lambda p: p.features).collect()
+ self.assertEqual(len(vectors), 2)
+ for v in vectors:
+ if isinstance(v, SparseVector):
+ self.assertEqual(v, self.sv1)
+ elif isinstance(v, DenseVector):
+ self.assertEqual(v, self.dv1)
+ else:
+ raise ValueError("expecting a vector but got %r of type %r" % (v, type(v)))
+
+
@unittest.skipIf(not _have_scipy, "SciPy not installed")
class SciPyTests(PySparkTestCase):