diff options
author | Davies Liu <davies.liu@gmail.com> | 2014-09-19 15:33:42 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2014-09-19 15:33:42 -0700 |
commit | a95ad99e31c2d5980a3b8cd8e36ff968b1e6b201 (patch) | |
tree | 7700076ce88b331c17a524aa105199d90ff5656f | |
parent | 5522151eb14f4208798901f5c090868edd8e8dde (diff) | |
download | spark-a95ad99e31c2d5980a3b8cd8e36ff968b1e6b201.tar.gz spark-a95ad99e31c2d5980a3b8cd8e36ff968b1e6b201.tar.bz2 spark-a95ad99e31c2d5980a3b8cd8e36ff968b1e6b201.zip |
[SPARK-3592] [SQL] [PySpark] support applySchema to RDD of Row
Fix the issue when applySchema() to an RDD of Row.
Also add type mapping for BinaryType.
Author: Davies Liu <davies.liu@gmail.com>
Closes #2448 from davies/row and squashes the following commits:
dd220cf [Davies Liu] fix test
3f3f188 [Davies Liu] add more test
f559746 [Davies Liu] add tests, fix serialization
9688fd2 [Davies Liu] support applySchema to RDD of Row
-rw-r--r-- | python/pyspark/sql.py | 13 | ||||
-rw-r--r-- | python/pyspark/tests.py | 11 |
2 files changed, 20 insertions, 4 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 42a9920f10..653195ea43 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -440,6 +440,7 @@ _type_mappings = { float: DoubleType, str: StringType, unicode: StringType, + bytearray: BinaryType, decimal.Decimal: DecimalType, datetime.datetime: TimestampType, datetime.date: TimestampType, @@ -690,11 +691,12 @@ _acceptable_types = { ByteType: (int, long), ShortType: (int, long), IntegerType: (int, long), - LongType: (long,), + LongType: (int, long), FloatType: (float,), DoubleType: (float,), DecimalType: (decimal.Decimal,), StringType: (str, unicode), + BinaryType: (bytearray,), TimestampType: (datetime.datetime,), ArrayType: (list, tuple, array), MapType: (dict,), @@ -728,9 +730,9 @@ def _verify_type(obj, dataType): return _type = type(dataType) - if _type not in _acceptable_types: - return + assert _type in _acceptable_types, "unkown datatype: %s" % dataType + # subclass of them can not be deserialized in JVM if type(obj) not in _acceptable_types[_type]: raise TypeError("%s can not accept abject in type %s" % (dataType, type(obj))) @@ -1121,6 +1123,11 @@ class SQLContext(object): # take the first few rows to verify schema rows = rdd.take(10) + # Row() cannot been deserialized by Pyrolite + if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row': + rdd = rdd.map(tuple) + rows = rdd.take(10) + for row in rows: _verify_type(row, schema) diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py index 7301966e48..a94eb0f429 100644 --- a/python/pyspark/tests.py +++ b/python/pyspark/tests.py @@ -45,7 +45,7 @@ from pyspark.files import SparkFiles from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \ CloudPickleSerializer from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter -from pyspark.sql import SQLContext, IntegerType +from pyspark.sql import SQLContext, IntegerType, Row from pyspark import shuffle _have_scipy = False @@ -659,6 +659,15 @@ class TestSQL(PySparkTestCase): self.assertEquals(result.getNumPartitions(), 5) self.assertEquals(result.count(), 3) + def test_apply_schema_to_row(self): + srdd = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""])) + srdd2 = self.sqlCtx.applySchema(srdd.map(lambda x: x), srdd.schema()) + self.assertEqual(srdd.collect(), srdd2.collect()) + + rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x)) + srdd3 = self.sqlCtx.applySchema(rdd, srdd.schema()) + self.assertEqual(10, srdd3.count()) + class TestIO(PySparkTestCase): |