aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/tests.py
diff options
context:
space:
mode:
authorXiangrui Meng <meng@databricks.com>2014-11-03 19:29:11 -0800
committerXiangrui Meng <meng@databricks.com>2014-11-03 19:29:11 -0800
commit04450d11548cfb25d4fb77d4a33e3a7cd4254183 (patch)
tree13b5c6fd1ac8c400cf59a51fe0b84b60c64d400f /python/pyspark/tests.py
parentc5912ecc7b392a13089ae735c07c2d7256de36c6 (diff)
downloadspark-04450d11548cfb25d4fb77d4a33e3a7cd4254183.tar.gz
spark-04450d11548cfb25d4fb77d4a33e3a7cd4254183.tar.bz2
spark-04450d11548cfb25d4fb77d4a33e3a7cd4254183.zip
[SPARK-4192][SQL] Internal API for Python UDT
Following #2919, this PR adds Python UDT (for internal use only) with tests under "pyspark.tests". Before `SQLContext.applySchema`, we check whether we need to convert user-type instances into SQL recognizable data. In the current implementation, a Python UDT must be paired with a Scala UDT for serialization on the JVM side. A following PR will add VectorUDT in MLlib for both Scala and Python. marmbrus jkbradley davies Author: Xiangrui Meng <meng@databricks.com> Closes #3068 from mengxr/SPARK-4192-sql and squashes the following commits: acff637 [Xiangrui Meng] merge master dba5ea7 [Xiangrui Meng] only use pyClass for Python UDT output sqlType as well 2c9d7e4 [Xiangrui Meng] move import to global setup; update needsConversion 7c4a6a9 [Xiangrui Meng] address comments 75223db [Xiangrui Meng] minor update f740379 [Xiangrui Meng] remove UDT from default imports e98d9d0 [Xiangrui Meng] fix py style 4e84fce [Xiangrui Meng] remove local hive tests and add more tests 39f19e0 [Xiangrui Meng] add tests b7f666d [Xiangrui Meng] add Python UDT
Diffstat (limited to 'python/pyspark/tests.py')
-rw-r--r--python/pyspark/tests.py93
1 files changed, 92 insertions, 1 deletions
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 68fd756876..e947b09468 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -49,7 +49,8 @@ from pyspark.files import SparkFiles
from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \
CloudPickleSerializer
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
-from pyspark.sql import SQLContext, IntegerType, Row, ArrayType
+from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \
+ UserDefinedType, DoubleType
from pyspark import shuffle
_have_scipy = False
@@ -694,8 +695,65 @@ class ProfilerTests(PySparkTestCase):
self.assertTrue("rdd_%d.pstats" % id in os.listdir(d))
+class ExamplePointUDT(UserDefinedType):
+ """
+ User-defined type (UDT) for ExamplePoint.
+ """
+
+ @classmethod
+ def sqlType(self):
+ return ArrayType(DoubleType(), False)
+
+ @classmethod
+ def module(cls):
+ return 'pyspark.tests'
+
+ @classmethod
+ def scalaUDT(cls):
+ return 'org.apache.spark.sql.test.ExamplePointUDT'
+
+ def serialize(self, obj):
+ return [obj.x, obj.y]
+
+ def deserialize(self, datum):
+ return ExamplePoint(datum[0], datum[1])
+
+
+class ExamplePoint:
+ """
+ An example class to demonstrate UDT in Scala, Java, and Python.
+ """
+
+ __UDT__ = ExamplePointUDT()
+
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+
+ def __repr__(self):
+ return "ExamplePoint(%s,%s)" % (self.x, self.y)
+
+ def __str__(self):
+ return "(%s,%s)" % (self.x, self.y)
+
+ def __eq__(self, other):
+ return isinstance(other, ExamplePoint) and \
+ other.x == self.x and other.y == self.y
+
+
class SQLTests(ReusedPySparkTestCase):
+ @classmethod
+ def setUpClass(cls):
+ ReusedPySparkTestCase.setUpClass()
+ cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
+ os.unlink(cls.tempdir.name)
+
+ @classmethod
+ def tearDownClass(cls):
+ ReusedPySparkTestCase.tearDownClass()
+ shutil.rmtree(cls.tempdir.name)
+
def setUp(self):
self.sqlCtx = SQLContext(self.sc)
@@ -824,6 +882,39 @@ class SQLTests(ReusedPySparkTestCase):
row = self.sqlCtx.sql("select l[0].a AS la from test").first()
self.assertEqual(1, row.asDict()["la"])
+ def test_infer_schema_with_udt(self):
+ from pyspark.tests import ExamplePoint, ExamplePointUDT
+ row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
+ rdd = self.sc.parallelize([row])
+ srdd = self.sqlCtx.inferSchema(rdd)
+ schema = srdd.schema()
+ field = [f for f in schema.fields if f.name == "point"][0]
+ self.assertEqual(type(field.dataType), ExamplePointUDT)
+ srdd.registerTempTable("labeled_point")
+ point = self.sqlCtx.sql("SELECT point FROM labeled_point").first().point
+ self.assertEqual(point, ExamplePoint(1.0, 2.0))
+
+ def test_apply_schema_with_udt(self):
+ from pyspark.tests import ExamplePoint, ExamplePointUDT
+ row = (1.0, ExamplePoint(1.0, 2.0))
+ rdd = self.sc.parallelize([row])
+ schema = StructType([StructField("label", DoubleType(), False),
+ StructField("point", ExamplePointUDT(), False)])
+ srdd = self.sqlCtx.applySchema(rdd, schema)
+ point = srdd.first().point
+ self.assertEquals(point, ExamplePoint(1.0, 2.0))
+
+ def test_parquet_with_udt(self):
+ from pyspark.tests import ExamplePoint
+ row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
+ rdd = self.sc.parallelize([row])
+ srdd0 = self.sqlCtx.inferSchema(rdd)
+ output_dir = os.path.join(self.tempdir.name, "labeled_point")
+ srdd0.saveAsParquetFile(output_dir)
+ srdd1 = self.sqlCtx.parquetFile(output_dir)
+ point = srdd1.first().point
+ self.assertEquals(point, ExamplePoint(1.0, 2.0))
+
class InputFormatTests(ReusedPySparkTestCase):