aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorDaoyuan Wang <daoyuan.wang@intel.com>2014-10-28 13:43:25 -0700
committerMichael Armbrust <michael@databricks.com>2014-10-28 13:43:25 -0700
commit47a40f60d62ea69b659959994918d4c640f39d5b (patch)
tree67582dfaec3140d5e247a9170278e729b6af41c4 /python
parent5807cb40ae178f0395c71b967f02aee853ef8bc9 (diff)
downloadspark-47a40f60d62ea69b659959994918d4c640f39d5b.tar.gz
spark-47a40f60d62ea69b659959994918d4c640f39d5b.tar.bz2
spark-47a40f60d62ea69b659959994918d4c640f39d5b.zip
[SPARK-3988][SQL] add public API for date type
Add json and python api for date type. By using Pickle, `java.sql.Date` was serialized as calendar, and recognized in python as `datetime.datetime`. Author: Daoyuan Wang <daoyuan.wang@intel.com> Closes #2901 from adrian-wang/spark3988 and squashes the following commits: c51a24d [Daoyuan Wang] convert datetime to date 5670626 [Daoyuan Wang] minor line combine f760d8e [Daoyuan Wang] fix indent 444f100 [Daoyuan Wang] fix a typo 1d74448 [Daoyuan Wang] fix scala style 8d7dd22 [Daoyuan Wang] add json and python api for date type
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql.py57
1 files changed, 39 insertions, 18 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 7daf306f68..93fd9d4909 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -49,7 +49,7 @@ from pyspark.traceback_utils import SCCallSiteSync
__all__ = [
- "StringType", "BinaryType", "BooleanType", "TimestampType", "DecimalType",
+ "StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType",
"DoubleType", "FloatType", "ByteType", "IntegerType", "LongType",
"ShortType", "ArrayType", "MapType", "StructField", "StructType",
"SQLContext", "HiveContext", "SchemaRDD", "Row"]
@@ -132,6 +132,14 @@ class BooleanType(PrimitiveType):
"""
+class DateType(PrimitiveType):
+
+ """Spark SQL DateType
+
+ The data type representing datetime.date values.
+ """
+
+
class TimestampType(PrimitiveType):
"""Spark SQL TimestampType
@@ -438,7 +446,7 @@ def _parse_datatype_json_value(json_value):
return _all_complex_types[json_value["type"]].fromJson(json_value)
-# Mapping Python types to Spark SQL DateType
+# Mapping Python types to Spark SQL DataType
_type_mappings = {
bool: BooleanType,
int: IntegerType,
@@ -448,8 +456,8 @@ _type_mappings = {
unicode: StringType,
bytearray: BinaryType,
decimal.Decimal: DecimalType,
+ datetime.date: DateType,
datetime.datetime: TimestampType,
- datetime.date: TimestampType,
datetime.time: TimestampType,
}
@@ -656,10 +664,10 @@ def _infer_schema_type(obj, dataType):
"""
Fill the dataType with types infered from obj
- >>> schema = _parse_schema_abstract("a b c")
- >>> row = (1, 1.0, "str")
+ >>> schema = _parse_schema_abstract("a b c d")
+ >>> row = (1, 1.0, "str", datetime.date(2014, 10, 10))
>>> _infer_schema_type(row, schema)
- StructType...IntegerType...DoubleType...StringType...
+ StructType...IntegerType...DoubleType...StringType...DateType...
>>> row = [[1], {"key": (1, 2.0)}]
>>> schema = _parse_schema_abstract("a[] b{c d}")
>>> _infer_schema_type(row, schema)
@@ -703,6 +711,7 @@ _acceptable_types = {
DecimalType: (decimal.Decimal,),
StringType: (str, unicode),
BinaryType: (bytearray,),
+ DateType: (datetime.date,),
TimestampType: (datetime.datetime,),
ArrayType: (list, tuple, array),
MapType: (dict,),
@@ -740,7 +749,7 @@ def _verify_type(obj, dataType):
# subclass of them can not be deserialized in JVM
if type(obj) not in _acceptable_types[_type]:
- raise TypeError("%s can not accept abject in type %s"
+ raise TypeError("%s can not accept object in type %s"
% (dataType, type(obj)))
if isinstance(dataType, ArrayType):
@@ -767,7 +776,7 @@ def _restore_object(dataType, obj):
""" Restore object during unpickling. """
# use id(dataType) as key to speed up lookup in dict
# Because of batched pickling, dataType will be the
- # same object in mose cases.
+ # same object in most cases.
k = id(dataType)
cls = _cached_cls.get(k)
if cls is None:
@@ -782,6 +791,10 @@ def _restore_object(dataType, obj):
def _create_object(cls, v):
""" Create an customized object with class `cls`. """
+ # datetime.date would be deserialized as datetime.datetime
+ # from java type, so we need to set it back.
+ if cls is datetime.date and isinstance(v, datetime.datetime):
+ return v.date()
return cls(v) if v is not None else v
@@ -795,14 +808,16 @@ def _create_getter(dt, i):
return getter
-def _has_struct(dt):
- """Return whether `dt` is or has StructType in it"""
+def _has_struct_or_date(dt):
+ """Return whether `dt` is or has StructType/DateType in it"""
if isinstance(dt, StructType):
return True
elif isinstance(dt, ArrayType):
- return _has_struct(dt.elementType)
+ return _has_struct_or_date(dt.elementType)
elif isinstance(dt, MapType):
- return _has_struct(dt.valueType)
+ return _has_struct_or_date(dt.valueType)
+ elif isinstance(dt, DateType):
+ return True
return False
@@ -815,7 +830,7 @@ def _create_properties(fields):
or keyword.iskeyword(name)):
warnings.warn("field name %s can not be accessed in Python,"
"use position to access it instead" % name)
- if _has_struct(f.dataType):
+ if _has_struct_or_date(f.dataType):
# delay creating object until accessing it
getter = _create_getter(f.dataType, i)
else:
@@ -870,6 +885,9 @@ def _create_cls(dataType):
return Dict
+ elif isinstance(dataType, DateType):
+ return datetime.date
+
elif not isinstance(dataType, StructType):
raise Exception("unexpected data type: %s" % dataType)
@@ -1068,8 +1086,9 @@ class SQLContext(object):
>>> srdd2.collect()
[Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')]
- >>> from datetime import datetime
+ >>> from datetime import date, datetime
>>> rdd = sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0,
+ ... date(2010, 1, 1),
... datetime(2010, 1, 1, 1, 1, 1),
... {"a": 1}, (2,), [1, 2, 3], None)])
>>> schema = StructType([
@@ -1079,6 +1098,7 @@ class SQLContext(object):
... StructField("short2", ShortType(), False),
... StructField("int", IntegerType(), False),
... StructField("float", FloatType(), False),
+ ... StructField("date", DateType(), False),
... StructField("time", TimestampType(), False),
... StructField("map",
... MapType(StringType(), IntegerType(), False), False),
@@ -1088,10 +1108,11 @@ class SQLContext(object):
... StructField("null", DoubleType(), True)])
>>> srdd = sqlCtx.applySchema(rdd, schema)
>>> results = srdd.map(
- ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.time,
- ... x.map["a"], x.struct.b, x.list, x.null))
- >>> results.collect()[0]
- (127, -128, -32768, 32767, 2147483647, 1.0, ...(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
+ ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.date,
+ ... x.time, x.map["a"], x.struct.b, x.list, x.null))
+ >>> results.collect()[0] # doctest: +NORMALIZE_WHITESPACE
+ (127, -128, -32768, 32767, 2147483647, 1.0, datetime.date(2010, 1, 1),
+ datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None)
>>> srdd.registerTempTable("table2")
>>> sqlCtx.sql(