diff options
author | Xiangrui Meng <meng@databricks.com> | 2014-11-03 22:29:48 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2014-11-03 22:29:48 -0800 |
commit | 1a9c6cddadebdc53d083ac3e0da276ce979b5d1f (patch) | |
tree | b485818ba52a9287ae7124e57ef55f1d974f3a1f /python/pyspark/mllib/tests.py | |
parent | 04450d11548cfb25d4fb77d4a33e3a7cd4254183 (diff) | |
download | spark-1a9c6cddadebdc53d083ac3e0da276ce979b5d1f.tar.gz spark-1a9c6cddadebdc53d083ac3e0da276ce979b5d1f.tar.bz2 spark-1a9c6cddadebdc53d083ac3e0da276ce979b5d1f.zip |
[SPARK-3573][MLLIB] Make MLlib's Vector compatible with SQL's SchemaRDD
Register MLlib's Vector as a SQL user-defined type (UDT) in both Scala and Python. With this PR, we can easily map a RDD[LabeledPoint] to a SchemaRDD, and then select columns or save to a Parquet file. Examples in Scala/Python are attached. The Scala code was copied from jkbradley.
~~This PR contains the changes from #3068 . I will rebase after #3068 is merged.~~
marmbrus jkbradley
Author: Xiangrui Meng <meng@databricks.com>
Closes #3070 from mengxr/SPARK-3573 and squashes the following commits:
3a0b6e5 [Xiangrui Meng] organize imports
236f0a0 [Xiangrui Meng] register vector as UDT and provide dataset examples
Diffstat (limited to 'python/pyspark/mllib/tests.py')
-rw-r--r-- | python/pyspark/mllib/tests.py | 39 |
1 files changed, 36 insertions, 3 deletions
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py index d6fb87b378..9fa4d6f6a2 100644 --- a/python/pyspark/mllib/tests.py +++ b/python/pyspark/mllib/tests.py @@ -33,14 +33,14 @@ if sys.version_info[:2] <= (2, 6): else: import unittest -from pyspark.serializers import PickleSerializer -from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, _convert_to_vector +from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector from pyspark.mllib.regression import LabeledPoint from pyspark.mllib.random import RandomRDDs from pyspark.mllib.stat import Statistics +from pyspark.serializers import PickleSerializer +from pyspark.sql import SQLContext from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase - _have_scipy = False try: import scipy.sparse @@ -221,6 +221,39 @@ class StatTests(PySparkTestCase): self.assertEqual(10, summary.count()) +class VectorUDTTests(PySparkTestCase): + + dv0 = DenseVector([]) + dv1 = DenseVector([1.0, 2.0]) + sv0 = SparseVector(2, [], []) + sv1 = SparseVector(2, [1], [2.0]) + udt = VectorUDT() + + def test_json_schema(self): + self.assertEqual(VectorUDT.fromJson(self.udt.jsonValue()), self.udt) + + def test_serialization(self): + for v in [self.dv0, self.dv1, self.sv0, self.sv1]: + self.assertEqual(v, self.udt.deserialize(self.udt.serialize(v))) + + def test_infer_schema(self): + sqlCtx = SQLContext(self.sc) + rdd = self.sc.parallelize([LabeledPoint(1.0, self.dv1), LabeledPoint(0.0, self.sv1)]) + srdd = sqlCtx.inferSchema(rdd) + schema = srdd.schema() + field = [f for f in schema.fields if f.name == "features"][0] + self.assertEqual(field.dataType, self.udt) + vectors = srdd.map(lambda p: p.features).collect() + self.assertEqual(len(vectors), 2) + for v in vectors: + if isinstance(v, SparseVector): + self.assertEqual(v, self.sv1) + elif isinstance(v, DenseVector): + self.assertEqual(v, self.dv1) + else: + raise ValueError("expecting a vector but got %r of type %r" % (v, type(v))) + + @unittest.skipIf(not _have_scipy, "SciPy not installed") class SciPyTests(PySparkTestCase): |