From c9e2ef52bb54f35a904427389dc492d61f29b018 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Thu, 9 Jul 2015 14:43:38 -0700 Subject: [SPARK-7902] [SPARK-6289] [SPARK-8685] [SQL] [PYSPARK] Refactor of serialization for Python DataFrame This PR fix the long standing issue of serialization between Python RDD and DataFrame, it change to using a customized Pickler for InternalRow to enable customized unpickling (type conversion, especially for UDT), now we can support UDT for UDF, cc mengxr . There is no generated `Row` anymore. Author: Davies Liu Closes #7301 from davies/sql_ser and squashes the following commits: 81bef71 [Davies Liu] address comments e9217bd [Davies Liu] add regression tests db34167 [Davies Liu] Refactor of serialization for Python DataFrame --- python/pyspark/sql/types.py | 419 ++++++++++++++++---------------------------- 1 file changed, 147 insertions(+), 272 deletions(-) (limited to 'python/pyspark/sql/types.py') diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index fecfe6d71e..d638576916 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -20,13 +20,9 @@ import decimal import time import datetime import calendar -import keyword -import warnings import json import re -import weakref from array import array -from operator import itemgetter if sys.version >= "3": long = int @@ -71,6 +67,26 @@ class DataType(object): separators=(',', ':'), sort_keys=True) + def needConversion(self): + """ + Does this type need to conversion between Python object and internal SQL object. + + This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType. + """ + return False + + def toInternal(self, obj): + """ + Converts a Python object into an internal SQL object. + """ + return obj + + def fromInternal(self, obj): + """ + Converts an internal SQL object into a native Python object. + """ + return obj + # This singleton pattern does not work with pickle, you will get # another object after pickle and unpickle @@ -143,6 +159,17 @@ class DateType(AtomicType): __metaclass__ = DataTypeSingleton + EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal() + + def needConversion(self): + return True + + def toInternal(self, d): + return d and d.toordinal() - self.EPOCH_ORDINAL + + def fromInternal(self, v): + return v and datetime.date.fromordinal(v + self.EPOCH_ORDINAL) + class TimestampType(AtomicType): """Timestamp (datetime.datetime) data type. @@ -150,6 +177,19 @@ class TimestampType(AtomicType): __metaclass__ = DataTypeSingleton + def needConversion(self): + return True + + def toInternal(self, dt): + if dt is not None: + seconds = (calendar.timegm(dt.utctimetuple()) if dt.tzinfo + else time.mktime(dt.timetuple())) + return int(seconds * 1e6 + dt.microsecond) + + def fromInternal(self, ts): + if ts is not None: + return datetime.datetime.fromtimestamp(ts / 1e6) + class DecimalType(FractionalType): """Decimal (decimal.Decimal) data type. @@ -259,6 +299,19 @@ class ArrayType(DataType): return ArrayType(_parse_datatype_json_value(json["elementType"]), json["containsNull"]) + def needConversion(self): + return self.elementType.needConversion() + + def toInternal(self, obj): + if not self.needConversion(): + return obj + return obj and [self.elementType.toInternal(v) for v in obj] + + def fromInternal(self, obj): + if not self.needConversion(): + return obj + return obj and [self.elementType.fromInternal(v) for v in obj] + class MapType(DataType): """Map data type. @@ -304,6 +357,21 @@ class MapType(DataType): _parse_datatype_json_value(json["valueType"]), json["valueContainsNull"]) + def needConversion(self): + return self.keyType.needConversion() or self.valueType.needConversion() + + def toInternal(self, obj): + if not self.needConversion(): + return obj + return obj and dict((self.keyType.toInternal(k), self.valueType.toInternal(v)) + for k, v in obj.items()) + + def fromInternal(self, obj): + if not self.needConversion(): + return obj + return obj and dict((self.keyType.fromInternal(k), self.valueType.fromInternal(v)) + for k, v in obj.items()) + class StructField(DataType): """A field in :class:`StructType`. @@ -311,7 +379,7 @@ class StructField(DataType): :param name: string, name of the field. :param dataType: :class:`DataType` of the field. :param nullable: boolean, whether the field can be null (None) or not. - :param metadata: a dict from string to simple type that can be serialized to JSON automatically + :param metadata: a dict from string to simple type that can be toInternald to JSON automatically """ def __init__(self, name, dataType, nullable=True, metadata=None): @@ -351,6 +419,15 @@ class StructField(DataType): json["nullable"], json["metadata"]) + def needConversion(self): + return self.dataType.needConversion() + + def toInternal(self, obj): + return self.dataType.toInternal(obj) + + def fromInternal(self, obj): + return self.dataType.fromInternal(obj) + class StructType(DataType): """Struct type, consisting of a list of :class:`StructField`. @@ -371,10 +448,13 @@ class StructType(DataType): """ if not fields: self.fields = [] + self.names = [] else: self.fields = fields + self.names = [f.name for f in fields] assert all(isinstance(f, StructField) for f in fields),\ "fields should be a list of StructField" + self._needSerializeFields = None def add(self, field, data_type=None, nullable=True, metadata=None): """ @@ -406,6 +486,7 @@ class StructType(DataType): """ if isinstance(field, StructField): self.fields.append(field) + self.names.append(field.name) else: if isinstance(field, str) and data_type is None: raise ValueError("Must specify DataType if passing name of struct_field to create.") @@ -415,6 +496,7 @@ class StructType(DataType): else: data_type_f = data_type self.fields.append(StructField(field, data_type_f, nullable, metadata)) + self.names.append(field) return self def simpleString(self): @@ -432,6 +514,41 @@ class StructType(DataType): def fromJson(cls, json): return StructType([StructField.fromJson(f) for f in json["fields"]]) + def needConversion(self): + # We need convert Row()/namedtuple into tuple() + return True + + def toInternal(self, obj): + if obj is None: + return + + if self._needSerializeFields is None: + self._needSerializeFields = any(f.needConversion() for f in self.fields) + + if self._needSerializeFields: + if isinstance(obj, dict): + return tuple(f.toInternal(obj.get(n)) for n, f in zip(names, self.fields)) + elif isinstance(obj, (tuple, list)): + return tuple(f.toInternal(v) for f, v in zip(self.fields, obj)) + else: + raise ValueError("Unexpected tuple %r with StructType" % obj) + else: + if isinstance(obj, dict): + return tuple(obj.get(n) for n in self.names) + elif isinstance(obj, (list, tuple)): + return tuple(obj) + else: + raise ValueError("Unexpected tuple %r with StructType" % obj) + + def fromInternal(self, obj): + if obj is None: + return + if isinstance(obj, Row): + # it's already converted by pickler + return obj + values = [f.dataType.fromInternal(v) for f, v in zip(self.fields, obj)] + return _create_row(self.names, values) + class UserDefinedType(DataType): """User-defined type (UDT). @@ -464,17 +581,35 @@ class UserDefinedType(DataType): """ raise NotImplementedError("UDT must have a paired Scala UDT.") + def needConversion(self): + return True + + @classmethod + def _cachedSqlType(cls): + """ + Cache the sqlType() into class, because it's heavy used in `toInternal`. + """ + if not hasattr(cls, "_cached_sql_type"): + cls._cached_sql_type = cls.sqlType() + return cls._cached_sql_type + + def toInternal(self, obj): + return self._cachedSqlType().toInternal(self.serialize(obj)) + + def fromInternal(self, obj): + return self.deserialize(self._cachedSqlType().fromInternal(obj)) + def serialize(self, obj): """ Converts the a user-type object into a SQL datum. """ - raise NotImplementedError("UDT must implement serialize().") + raise NotImplementedError("UDT must implement toInternal().") def deserialize(self, datum): """ Converts a SQL datum into a user-type object. """ - raise NotImplementedError("UDT must implement deserialize().") + raise NotImplementedError("UDT must implement fromInternal().") def simpleString(self): return 'udt' @@ -671,117 +806,6 @@ def _infer_schema(row): return StructType(fields) -def _need_python_to_sql_conversion(dataType): - """ - Checks whether we need python to sql conversion for the given type. - For now, only UDTs need this conversion. - - >>> _need_python_to_sql_conversion(DoubleType()) - False - >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False), - ... StructField("values", ArrayType(DoubleType(), False), False)]) - >>> _need_python_to_sql_conversion(schema0) - True - >>> _need_python_to_sql_conversion(ExamplePointUDT()) - True - >>> schema1 = ArrayType(ExamplePointUDT(), False) - >>> _need_python_to_sql_conversion(schema1) - True - >>> schema2 = StructType([StructField("label", DoubleType(), False), - ... StructField("point", ExamplePointUDT(), False)]) - >>> _need_python_to_sql_conversion(schema2) - True - """ - if isinstance(dataType, StructType): - # convert namedtuple or Row into tuple - return True - elif isinstance(dataType, ArrayType): - return _need_python_to_sql_conversion(dataType.elementType) - elif isinstance(dataType, MapType): - return _need_python_to_sql_conversion(dataType.keyType) or \ - _need_python_to_sql_conversion(dataType.valueType) - elif isinstance(dataType, UserDefinedType): - return True - elif isinstance(dataType, (DateType, TimestampType)): - return True - else: - return False - - -EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal() - - -def _python_to_sql_converter(dataType): - """ - Returns a converter that converts a Python object into a SQL datum for the given type. - - >>> conv = _python_to_sql_converter(DoubleType()) - >>> conv(1.0) - 1.0 - >>> conv = _python_to_sql_converter(ArrayType(DoubleType(), False)) - >>> conv([1.0, 2.0]) - [1.0, 2.0] - >>> conv = _python_to_sql_converter(ExamplePointUDT()) - >>> conv(ExamplePoint(1.0, 2.0)) - [1.0, 2.0] - >>> schema = StructType([StructField("label", DoubleType(), False), - ... StructField("point", ExamplePointUDT(), False)]) - >>> conv = _python_to_sql_converter(schema) - >>> conv((1.0, ExamplePoint(1.0, 2.0))) - (1.0, [1.0, 2.0]) - """ - if not _need_python_to_sql_conversion(dataType): - return lambda x: x - - if isinstance(dataType, StructType): - names, types = zip(*[(f.name, f.dataType) for f in dataType.fields]) - if any(_need_python_to_sql_conversion(t) for t in types): - converters = [_python_to_sql_converter(t) for t in types] - - def converter(obj): - if isinstance(obj, dict): - return tuple(c(obj.get(n)) for n, c in zip(names, converters)) - elif isinstance(obj, tuple): - if hasattr(obj, "__fields__") or hasattr(obj, "_fields"): - return tuple(c(v) for c, v in zip(converters, obj)) - else: - return tuple(c(v) for c, v in zip(converters, obj)) - elif obj is not None: - raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType)) - else: - def converter(obj): - if isinstance(obj, dict): - return tuple(obj.get(n) for n in names) - else: - return tuple(obj) - return converter - elif isinstance(dataType, ArrayType): - element_converter = _python_to_sql_converter(dataType.elementType) - return lambda a: a and [element_converter(v) for v in a] - elif isinstance(dataType, MapType): - key_converter = _python_to_sql_converter(dataType.keyType) - value_converter = _python_to_sql_converter(dataType.valueType) - return lambda m: m and dict([(key_converter(k), value_converter(v)) for k, v in m.items()]) - - elif isinstance(dataType, UserDefinedType): - return lambda obj: obj and dataType.serialize(obj) - - elif isinstance(dataType, DateType): - return lambda d: d and d.toordinal() - EPOCH_ORDINAL - - elif isinstance(dataType, TimestampType): - - def to_posix_timstamp(dt): - if dt: - seconds = (calendar.timegm(dt.utctimetuple()) if dt.tzinfo - else time.mktime(dt.timetuple())) - return int(seconds * 1e6 + dt.microsecond) - return to_posix_timstamp - - else: - raise ValueError("Unexpected type %r" % dataType) - - def _has_nulltype(dt): """ Return whether there is NullType in `dt` or not """ if isinstance(dt, StructType): @@ -1076,7 +1100,7 @@ def _verify_type(obj, dataType): if isinstance(dataType, UserDefinedType): if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType): raise ValueError("%r is not an instance of type %r" % (obj, dataType)) - _verify_type(dataType.serialize(obj), dataType.sqlType()) + _verify_type(dataType.toInternal(obj), dataType.sqlType()) return _type = type(dataType) @@ -1086,7 +1110,7 @@ def _verify_type(obj, dataType): if not isinstance(obj, (tuple, list)): raise TypeError("StructType can not accept object in type %s" % type(obj)) else: - # subclass of them can not be deserialized in JVM + # subclass of them can not be fromInternald in JVM if type(obj) not in _acceptable_types[_type]: raise TypeError("%s can not accept object in type %s" % (dataType, type(obj))) @@ -1106,159 +1130,10 @@ def _verify_type(obj, dataType): for v, f in zip(obj, dataType.fields): _verify_type(v, f.dataType) -_cached_cls = weakref.WeakValueDictionary() - - -def _restore_object(dataType, obj): - """ Restore object during unpickling. """ - # use id(dataType) as key to speed up lookup in dict - # Because of batched pickling, dataType will be the - # same object in most cases. - k = id(dataType) - cls = _cached_cls.get(k) - if cls is None or cls.__datatype is not dataType: - # use dataType as key to avoid create multiple class - cls = _cached_cls.get(dataType) - if cls is None: - cls = _create_cls(dataType) - _cached_cls[dataType] = cls - cls.__datatype = dataType - _cached_cls[k] = cls - return cls(obj) - - -def _create_object(cls, v): - """ Create an customized object with class `cls`. """ - # datetime.date would be deserialized as datetime.datetime - # from java type, so we need to set it back. - if cls is datetime.date and isinstance(v, datetime.datetime): - return v.date() - return cls(v) if v is not None else v - - -def _create_getter(dt, i): - """ Create a getter for item `i` with schema """ - cls = _create_cls(dt) - - def getter(self): - return _create_object(cls, self[i]) - - return getter - - -def _has_struct_or_date(dt): - """Return whether `dt` is or has StructType/DateType in it""" - if isinstance(dt, StructType): - return True - elif isinstance(dt, ArrayType): - return _has_struct_or_date(dt.elementType) - elif isinstance(dt, MapType): - return _has_struct_or_date(dt.keyType) or _has_struct_or_date(dt.valueType) - elif isinstance(dt, DateType): - return True - elif isinstance(dt, UserDefinedType): - return True - return False - - -def _create_properties(fields): - """Create properties according to fields""" - ps = {} - for i, f in enumerate(fields): - name = f.name - if (name.startswith("__") and name.endswith("__") - or keyword.iskeyword(name)): - warnings.warn("field name %s can not be accessed in Python," - "use position to access it instead" % name) - if _has_struct_or_date(f.dataType): - # delay creating object until accessing it - getter = _create_getter(f.dataType, i) - else: - getter = itemgetter(i) - ps[name] = property(getter) - return ps - - -def _create_cls(dataType): - """ - Create an class by dataType - - The created class is similar to namedtuple, but can have nested schema. - - >>> schema = _parse_schema_abstract("a b c") - >>> row = (1, 1.0, "str") - >>> schema = _infer_schema_type(row, schema) - >>> obj = _create_cls(schema)(row) - >>> import pickle - >>> pickle.loads(pickle.dumps(obj)) - Row(a=1, b=1.0, c='str') - - >>> row = [[1], {"key": (1, 2.0)}] - >>> schema = _parse_schema_abstract("a[] b{c d}") - >>> schema = _infer_schema_type(row, schema) - >>> obj = _create_cls(schema)(row) - >>> pickle.loads(pickle.dumps(obj)) - Row(a=[1], b={'key': Row(c=1, d=2.0)}) - >>> pickle.loads(pickle.dumps(obj.a)) - [1] - >>> pickle.loads(pickle.dumps(obj.b)) - {'key': Row(c=1, d=2.0)} - """ - - if isinstance(dataType, ArrayType): - cls = _create_cls(dataType.elementType) - - def List(l): - if l is None: - return - return [_create_object(cls, v) for v in l] - - return List - - elif isinstance(dataType, MapType): - kcls = _create_cls(dataType.keyType) - vcls = _create_cls(dataType.valueType) - - def Dict(d): - if d is None: - return - return dict((_create_object(kcls, k), _create_object(vcls, v)) for k, v in d.items()) - - return Dict - - elif isinstance(dataType, DateType): - return datetime.date - - elif isinstance(dataType, UserDefinedType): - return lambda datum: dataType.deserialize(datum) - - elif not isinstance(dataType, StructType): - # no wrapper for atomic types - return lambda x: x - - class Row(tuple): - - """ Row in DataFrame """ - __datatype = dataType - __fields__ = tuple(f.name for f in dataType.fields) - __slots__ = () - - # create property for fast access - locals().update(_create_properties(dataType.fields)) - - def asDict(self): - """ Return as a dict """ - return dict((n, getattr(self, n)) for n in self.__fields__) - - def __repr__(self): - # call collect __repr__ for nested objects - return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n)) - for n in self.__fields__)) - - def __reduce__(self): - return (_restore_object, (self.__datatype, tuple(self))) - return Row +# This is used to unpickle a Row from JVM +def _create_row_inbound_converter(dataType): + return lambda *a: dataType.fromInternal(a) def _create_row(fields, values): -- cgit v1.2.3