diff options
author | Xiangrui Meng <meng@databricks.com> | 2014-11-03 19:29:11 -0800 |
---|---|---|
committer | Xiangrui Meng <meng@databricks.com> | 2014-11-03 19:29:11 -0800 |
commit | 04450d11548cfb25d4fb77d4a33e3a7cd4254183 (patch) | |
tree | 13b5c6fd1ac8c400cf59a51fe0b84b60c64d400f /python/pyspark/tests.py | |
parent | c5912ecc7b392a13089ae735c07c2d7256de36c6 (diff) | |
download | spark-04450d11548cfb25d4fb77d4a33e3a7cd4254183.tar.gz spark-04450d11548cfb25d4fb77d4a33e3a7cd4254183.tar.bz2 spark-04450d11548cfb25d4fb77d4a33e3a7cd4254183.zip |
[SPARK-4192][SQL] Internal API for Python UDT
Following #2919, this PR adds Python UDT (for internal use only) with tests under "pyspark.tests". Before `SQLContext.applySchema`, we check whether we need to convert user-type instances into SQL recognizable data. In the current implementation, a Python UDT must be paired with a Scala UDT for serialization on the JVM side. A following PR will add VectorUDT in MLlib for both Scala and Python.
marmbrus jkbradley davies
Author: Xiangrui Meng <meng@databricks.com>
Closes #3068 from mengxr/SPARK-4192-sql and squashes the following commits:
acff637 [Xiangrui Meng] merge master
dba5ea7 [Xiangrui Meng] only use pyClass for Python UDT output sqlType as well
2c9d7e4 [Xiangrui Meng] move import to global setup; update needsConversion
7c4a6a9 [Xiangrui Meng] address comments
75223db [Xiangrui Meng] minor update
f740379 [Xiangrui Meng] remove UDT from default imports
e98d9d0 [Xiangrui Meng] fix py style
4e84fce [Xiangrui Meng] remove local hive tests and add more tests
39f19e0 [Xiangrui Meng] add tests
b7f666d [Xiangrui Meng] add Python UDT
Diffstat (limited to 'python/pyspark/tests.py')
-rw-r--r-- | python/pyspark/tests.py | 93 |
1 files changed, 92 insertions, 1 deletions
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 68fd756876..e947b09468 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -49,7 +49,8 @@ from pyspark.files import SparkFiles from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ CloudPickleSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter -from pyspark.sql import SQLContext, IntegerType, Row, ArrayType +from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \ + UserDefinedType, DoubleType from pyspark import shuffle _have_scipy = False @@ -694,8 +695,65 @@ class ProfilerTests(PySparkTestCase): self.assertTrue("rdd_%d.pstats" % id in os.listdir(d)) +class ExamplePointUDT(UserDefinedType): + """ + User-defined type (UDT) for ExamplePoint. + """ + + @classmethod + def sqlType(self): + return ArrayType(DoubleType(), False) + + @classmethod + def module(cls): + return 'pyspark.tests' + + @classmethod + def scalaUDT(cls): + return 'org.apache.spark.sql.test.ExamplePointUDT' + + def serialize(self, obj): + return [obj.x, obj.y] + + def deserialize(self, datum): + return ExamplePoint(datum[0], datum[1]) + + +class ExamplePoint: + """ + An example class to demonstrate UDT in Scala, Java, and Python. + """ + + __UDT__ = ExamplePointUDT() + + def __init__(self, x, y): + self.x = x + self.y = y + + def __repr__(self): + return "ExamplePoint(%s,%s)" % (self.x, self.y) + + def __str__(self): + return "(%s,%s)" % (self.x, self.y) + + def __eq__(self, other): + return isinstance(other, ExamplePoint) and \ + other.x == self.x and other.y == self.y + + class SQLTests(ReusedPySparkTestCase): + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(cls.tempdir.name) + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + shutil.rmtree(cls.tempdir.name) + def setUp(self): self.sqlCtx = SQLContext(self.sc) @@ -824,6 +882,39 @@ class SQLTests(ReusedPySparkTestCase): row = self.sqlCtx.sql("select l[0].a AS la from test").first() self.assertEqual(1, row.asDict()["la"]) + def test_infer_schema_with_udt(self): + from pyspark.tests import ExamplePoint, ExamplePointUDT + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize([row]) + srdd = self.sqlCtx.inferSchema(rdd) + schema = srdd.schema() + field = [f for f in schema.fields if f.name == "point"][0] + self.assertEqual(type(field.dataType), ExamplePointUDT) + srdd.registerTempTable("labeled_point") + point = self.sqlCtx.sql("SELECT point FROM labeled_point").first().point + self.assertEqual(point, ExamplePoint(1.0, 2.0)) + + def test_apply_schema_with_udt(self): + from pyspark.tests import ExamplePoint, ExamplePointUDT + row = (1.0, ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize([row]) + schema = StructType([StructField("label", DoubleType(), False), + StructField("point", ExamplePointUDT(), False)]) + srdd = self.sqlCtx.applySchema(rdd, schema) + point = srdd.first().point + self.assertEquals(point, ExamplePoint(1.0, 2.0)) + + def test_parquet_with_udt(self): + from pyspark.tests import ExamplePoint + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize([row]) + srdd0 = self.sqlCtx.inferSchema(rdd) + output_dir = os.path.join(self.tempdir.name, "labeled_point") + srdd0.saveAsParquetFile(output_dir) + srdd1 = self.sqlCtx.parquetFile(output_dir) + point = srdd1.first().point + self.assertEquals(point, ExamplePoint(1.0, 2.0)) + class InputFormatTests(ReusedPySparkTestCase): |