aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorDavies Liu <davies.liu@gmail.com>2014-09-19 15:33:42 -0700
committerMichael Armbrust <michael@databricks.com>2014-09-19 15:33:42 -0700
commita95ad99e31c2d5980a3b8cd8e36ff968b1e6b201 (patch)
tree7700076ce88b331c17a524aa105199d90ff5656f /python
parent5522151eb14f4208798901f5c090868edd8e8dde (diff)
downloadspark-a95ad99e31c2d5980a3b8cd8e36ff968b1e6b201.tar.gz
spark-a95ad99e31c2d5980a3b8cd8e36ff968b1e6b201.tar.bz2
spark-a95ad99e31c2d5980a3b8cd8e36ff968b1e6b201.zip
[SPARK-3592] [SQL] [PySpark] support applySchema to RDD of Row
Fix the issue when applySchema() to an RDD of Row. Also add type mapping for BinaryType. Author: Davies Liu <davies.liu@gmail.com> Closes #2448 from davies/row and squashes the following commits: dd220cf [Davies Liu] fix test 3f3f188 [Davies Liu] add more test f559746 [Davies Liu] add tests, fix serialization 9688fd2 [Davies Liu] support applySchema to RDD of Row
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql.py13
-rw-r--r--python/pyspark/tests.py11
2 files changed, 20 insertions, 4 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 42a9920f10..653195ea43 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -440,6 +440,7 @@ _type_mappings = {
float: DoubleType,
str: StringType,
unicode: StringType,
+ bytearray: BinaryType,
decimal.Decimal: DecimalType,
datetime.datetime: TimestampType,
datetime.date: TimestampType,
@@ -690,11 +691,12 @@ _acceptable_types = {
ByteType: (int, long),
ShortType: (int, long),
IntegerType: (int, long),
- LongType: (long,),
+ LongType: (int, long),
FloatType: (float,),
DoubleType: (float,),
DecimalType: (decimal.Decimal,),
StringType: (str, unicode),
+ BinaryType: (bytearray,),
TimestampType: (datetime.datetime,),
ArrayType: (list, tuple, array),
MapType: (dict,),
@@ -728,9 +730,9 @@ def _verify_type(obj, dataType):
return
_type = type(dataType)
- if _type not in _acceptable_types:
- return
+ assert _type in _acceptable_types, "unkown datatype: %s" % dataType
+ # subclass of them can not be deserialized in JVM
if type(obj) not in _acceptable_types[_type]:
raise TypeError("%s can not accept abject in type %s"
% (dataType, type(obj)))
@@ -1121,6 +1123,11 @@ class SQLContext(object):
# take the first few rows to verify schema
rows = rdd.take(10)
+ # Row() cannot been deserialized by Pyrolite
+ if rows and isinstance(rows[0], tuple) and rows[0].__class__.__name__ == 'Row':
+ rdd = rdd.map(tuple)
+ rows = rdd.take(10)
+
for row in rows:
_verify_type(row, schema)
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 7301966e48..a94eb0f429 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -45,7 +45,7 @@ from pyspark.files import SparkFiles
from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \
CloudPickleSerializer
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
-from pyspark.sql import SQLContext, IntegerType
+from pyspark.sql import SQLContext, IntegerType, Row
from pyspark import shuffle
_have_scipy = False
@@ -659,6 +659,15 @@ class TestSQL(PySparkTestCase):
self.assertEquals(result.getNumPartitions(), 5)
self.assertEquals(result.count(), 3)
+ def test_apply_schema_to_row(self):
+ srdd = self.sqlCtx.jsonRDD(self.sc.parallelize(["""{"a":2}"""]))
+ srdd2 = self.sqlCtx.applySchema(srdd.map(lambda x: x), srdd.schema())
+ self.assertEqual(srdd.collect(), srdd2.collect())
+
+ rdd = self.sc.parallelize(range(10)).map(lambda x: Row(a=x))
+ srdd3 = self.sqlCtx.applySchema(rdd, srdd.schema())
+ self.assertEqual(10, srdd3.count())
+
class TestIO(PySparkTestCase):