aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql.py
diff options
context:
space:
mode:
authorDavies Liu <davies.liu@gmail.com>2014-09-19 15:33:42 -0700
committerMichael Armbrust <michael@databricks.com>2014-09-19 15:33:42 -0700
commita95ad99e31c2d5980a3b8cd8e36ff968b1e6b201 (patch)
tree7700076ce88b331c17a524aa105199d90ff5656f /python/pyspark/sql.py
parent5522151eb14f4208798901f5c090868edd8e8dde (diff)
downloadspark-a95ad99e31c2d5980a3b8cd8e36ff968b1e6b201.tar.gz
spark-a95ad99e31c2d5980a3b8cd8e36ff968b1e6b201.tar.bz2
spark-a95ad99e31c2d5980a3b8cd8e36ff968b1e6b201.zip
[SPARK-3592] [SQL] [PySpark] support applySchema to RDD of Row
Fix the issue when applySchema() to an RDD of Row. Also add type mapping for BinaryType. Author: Davies Liu <davies.liu@gmail.com> Closes #2448 from davies/row and squashes the following commits: dd220cf [Davies Liu] fix test 3f3f188 [Davies Liu] add more test f559746 [Davies Liu] add tests, fix serialization 9688fd2 [Davies Liu] support applySchema to RDD of Row
Diffstat (limited to 'python/pyspark/sql.py')
-rw-r--r--python/pyspark/sql.py13
1 files changed, 10 insertions, 3 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 42a9920f10..653195ea43 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -440,6 +440,7 @@ _type_mappings = {
float: DoubleType,
str: StringType,
unicode: StringType,
+ bytearray: BinaryType,
decimal.Decimal: DecimalType,
datetime.datetime: TimestampType,
datetime.date: TimestampType,
@@ -690,11 +691,12 @@ _acceptable_types = {
ByteType: (int, long),
ShortType: (int, long),
IntegerType: (int, long),
- LongType: (long,),
+ LongType: (int, long),
FloatType: (float,),
DoubleType: (float,),
DecimalType: (decimal.Decimal,),
StringType: (str, unicode),
+ BinaryType: (bytearray,),
TimestampType: (datetime.datetime,),
ArrayType: (list, tuple, array),
MapType: (dict,),
@@ -728,9 +730,9 @@ def _verify_type(obj, dataType):
return
_type = type(dataType)
- if _type not in _acceptable_types:
- return
+ assert _type in _acceptable_types, "unkown datatype: %s" % dataType
+ # subclass of them can not be deserialized in JVM
if type(obj) not in _acceptable_types[_type]:
raise TypeError("%s can not accept abject in type %s"
% (dataType, type(obj)))
@@ -1121,6 +1123,11 @@ class SQLContext(object):
# take the first few rows to verify schema
rows = rdd.take(10)
+ # Row() cannot been deserialized by Pyrolite
+ if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row':
+ rdd = rdd.map(tuple)
+ rows = rdd.take(10)
+
for row in rows:
_verify_type(row, schema)