diff options
author | Davies Liu <davies.liu@gmail.com> | 2014-09-19 15:33:42 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2014-09-19 15:33:42 -0700 |
commit | a95ad99e31c2d5980a3b8cd8e36ff968b1e6b201 (patch) | |
tree | 7700076ce88b331c17a524aa105199d90ff5656f /python/pyspark/sql.py | |
parent | 5522151eb14f4208798901f5c090868edd8e8dde (diff) | |
download | spark-a95ad99e31c2d5980a3b8cd8e36ff968b1e6b201.tar.gz spark-a95ad99e31c2d5980a3b8cd8e36ff968b1e6b201.tar.bz2 spark-a95ad99e31c2d5980a3b8cd8e36ff968b1e6b201.zip |
[SPARK-3592] [SQL] [PySpark] support applySchema to RDD of Row
Fix the issue when applySchema() to an RDD of Row.
Also add type mapping for BinaryType.
Author: Davies Liu <davies.liu@gmail.com>
Closes #2448 from davies/row and squashes the following commits:
dd220cf [Davies Liu] fix test
3f3f188 [Davies Liu] add more test
f559746 [Davies Liu] add tests, fix serialization
9688fd2 [Davies Liu] support applySchema to RDD of Row
Diffstat (limited to 'python/pyspark/sql.py')
-rw-r--r-- | python/pyspark/sql.py | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 42a9920f10..653195ea43 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -440,6 +440,7 @@ _type_mappings = { float: DoubleType, str: StringType, unicode: StringType, + bytearray: BinaryType, decimal.Decimal: DecimalType, datetime.datetime: TimestampType, datetime.date: TimestampType, @@ -690,11 +691,12 @@ _acceptable_types = { ByteType: (int, long), ShortType: (int, long), IntegerType: (int, long), - LongType: (long,), + LongType: (int, long), FloatType: (float,), DoubleType: (float,), DecimalType: (decimal.Decimal,), StringType: (str, unicode), + BinaryType: (bytearray,), TimestampType: (datetime.datetime,), ArrayType: (list, tuple, array), MapType: (dict,), @@ -728,9 +730,9 @@ def _verify_type(obj, dataType): return _type = type(dataType) - if _type not in _acceptable_types: - return + assert _type in _acceptable_types, "unkown datatype: %s" % dataType + # subclass of them can not be deserialized in JVM if type(obj) not in _acceptable_types[_type]: raise TypeError("%s can not accept abject in type %s" % (dataType, type(obj))) @@ -1121,6 +1123,11 @@ class SQLContext(object): # take the first few rows to verify schema rows = rdd.take(10) + # Row() cannot been deserialized by Pyrolite + if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row': + rdd = rdd.map(tuple) + rows = rdd.take(10) + for row in rows: _verify_type(row, schema) |