diff options
Diffstat (limited to 'python/pyspark/sql/types.py')
-rw-r--r-- | python/pyspark/sql/types.py | 27 |
1 files changed, 18 insertions, 9 deletions
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 8f286b631f..23d9adb0da 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -655,12 +655,15 @@ def _need_python_to_sql_conversion(dataType): _need_python_to_sql_conversion(dataType.valueType) elif isinstance(dataType, UserDefinedType): return True - elif isinstance(dataType, TimestampType): + elif isinstance(dataType, (DateType, TimestampType)): return True else: return False +EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal() + + def _python_to_sql_converter(dataType): """ Returns a converter that converts a Python object into a SQL datum for the given type. @@ -698,26 +701,32 @@ def _python_to_sql_converter(dataType): return tuple(c(d.get(n)) for n, c in zip(names, converters)) else: return tuple(c(v) for c, v in zip(converters, obj)) - else: + elif obj is not None: raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType)) return converter elif isinstance(dataType, ArrayType): element_converter = _python_to_sql_converter(dataType.elementType) - return lambda a: [element_converter(v) for v in a] + return lambda a: a and [element_converter(v) for v in a] elif isinstance(dataType, MapType): key_converter = _python_to_sql_converter(dataType.keyType) value_converter = _python_to_sql_converter(dataType.valueType) - return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()]) + return lambda m: m and dict([(key_converter(k), value_converter(v)) for k, v in m.items()]) + elif isinstance(dataType, UserDefinedType): - return lambda obj: dataType.serialize(obj) + return lambda obj: obj and dataType.serialize(obj) + + elif isinstance(dataType, DateType): + return lambda d: d and d.toordinal() - EPOCH_ORDINAL + elif isinstance(dataType, TimestampType): def to_posix_timstamp(dt): - if dt.tzinfo is None: - return int(time.mktime(dt.timetuple()) * 1e7 + dt.microsecond * 10) - else: - return int(calendar.timegm(dt.utctimetuple()) * 1e7 + dt.microsecond * 10) + if dt: + seconds = (calendar.timegm(dt.utctimetuple()) if dt.tzinfo + else time.mktime(dt.timetuple())) + return int(seconds * 1e7 + dt.microsecond * 10) return to_posix_timstamp + else: raise ValueError("Unexpected type %r" % dataType) |