aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorCheng Lian <lian.cs.zju@gmail.com>2014-10-08 17:04:49 -0700
committerMichael Armbrust <michael@databricks.com>2014-10-08 17:04:49 -0700
commita42cc08d219c579019f613faa8d310e6069c06fe (patch)
tree47adb5abf147cd477a88e33524de43b29379b990 /python
parenta85f24accd3266e0f97ee04d03c22b593d99c062 (diff)
downloadspark-a42cc08d219c579019f613faa8d310e6069c06fe.tar.gz
spark-a42cc08d219c579019f613faa8d310e6069c06fe.tar.bz2
spark-a42cc08d219c579019f613faa8d310e6069c06fe.zip
[SPARK-3713][SQL] Uses JSON to serialize DataType objects
This PR uses JSON instead of `toString` to serialize `DataType`s. The latter is not only hard to parse but also flaky in many cases. Since we already write schema information to Parquet metadata in the old style, we have to reserve the old `DataType` parser and ensure downward compatibility. The old parser is now renamed to `CaseClassStringParser` and moved into `object DataType`. JoshRosen davies Please help review PySpark related changes, thanks! Author: Cheng Lian <lian.cs.zju@gmail.com> Closes #2563 from liancheng/datatype-to-json and squashes the following commits: fc92eb3 [Cheng Lian] Reverts debugging code, simplifies primitive type JSON representation 438c75f [Cheng Lian] Refactors PySpark DataType JSON SerDe per comments 6b6387b [Cheng Lian] Removes debugging code 6a3ee3a [Cheng Lian] Addresses per review comments dc158b5 [Cheng Lian] Addresses PEP8 issues 99ab4ee [Cheng Lian] Adds compatibility est case for Parquet type conversion a983a6c [Cheng Lian] Adds PySpark support f608c6e [Cheng Lian] De/serializes DataType objects from/to JSON
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql.py153
1 files changed, 75 insertions, 78 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 3d5a281239..d3d36eb995 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -34,6 +34,7 @@ import decimal
import datetime
import keyword
import warnings
+import json
from array import array
from operator import itemgetter
from itertools import imap
@@ -71,6 +72,18 @@ class DataType(object):
def __ne__(self, other):
return not self.__eq__(other)
+ @classmethod
+ def typeName(cls):
+ return cls.__name__[:-4].lower()
+
+ def jsonValue(self):
+ return self.typeName()
+
+ def json(self):
+ return json.dumps(self.jsonValue(),
+ separators=(',', ':'),
+ sort_keys=True)
+
class PrimitiveTypeSingleton(type):
@@ -214,6 +227,16 @@ class ArrayType(DataType):
return "ArrayType(%s,%s)" % (self.elementType,
str(self.containsNull).lower())
+ def jsonValue(self):
+ return {"type": self.typeName(),
+ "elementType": self.elementType.jsonValue(),
+ "containsNull": self.containsNull}
+
+ @classmethod
+ def fromJson(cls, json):
+ return ArrayType(_parse_datatype_json_value(json["elementType"]),
+ json["containsNull"])
+
class MapType(DataType):
@@ -254,6 +277,18 @@ class MapType(DataType):
return "MapType(%s,%s,%s)" % (self.keyType, self.valueType,
str(self.valueContainsNull).lower())
+ def jsonValue(self):
+ return {"type": self.typeName(),
+ "keyType": self.keyType.jsonValue(),
+ "valueType": self.valueType.jsonValue(),
+ "valueContainsNull": self.valueContainsNull}
+
+ @classmethod
+ def fromJson(cls, json):
+ return MapType(_parse_datatype_json_value(json["keyType"]),
+ _parse_datatype_json_value(json["valueType"]),
+ json["valueContainsNull"])
+
class StructField(DataType):
@@ -292,6 +327,17 @@ class StructField(DataType):
return "StructField(%s,%s,%s)" % (self.name, self.dataType,
str(self.nullable).lower())
+ def jsonValue(self):
+ return {"name": self.name,
+ "type": self.dataType.jsonValue(),
+ "nullable": self.nullable}
+
+ @classmethod
+ def fromJson(cls, json):
+ return StructField(json["name"],
+ _parse_datatype_json_value(json["type"]),
+ json["nullable"])
+
class StructType(DataType):
@@ -321,42 +367,30 @@ class StructType(DataType):
return ("StructType(List(%s))" %
",".join(str(field) for field in self.fields))
+ def jsonValue(self):
+ return {"type": self.typeName(),
+ "fields": [f.jsonValue() for f in self.fields]}
-def _parse_datatype_list(datatype_list_string):
- """Parses a list of comma separated data types."""
- index = 0
- datatype_list = []
- start = 0
- depth = 0
- while index < len(datatype_list_string):
- if depth == 0 and datatype_list_string[index] == ",":
- datatype_string = datatype_list_string[start:index].strip()
- datatype_list.append(_parse_datatype_string(datatype_string))
- start = index + 1
- elif datatype_list_string[index] == "(":
- depth += 1
- elif datatype_list_string[index] == ")":
- depth -= 1
+ @classmethod
+ def fromJson(cls, json):
+ return StructType([StructField.fromJson(f) for f in json["fields"]])
- index += 1
- # Handle the last data type
- datatype_string = datatype_list_string[start:index].strip()
- datatype_list.append(_parse_datatype_string(datatype_string))
- return datatype_list
+_all_primitive_types = dict((v.typeName(), v)
+ for v in globals().itervalues()
+ if type(v) is PrimitiveTypeSingleton and
+ v.__base__ == PrimitiveType)
-_all_primitive_types = dict((k, v) for k, v in globals().iteritems()
- if type(v) is PrimitiveTypeSingleton and v.__base__ == PrimitiveType)
+_all_complex_types = dict((v.typeName(), v)
+ for v in [ArrayType, MapType, StructType])
-def _parse_datatype_string(datatype_string):
- """Parses the given data type string.
-
+def _parse_datatype_json_string(json_string):
+ """Parses the given data type JSON string.
>>> def check_datatype(datatype):
- ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(str(datatype))
- ... python_datatype = _parse_datatype_string(
- ... scala_datatype.toString())
+ ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json())
+ ... python_datatype = _parse_datatype_json_string(scala_datatype.json())
... return datatype == python_datatype
>>> all(check_datatype(cls()) for cls in _all_primitive_types.values())
True
@@ -394,51 +428,14 @@ def _parse_datatype_string(datatype_string):
>>> check_datatype(complex_maptype)
True
"""
- index = datatype_string.find("(")
- if index == -1:
- # It is a primitive type.
- index = len(datatype_string)
- type_or_field = datatype_string[:index]
- rest_part = datatype_string[index + 1:len(datatype_string) - 1].strip()
-
- if type_or_field in _all_primitive_types:
- return _all_primitive_types[type_or_field]()
-
- elif type_or_field == "ArrayType":
- last_comma_index = rest_part.rfind(",")
- containsNull = True
- if rest_part[last_comma_index + 1:].strip().lower() == "false":
- containsNull = False
- elementType = _parse_datatype_string(
- rest_part[:last_comma_index].strip())
- return ArrayType(elementType, containsNull)
-
- elif type_or_field == "MapType":
- last_comma_index = rest_part.rfind(",")
- valueContainsNull = True
- if rest_part[last_comma_index + 1:].strip().lower() == "false":
- valueContainsNull = False
- keyType, valueType = _parse_datatype_list(
- rest_part[:last_comma_index].strip())
- return MapType(keyType, valueType, valueContainsNull)
-
- elif type_or_field == "StructField":
- first_comma_index = rest_part.find(",")
- name = rest_part[:first_comma_index].strip()
- last_comma_index = rest_part.rfind(",")
- nullable = True
- if rest_part[last_comma_index + 1:].strip().lower() == "false":
- nullable = False
- dataType = _parse_datatype_string(
- rest_part[first_comma_index + 1:last_comma_index].strip())
- return StructField(name, dataType, nullable)
-
- elif type_or_field == "StructType":
- # rest_part should be in the format like
- # List(StructField(field1,IntegerType,false)).
- field_list_string = rest_part[rest_part.find("(") + 1:-1]
- fields = _parse_datatype_list(field_list_string)
- return StructType(fields)
+ return _parse_datatype_json_value(json.loads(json_string))
+
+
+def _parse_datatype_json_value(json_value):
+ if type(json_value) is unicode and json_value in _all_primitive_types.keys():
+ return _all_primitive_types[json_value]()
+ else:
+ return _all_complex_types[json_value["type"]].fromJson(json_value)
# Mapping Python types to Spark SQL DateType
@@ -992,7 +989,7 @@ class SQLContext(object):
self._sc.pythonExec,
broadcast_vars,
self._sc._javaAccumulator,
- str(returnType))
+ returnType.json())
def inferSchema(self, rdd):
"""Infer and apply a schema to an RDD of L{Row}.
@@ -1128,7 +1125,7 @@ class SQLContext(object):
batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer)
jrdd = self._pythonToJava(rdd._jrdd, batched)
- srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), str(schema))
+ srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
return SchemaRDD(srdd.toJavaSchemaRDD(), self)
def registerRDDAsTable(self, rdd, tableName):
@@ -1218,7 +1215,7 @@ class SQLContext(object):
if schema is None:
srdd = self._ssql_ctx.jsonFile(path)
else:
- scala_datatype = self._ssql_ctx.parseDataType(str(schema))
+ scala_datatype = self._ssql_ctx.parseDataType(schema.json())
srdd = self._ssql_ctx.jsonFile(path, scala_datatype)
return SchemaRDD(srdd.toJavaSchemaRDD(), self)
@@ -1288,7 +1285,7 @@ class SQLContext(object):
if schema is None:
srdd = self._ssql_ctx.jsonRDD(jrdd.rdd())
else:
- scala_datatype = self._ssql_ctx.parseDataType(str(schema))
+ scala_datatype = self._ssql_ctx.parseDataType(schema.json())
srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
return SchemaRDD(srdd.toJavaSchemaRDD(), self)
@@ -1623,7 +1620,7 @@ class SchemaRDD(RDD):
def schema(self):
"""Returns the schema of this SchemaRDD (represented by
a L{StructType})."""
- return _parse_datatype_string(self._jschema_rdd.baseSchemaRDD().schema().toString())
+ return _parse_datatype_json_string(self._jschema_rdd.baseSchemaRDD().schema().json())
def schemaString(self):
"""Returns the output schema in the tree format."""