diff options
Diffstat (limited to 'python/pyspark/tests.py')
-rw-r--r-- | python/pyspark/tests.py | 93 |
1 files changed, 92 insertions, 1 deletions
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 68fd756876..e947b09468 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -49,7 +49,8 @@ from pyspark.files import SparkFiles from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ CloudPickleSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter -from pyspark.sql import SQLContext, IntegerType, Row, ArrayType +from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \ + UserDefinedType, DoubleType from pyspark import shuffle _have_scipy = False @@ -694,8 +695,65 @@ class ProfilerTests(PySparkTestCase): self.assertTrue("rdd_%d.pstats" % id in os.listdir(d)) +class ExamplePointUDT(UserDefinedType): + """ + User-defined type (UDT) for ExamplePoint. + """ + + @classmethod + def sqlType(self): + return ArrayType(DoubleType(), False) + + @classmethod + def module(cls): + return 'pyspark.tests' + + @classmethod + def scalaUDT(cls): + return 'org.apache.spark.sql.test.ExamplePointUDT' + + def serialize(self, obj): + return [obj.x, obj.y] + + def deserialize(self, datum): + return ExamplePoint(datum[0], datum[1]) + + +class ExamplePoint: + """ + An example class to demonstrate UDT in Scala, Java, and Python. + """ + + __UDT__ = ExamplePointUDT() + + def __init__(self, x, y): + self.x = x + self.y = y + + def __repr__(self): + return "ExamplePoint(%s,%s)" % (self.x, self.y) + + def __str__(self): + return "(%s,%s)" % (self.x, self.y) + + def __eq__(self, other): + return isinstance(other, ExamplePoint) and \ + other.x == self.x and other.y == self.y + + class SQLTests(ReusedPySparkTestCase): + @classmethod + def setUpClass(cls): + ReusedPySparkTestCase.setUpClass() + cls.tempdir = tempfile.NamedTemporaryFile(delete=False) + os.unlink(cls.tempdir.name) + + @classmethod + def tearDownClass(cls): + ReusedPySparkTestCase.tearDownClass() + shutil.rmtree(cls.tempdir.name) + def setUp(self): self.sqlCtx = SQLContext(self.sc) @@ -824,6 +882,39 @@ class SQLTests(ReusedPySparkTestCase): row = self.sqlCtx.sql("select l[0].a AS la from test").first() self.assertEqual(1, row.asDict()["la"]) + def test_infer_schema_with_udt(self): + from pyspark.tests import ExamplePoint, ExamplePointUDT + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize([row]) + srdd = self.sqlCtx.inferSchema(rdd) + schema = srdd.schema() + field = [f for f in schema.fields if f.name == "point"][0] + self.assertEqual(type(field.dataType), ExamplePointUDT) + srdd.registerTempTable("labeled_point") + point = self.sqlCtx.sql("SELECT point FROM labeled_point").first().point + self.assertEqual(point, ExamplePoint(1.0, 2.0)) + + def test_apply_schema_with_udt(self): + from pyspark.tests import ExamplePoint, ExamplePointUDT + row = (1.0, ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize([row]) + schema = StructType([StructField("label", DoubleType(), False), + StructField("point", ExamplePointUDT(), False)]) + srdd = self.sqlCtx.applySchema(rdd, schema) + point = srdd.first().point + self.assertEquals(point, ExamplePoint(1.0, 2.0)) + + def test_parquet_with_udt(self): + from pyspark.tests import ExamplePoint + row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) + rdd = self.sc.parallelize([row]) + srdd0 = self.sqlCtx.inferSchema(rdd) + output_dir = os.path.join(self.tempdir.name, "labeled_point") + srdd0.saveAsParquetFile(output_dir) + srdd1 = self.sqlCtx.parquetFile(output_dir) + point = srdd1.first().point + self.assertEquals(point, ExamplePoint(1.0, 2.0)) + class InputFormatTests(ReusedPySparkTestCase): |