aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/types.py
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-07-09 14:43:38 -0700
committerDavies Liu <davies.liu@gmail.com>2015-07-09 14:43:38 -0700
commitc9e2ef52bb54f35a904427389dc492d61f29b018 (patch)
tree90887ae7055aa78751561119083bd09ac099e0f4 /python/pyspark/sql/types.py
parent3ccebf36c5abe04702d4cf223552a94034d980fb (diff)
downloadspark-c9e2ef52bb54f35a904427389dc492d61f29b018.tar.gz
spark-c9e2ef52bb54f35a904427389dc492d61f29b018.tar.bz2
spark-c9e2ef52bb54f35a904427389dc492d61f29b018.zip
[SPARK-7902] [SPARK-6289] [SPARK-8685] [SQL] [PYSPARK] Refactor of serialization for Python DataFrame
This PR fix the long standing issue of serialization between Python RDD and DataFrame, it change to using a customized Pickler for InternalRow to enable customized unpickling (type conversion, especially for UDT), now we can support UDT for UDF, cc mengxr . There is no generated `Row` anymore. Author: Davies Liu <davies@databricks.com> Closes #7301 from davies/sql_ser and squashes the following commits: 81bef71 [Davies Liu] address comments e9217bd [Davies Liu] add regression tests db34167 [Davies Liu] Refactor of serialization for Python DataFrame
Diffstat (limited to 'python/pyspark/sql/types.py')
-rw-r--r--python/pyspark/sql/types.py419
1 files changed, 147 insertions, 272 deletions
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index fecfe6d71e..d638576916 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -20,13 +20,9 @@ import decimal
import time
import datetime
import calendar
-import keyword
-import warnings
import json
import re
-import weakref
from array import array
-from operator import itemgetter
if sys.version >= "3":
long = int
@@ -71,6 +67,26 @@ class DataType(object):
separators=(',', ':'),
sort_keys=True)
+ def needConversion(self):
+ """
+ Does this type need to conversion between Python object and internal SQL object.
+
+ This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.
+ """
+ return False
+
+ def toInternal(self, obj):
+ """
+ Converts a Python object into an internal SQL object.
+ """
+ return obj
+
+ def fromInternal(self, obj):
+ """
+ Converts an internal SQL object into a native Python object.
+ """
+ return obj
+
# This singleton pattern does not work with pickle, you will get
# another object after pickle and unpickle
@@ -143,6 +159,17 @@ class DateType(AtomicType):
__metaclass__ = DataTypeSingleton
+ EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal()
+
+ def needConversion(self):
+ return True
+
+ def toInternal(self, d):
+ return d and d.toordinal() - self.EPOCH_ORDINAL
+
+ def fromInternal(self, v):
+ return v and datetime.date.fromordinal(v + self.EPOCH_ORDINAL)
+
class TimestampType(AtomicType):
"""Timestamp (datetime.datetime) data type.
@@ -150,6 +177,19 @@ class TimestampType(AtomicType):
__metaclass__ = DataTypeSingleton
+ def needConversion(self):
+ return True
+
+ def toInternal(self, dt):
+ if dt is not None:
+ seconds = (calendar.timegm(dt.utctimetuple()) if dt.tzinfo
+ else time.mktime(dt.timetuple()))
+ return int(seconds * 1e6 + dt.microsecond)
+
+ def fromInternal(self, ts):
+ if ts is not None:
+ return datetime.datetime.fromtimestamp(ts / 1e6)
+
class DecimalType(FractionalType):
"""Decimal (decimal.Decimal) data type.
@@ -259,6 +299,19 @@ class ArrayType(DataType):
return ArrayType(_parse_datatype_json_value(json["elementType"]),
json["containsNull"])
+ def needConversion(self):
+ return self.elementType.needConversion()
+
+ def toInternal(self, obj):
+ if not self.needConversion():
+ return obj
+ return obj and [self.elementType.toInternal(v) for v in obj]
+
+ def fromInternal(self, obj):
+ if not self.needConversion():
+ return obj
+ return obj and [self.elementType.fromInternal(v) for v in obj]
+
class MapType(DataType):
"""Map data type.
@@ -304,6 +357,21 @@ class MapType(DataType):
_parse_datatype_json_value(json["valueType"]),
json["valueContainsNull"])
+ def needConversion(self):
+ return self.keyType.needConversion() or self.valueType.needConversion()
+
+ def toInternal(self, obj):
+ if not self.needConversion():
+ return obj
+ return obj and dict((self.keyType.toInternal(k), self.valueType.toInternal(v))
+ for k, v in obj.items())
+
+ def fromInternal(self, obj):
+ if not self.needConversion():
+ return obj
+ return obj and dict((self.keyType.fromInternal(k), self.valueType.fromInternal(v))
+ for k, v in obj.items())
+
class StructField(DataType):
"""A field in :class:`StructType`.
@@ -311,7 +379,7 @@ class StructField(DataType):
:param name: string, name of the field.
:param dataType: :class:`DataType` of the field.
:param nullable: boolean, whether the field can be null (None) or not.
- :param metadata: a dict from string to simple type that can be serialized to JSON automatically
+ :param metadata: a dict from string to simple type that can be toInternald to JSON automatically
"""
def __init__(self, name, dataType, nullable=True, metadata=None):
@@ -351,6 +419,15 @@ class StructField(DataType):
json["nullable"],
json["metadata"])
+ def needConversion(self):
+ return self.dataType.needConversion()
+
+ def toInternal(self, obj):
+ return self.dataType.toInternal(obj)
+
+ def fromInternal(self, obj):
+ return self.dataType.fromInternal(obj)
+
class StructType(DataType):
"""Struct type, consisting of a list of :class:`StructField`.
@@ -371,10 +448,13 @@ class StructType(DataType):
"""
if not fields:
self.fields = []
+ self.names = []
else:
self.fields = fields
+ self.names = [f.name for f in fields]
assert all(isinstance(f, StructField) for f in fields),\
"fields should be a list of StructField"
+ self._needSerializeFields = None
def add(self, field, data_type=None, nullable=True, metadata=None):
"""
@@ -406,6 +486,7 @@ class StructType(DataType):
"""
if isinstance(field, StructField):
self.fields.append(field)
+ self.names.append(field.name)
else:
if isinstance(field, str) and data_type is None:
raise ValueError("Must specify DataType if passing name of struct_field to create.")
@@ -415,6 +496,7 @@ class StructType(DataType):
else:
data_type_f = data_type
self.fields.append(StructField(field, data_type_f, nullable, metadata))
+ self.names.append(field)
return self
def simpleString(self):
@@ -432,6 +514,41 @@ class StructType(DataType):
def fromJson(cls, json):
return StructType([StructField.fromJson(f) for f in json["fields"]])
+ def needConversion(self):
+ # We need convert Row()/namedtuple into tuple()
+ return True
+
+ def toInternal(self, obj):
+ if obj is None:
+ return
+
+ if self._needSerializeFields is None:
+ self._needSerializeFields = any(f.needConversion() for f in self.fields)
+
+ if self._needSerializeFields:
+ if isinstance(obj, dict):
+ return tuple(f.toInternal(obj.get(n)) for n, f in zip(names, self.fields))
+ elif isinstance(obj, (tuple, list)):
+ return tuple(f.toInternal(v) for f, v in zip(self.fields, obj))
+ else:
+ raise ValueError("Unexpected tuple %r with StructType" % obj)
+ else:
+ if isinstance(obj, dict):
+ return tuple(obj.get(n) for n in self.names)
+ elif isinstance(obj, (list, tuple)):
+ return tuple(obj)
+ else:
+ raise ValueError("Unexpected tuple %r with StructType" % obj)
+
+ def fromInternal(self, obj):
+ if obj is None:
+ return
+ if isinstance(obj, Row):
+ # it's already converted by pickler
+ return obj
+ values = [f.dataType.fromInternal(v) for f, v in zip(self.fields, obj)]
+ return _create_row(self.names, values)
+
class UserDefinedType(DataType):
"""User-defined type (UDT).
@@ -464,17 +581,35 @@ class UserDefinedType(DataType):
"""
raise NotImplementedError("UDT must have a paired Scala UDT.")
+ def needConversion(self):
+ return True
+
+ @classmethod
+ def _cachedSqlType(cls):
+ """
+ Cache the sqlType() into class, because it's heavy used in `toInternal`.
+ """
+ if not hasattr(cls, "_cached_sql_type"):
+ cls._cached_sql_type = cls.sqlType()
+ return cls._cached_sql_type
+
+ def toInternal(self, obj):
+ return self._cachedSqlType().toInternal(self.serialize(obj))
+
+ def fromInternal(self, obj):
+ return self.deserialize(self._cachedSqlType().fromInternal(obj))
+
def serialize(self, obj):
"""
Converts the a user-type object into a SQL datum.
"""
- raise NotImplementedError("UDT must implement serialize().")
+ raise NotImplementedError("UDT must implement toInternal().")
def deserialize(self, datum):
"""
Converts a SQL datum into a user-type object.
"""
- raise NotImplementedError("UDT must implement deserialize().")
+ raise NotImplementedError("UDT must implement fromInternal().")
def simpleString(self):
return 'udt'
@@ -671,117 +806,6 @@ def _infer_schema(row):
return StructType(fields)
-def _need_python_to_sql_conversion(dataType):
- """
- Checks whether we need python to sql conversion for the given type.
- For now, only UDTs need this conversion.
-
- >>> _need_python_to_sql_conversion(DoubleType())
- False
- >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False),
- ... StructField("values", ArrayType(DoubleType(), False), False)])
- >>> _need_python_to_sql_conversion(schema0)
- True
- >>> _need_python_to_sql_conversion(ExamplePointUDT())
- True
- >>> schema1 = ArrayType(ExamplePointUDT(), False)
- >>> _need_python_to_sql_conversion(schema1)
- True
- >>> schema2 = StructType([StructField("label", DoubleType(), False),
- ... StructField("point", ExamplePointUDT(), False)])
- >>> _need_python_to_sql_conversion(schema2)
- True
- """
- if isinstance(dataType, StructType):
- # convert namedtuple or Row into tuple
- return True
- elif isinstance(dataType, ArrayType):
- return _need_python_to_sql_conversion(dataType.elementType)
- elif isinstance(dataType, MapType):
- return _need_python_to_sql_conversion(dataType.keyType) or \
- _need_python_to_sql_conversion(dataType.valueType)
- elif isinstance(dataType, UserDefinedType):
- return True
- elif isinstance(dataType, (DateType, TimestampType)):
- return True
- else:
- return False
-
-
-EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal()
-
-
-def _python_to_sql_converter(dataType):
- """
- Returns a converter that converts a Python object into a SQL datum for the given type.
-
- >>> conv = _python_to_sql_converter(DoubleType())
- >>> conv(1.0)
- 1.0
- >>> conv = _python_to_sql_converter(ArrayType(DoubleType(), False))
- >>> conv([1.0, 2.0])
- [1.0, 2.0]
- >>> conv = _python_to_sql_converter(ExamplePointUDT())
- >>> conv(ExamplePoint(1.0, 2.0))
- [1.0, 2.0]
- >>> schema = StructType([StructField("label", DoubleType(), False),
- ... StructField("point", ExamplePointUDT(), False)])
- >>> conv = _python_to_sql_converter(schema)
- >>> conv((1.0, ExamplePoint(1.0, 2.0)))
- (1.0, [1.0, 2.0])
- """
- if not _need_python_to_sql_conversion(dataType):
- return lambda x: x
-
- if isinstance(dataType, StructType):
- names, types = zip(*[(f.name, f.dataType) for f in dataType.fields])
- if any(_need_python_to_sql_conversion(t) for t in types):
- converters = [_python_to_sql_converter(t) for t in types]
-
- def converter(obj):
- if isinstance(obj, dict):
- return tuple(c(obj.get(n)) for n, c in zip(names, converters))
- elif isinstance(obj, tuple):
- if hasattr(obj, "__fields__") or hasattr(obj, "_fields"):
- return tuple(c(v) for c, v in zip(converters, obj))
- else:
- return tuple(c(v) for c, v in zip(converters, obj))
- elif obj is not None:
- raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType))
- else:
- def converter(obj):
- if isinstance(obj, dict):
- return tuple(obj.get(n) for n in names)
- else:
- return tuple(obj)
- return converter
- elif isinstance(dataType, ArrayType):
- element_converter = _python_to_sql_converter(dataType.elementType)
- return lambda a: a and [element_converter(v) for v in a]
- elif isinstance(dataType, MapType):
- key_converter = _python_to_sql_converter(dataType.keyType)
- value_converter = _python_to_sql_converter(dataType.valueType)
- return lambda m: m and dict([(key_converter(k), value_converter(v)) for k, v in m.items()])
-
- elif isinstance(dataType, UserDefinedType):
- return lambda obj: obj and dataType.serialize(obj)
-
- elif isinstance(dataType, DateType):
- return lambda d: d and d.toordinal() - EPOCH_ORDINAL
-
- elif isinstance(dataType, TimestampType):
-
- def to_posix_timstamp(dt):
- if dt:
- seconds = (calendar.timegm(dt.utctimetuple()) if dt.tzinfo
- else time.mktime(dt.timetuple()))
- return int(seconds * 1e6 + dt.microsecond)
- return to_posix_timstamp
-
- else:
- raise ValueError("Unexpected type %r" % dataType)
-
-
def _has_nulltype(dt):
""" Return whether there is NullType in `dt` or not """
if isinstance(dt, StructType):
@@ -1076,7 +1100,7 @@ def _verify_type(obj, dataType):
if isinstance(dataType, UserDefinedType):
if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType):
raise ValueError("%r is not an instance of type %r" % (obj, dataType))
- _verify_type(dataType.serialize(obj), dataType.sqlType())
+ _verify_type(dataType.toInternal(obj), dataType.sqlType())
return
_type = type(dataType)
@@ -1086,7 +1110,7 @@ def _verify_type(obj, dataType):
if not isinstance(obj, (tuple, list)):
raise TypeError("StructType can not accept object in type %s" % type(obj))
else:
- # subclass of them can not be deserialized in JVM
+ # subclass of them can not be fromInternald in JVM
if type(obj) not in _acceptable_types[_type]:
raise TypeError("%s can not accept object in type %s" % (dataType, type(obj)))
@@ -1106,159 +1130,10 @@ def _verify_type(obj, dataType):
for v, f in zip(obj, dataType.fields):
_verify_type(v, f.dataType)
-_cached_cls = weakref.WeakValueDictionary()
-
-
-def _restore_object(dataType, obj):
- """ Restore object during unpickling. """
- # use id(dataType) as key to speed up lookup in dict
- # Because of batched pickling, dataType will be the
- # same object in most cases.
- k = id(dataType)
- cls = _cached_cls.get(k)
- if cls is None or cls.__datatype is not dataType:
- # use dataType as key to avoid create multiple class
- cls = _cached_cls.get(dataType)
- if cls is None:
- cls = _create_cls(dataType)
- _cached_cls[dataType] = cls
- cls.__datatype = dataType
- _cached_cls[k] = cls
- return cls(obj)
-
-
-def _create_object(cls, v):
- """ Create an customized object with class `cls`. """
- # datetime.date would be deserialized as datetime.datetime
- # from java type, so we need to set it back.
- if cls is datetime.date and isinstance(v, datetime.datetime):
- return v.date()
- return cls(v) if v is not None else v
-
-
-def _create_getter(dt, i):
- """ Create a getter for item `i` with schema """
- cls = _create_cls(dt)
-
- def getter(self):
- return _create_object(cls, self[i])
-
- return getter
-
-
-def _has_struct_or_date(dt):
- """Return whether `dt` is or has StructType/DateType in it"""
- if isinstance(dt, StructType):
- return True
- elif isinstance(dt, ArrayType):
- return _has_struct_or_date(dt.elementType)
- elif isinstance(dt, MapType):
- return _has_struct_or_date(dt.keyType) or _has_struct_or_date(dt.valueType)
- elif isinstance(dt, DateType):
- return True
- elif isinstance(dt, UserDefinedType):
- return True
- return False
-
-
-def _create_properties(fields):
- """Create properties according to fields"""
- ps = {}
- for i, f in enumerate(fields):
- name = f.name
- if (name.startswith("__") and name.endswith("__")
- or keyword.iskeyword(name)):
- warnings.warn("field name %s can not be accessed in Python,"
- "use position to access it instead" % name)
- if _has_struct_or_date(f.dataType):
- # delay creating object until accessing it
- getter = _create_getter(f.dataType, i)
- else:
- getter = itemgetter(i)
- ps[name] = property(getter)
- return ps
-
-
-def _create_cls(dataType):
- """
- Create an class by dataType
-
- The created class is similar to namedtuple, but can have nested schema.
-
- >>> schema = _parse_schema_abstract("a b c")
- >>> row = (1, 1.0, "str")
- >>> schema = _infer_schema_type(row, schema)
- >>> obj = _create_cls(schema)(row)
- >>> import pickle
- >>> pickle.loads(pickle.dumps(obj))
- Row(a=1, b=1.0, c='str')
-
- >>> row = [[1], {"key": (1, 2.0)}]
- >>> schema = _parse_schema_abstract("a[] b{c d}")
- >>> schema = _infer_schema_type(row, schema)
- >>> obj = _create_cls(schema)(row)
- >>> pickle.loads(pickle.dumps(obj))
- Row(a=[1], b={'key': Row(c=1, d=2.0)})
- >>> pickle.loads(pickle.dumps(obj.a))
- [1]
- >>> pickle.loads(pickle.dumps(obj.b))
- {'key': Row(c=1, d=2.0)}
- """
-
- if isinstance(dataType, ArrayType):
- cls = _create_cls(dataType.elementType)
-
- def List(l):
- if l is None:
- return
- return [_create_object(cls, v) for v in l]
-
- return List
-
- elif isinstance(dataType, MapType):
- kcls = _create_cls(dataType.keyType)
- vcls = _create_cls(dataType.valueType)
-
- def Dict(d):
- if d is None:
- return
- return dict((_create_object(kcls, k), _create_object(vcls, v)) for k, v in d.items())
-
- return Dict
-
- elif isinstance(dataType, DateType):
- return datetime.date
-
- elif isinstance(dataType, UserDefinedType):
- return lambda datum: dataType.deserialize(datum)
-
- elif not isinstance(dataType, StructType):
- # no wrapper for atomic types
- return lambda x: x
-
- class Row(tuple):
-
- """ Row in DataFrame """
- __datatype = dataType
- __fields__ = tuple(f.name for f in dataType.fields)
- __slots__ = ()
-
- # create property for fast access
- locals().update(_create_properties(dataType.fields))
-
- def asDict(self):
- """ Return as a dict """
- return dict((n, getattr(self, n)) for n in self.__fields__)
-
- def __repr__(self):
- # call collect __repr__ for nested objects
- return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n))
- for n in self.__fields__))
-
- def __reduce__(self):
- return (_restore_object, (self.__datatype, tuple(self)))
- return Row
+# This is used to unpickle a Row from JVM
+def _create_row_inbound_converter(dataType):
+ return lambda *a: dataType.fromInternal(a)
def _create_row(fields, values):