From a95ad99e31c2d5980a3b8cd8e36ff968b1e6b201 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 19 Sep 2014 15:33:42 -0700 Subject: [SPARK-3592] [SQL] [PySpark] support applySchema to RDD of Row Fix the issue when applySchema() to an RDD of Row. Also add type mapping for BinaryType. Author: Davies Liu Closes #2448 from davies/row and squashes the following commits: dd220cf [Davies Liu] fix test 3f3f188 [Davies Liu] add more test f559746 [Davies Liu] add tests, fix serialization 9688fd2 [Davies Liu] support applySchema to RDD of Row --- python/pyspark/sql.py | 13 ++++++++++--- python/pyspark/tests.py | 11 ++++++++++- 2 files changed, 20 insertions(+), 4 deletions(-) (limited to 'python') diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 42a9920f10..653195ea43 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -440,6 +440,7 @@ _type_mappings = { float: DoubleType, str: StringType, unicode: StringType, + bytearray: BinaryType, decimal.Decimal: DecimalType, datetime.datetime: TimestampType, datetime.date: TimestampType, @@ -690,11 +691,12 @@ _acceptable_types = { ByteType: (int, long), ShortType: (int, long), IntegerType: (int, long), - LongType: (long,), + LongType: (int, long), FloatType: (float,), DoubleType: (float,), DecimalType: (decimal.Decimal,), StringType: (str, unicode), + BinaryType: (bytearray,), TimestampType: (datetime.datetime,), ArrayType: (list, tuple, array), MapType: (dict,), @@ -728,9 +730,9 @@ def _verify_type(obj, dataType): return _type = type(dataType) - if _type not in _acceptable_types: - return + assert _type in _acceptable_types, "unkown datatype: %s" % dataType + # subclass of them can not be deserialized in JVM if type(obj) not in _acceptable_types[_type]: raise TypeError("%s can not accept abject in type %s" % (dataType, type(obj))) @@ -1121,6 +1123,11 @@ class SQLContext(object): # take the first few rows to verify schema rows = rdd.take(10) + # Row() cannot been deserialized by Pyrolite + if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row': + rdd = rdd.map(tuple) + rows = rdd.take(10) + for row in rows: _verify_type(row, schema) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 7301966e48..a94eb0f429 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -45,7 +45,7 @@ from pyspark.files import SparkFiles from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ CloudPickleSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter -from pyspark.sql import SQLContext, IntegerType +from pyspark.sql import SQLContext, IntegerType, Row from pyspark import shuffle _have_scipy = False @@ -659,6 +659,15 @@ class TestSQL(PySparkTestCase): self.assertEquals(result.getNumPartitions(), 5) self.assertEquals(result.count(), 3) + def test_apply_schema_to_row(self): + srdd = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""])) + srdd2 = self.sqlCtx.applySchema(srdd.map(lambda x: x), srdd.schema()) + self.assertEqual(srdd.collect(), srdd2.collect()) + + rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x)) + srdd3 = self.sqlCtx.applySchema(rdd, srdd.schema()) + self.assertEqual(10, srdd3.count()) + class TestIO(PySparkTestCase): -- cgit v1.2.3