aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/types.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/sql/types.py')
-rw-r--r--python/pyspark/sql/types.py27
1 files changed, 18 insertions, 9 deletions
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 8f286b631f..23d9adb0da 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -655,12 +655,15 @@ def _need_python_to_sql_conversion(dataType):
_need_python_to_sql_conversion(dataType.valueType)
elif isinstance(dataType, UserDefinedType):
return True
- elif isinstance(dataType, TimestampType):
+ elif isinstance(dataType, (DateType, TimestampType)):
return True
else:
return False
+EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal()
+
+
def _python_to_sql_converter(dataType):
"""
Returns a converter that converts a Python object into a SQL datum for the given type.
@@ -698,26 +701,32 @@ def _python_to_sql_converter(dataType):
return tuple(c(d.get(n)) for n, c in zip(names, converters))
else:
return tuple(c(v) for c, v in zip(converters, obj))
- else:
+ elif obj is not None:
raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType))
return converter
elif isinstance(dataType, ArrayType):
element_converter = _python_to_sql_converter(dataType.elementType)
- return lambda a: [element_converter(v) for v in a]
+ return lambda a: a and [element_converter(v) for v in a]
elif isinstance(dataType, MapType):
key_converter = _python_to_sql_converter(dataType.keyType)
value_converter = _python_to_sql_converter(dataType.valueType)
- return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()])
+ return lambda m: m and dict([(key_converter(k), value_converter(v)) for k, v in m.items()])
+
elif isinstance(dataType, UserDefinedType):
- return lambda obj: dataType.serialize(obj)
+ return lambda obj: obj and dataType.serialize(obj)
+
+ elif isinstance(dataType, DateType):
+ return lambda d: d and d.toordinal() - EPOCH_ORDINAL
+
elif isinstance(dataType, TimestampType):
def to_posix_timstamp(dt):
- if dt.tzinfo is None:
- return int(time.mktime(dt.timetuple()) * 1e7 + dt.microsecond * 10)
- else:
- return int(calendar.timegm(dt.utctimetuple()) * 1e7 + dt.microsecond * 10)
+ if dt:
+ seconds = (calendar.timegm(dt.utctimetuple()) if dt.tzinfo
+ else time.mktime(dt.timetuple()))
+ return int(seconds * 1e7 + dt.microsecond * 10)
return to_posix_timstamp
+
else:
raise ValueError("Unexpected type %r" % dataType)