diff options
author | Daoyuan Wang <daoyuan.wang@intel.com> | 2014-10-28 13:43:25 -0700 |
---|---|---|
committer | Michael Armbrust <michael@databricks.com> | 2014-10-28 13:43:25 -0700 |
commit | 47a40f60d62ea69b659959994918d4c640f39d5b (patch) | |
tree | 67582dfaec3140d5e247a9170278e729b6af41c4 /python | |
parent | 5807cb40ae178f0395c71b967f02aee853ef8bc9 (diff) | |
download | spark-47a40f60d62ea69b659959994918d4c640f39d5b.tar.gz spark-47a40f60d62ea69b659959994918d4c640f39d5b.tar.bz2 spark-47a40f60d62ea69b659959994918d4c640f39d5b.zip |
[SPARK-3988][SQL] add public API for date type
Add json and python api for date type.
By using Pickle, `java.sql.Date` was serialized as calendar, and recognized in python as `datetime.datetime`.
Author: Daoyuan Wang <daoyuan.wang@intel.com>
Closes #2901 from adrian-wang/spark3988 and squashes the following commits:
c51a24d [Daoyuan Wang] convert datetime to date
5670626 [Daoyuan Wang] minor line combine
f760d8e [Daoyuan Wang] fix indent
444f100 [Daoyuan Wang] fix a typo
1d74448 [Daoyuan Wang] fix scala style
8d7dd22 [Daoyuan Wang] add json and python api for date type
Diffstat (limited to 'python')
-rw-r--r-- | python/pyspark/sql.py | 57 |
1 files changed, 39 insertions, 18 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 7daf306f68..93fd9d4909 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -49,7 +49,7 @@ from pyspark.traceback_utils import SCCallSiteSync __all__ = [ - "StringType", "BinaryType", "BooleanType", "TimestampType", "DecimalType", + "StringType", "BinaryType", "BooleanType", "DateType", "TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType", "LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType", "SQLContext", "HiveContext", "SchemaRDD", "Row"] @@ -132,6 +132,14 @@ class BooleanType(PrimitiveType): """ +class DateType(PrimitiveType): + + """Spark SQL DateType + + The data type representing datetime.date values. + """ + + class TimestampType(PrimitiveType): """Spark SQL TimestampType @@ -438,7 +446,7 @@ def _parse_datatype_json_value(json_value): return _all_complex_types[json_value["type"]].fromJson(json_value) -# Mapping Python types to Spark SQL DateType +# Mapping Python types to Spark SQL DataType _type_mappings = { bool: BooleanType, int: IntegerType, @@ -448,8 +456,8 @@ _type_mappings = { unicode: StringType, bytearray: BinaryType, decimal.Decimal: DecimalType, + datetime.date: DateType, datetime.datetime: TimestampType, - datetime.date: TimestampType, datetime.time: TimestampType, } @@ -656,10 +664,10 @@ def _infer_schema_type(obj, dataType): """ Fill the dataType with types infered from obj - >>> schema = _parse_schema_abstract("a b c") - >>> row = (1, 1.0, "str") + >>> schema = _parse_schema_abstract("a b c d") + >>> row = (1, 1.0, "str", datetime.date(2014, 10, 10)) >>> _infer_schema_type(row, schema) - StructType...IntegerType...DoubleType...StringType... + StructType...IntegerType...DoubleType...StringType...DateType... >>> row = [[1], {"key": (1, 2.0)}] >>> schema = _parse_schema_abstract("a[] b{c d}") >>> _infer_schema_type(row, schema) @@ -703,6 +711,7 @@ _acceptable_types = { DecimalType: (decimal.Decimal,), StringType: (str, unicode), BinaryType: (bytearray,), + DateType: (datetime.date,), TimestampType: (datetime.datetime,), ArrayType: (list, tuple, array), MapType: (dict,), @@ -740,7 +749,7 @@ def _verify_type(obj, dataType): # subclass of them can not be deserialized in JVM if type(obj) not in _acceptable_types[_type]: - raise TypeError("%s can not accept abject in type %s" + raise TypeError("%s can not accept object in type %s" % (dataType, type(obj))) if isinstance(dataType, ArrayType): @@ -767,7 +776,7 @@ def _restore_object(dataType, obj): """ Restore object during unpickling. """ # use id(dataType) as key to speed up lookup in dict # Because of batched pickling, dataType will be the - # same object in mose cases. + # same object in most cases. k = id(dataType) cls = _cached_cls.get(k) if cls is None: @@ -782,6 +791,10 @@ def _restore_object(dataType, obj): def _create_object(cls, v): """ Create an customized object with class `cls`. """ + # datetime.date would be deserialized as datetime.datetime + # from java type, so we need to set it back. + if cls is datetime.date and isinstance(v, datetime.datetime): + return v.date() return cls(v) if v is not None else v @@ -795,14 +808,16 @@ def _create_getter(dt, i): return getter -def _has_struct(dt): - """Return whether `dt` is or has StructType in it""" +def _has_struct_or_date(dt): + """Return whether `dt` is or has StructType/DateType in it""" if isinstance(dt, StructType): return True elif isinstance(dt, ArrayType): - return _has_struct(dt.elementType) + return _has_struct_or_date(dt.elementType) elif isinstance(dt, MapType): - return _has_struct(dt.valueType) + return _has_struct_or_date(dt.valueType) + elif isinstance(dt, DateType): + return True return False @@ -815,7 +830,7 @@ def _create_properties(fields): or keyword.iskeyword(name)): warnings.warn("field name %s can not be accessed in Python," "use position to access it instead" % name) - if _has_struct(f.dataType): + if _has_struct_or_date(f.dataType): # delay creating object until accessing it getter = _create_getter(f.dataType, i) else: @@ -870,6 +885,9 @@ def _create_cls(dataType): return Dict + elif isinstance(dataType, DateType): + return datetime.date + elif not isinstance(dataType, StructType): raise Exception("unexpected data type: %s" % dataType) @@ -1068,8 +1086,9 @@ class SQLContext(object): >>> srdd2.collect() [Row(field1=1, field2=u'row1'),..., Row(field1=3, field2=u'row3')] - >>> from datetime import datetime + >>> from datetime import date, datetime >>> rdd = sc.parallelize([(127, -128L, -32768, 32767, 2147483647L, 1.0, + ... date(2010, 1, 1), ... datetime(2010, 1, 1, 1, 1, 1), ... {"a": 1}, (2,), [1, 2, 3], None)]) >>> schema = StructType([ @@ -1079,6 +1098,7 @@ class SQLContext(object): ... StructField("short2", ShortType(), False), ... StructField("int", IntegerType(), False), ... StructField("float", FloatType(), False), + ... StructField("date", DateType(), False), ... StructField("time", TimestampType(), False), ... StructField("map", ... MapType(StringType(), IntegerType(), False), False), @@ -1088,10 +1108,11 @@ class SQLContext(object): ... StructField("null", DoubleType(), True)]) >>> srdd = sqlCtx.applySchema(rdd, schema) >>> results = srdd.map( - ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.time, - ... x.map["a"], x.struct.b, x.list, x.null)) - >>> results.collect()[0] - (127, -128, -32768, 32767, 2147483647, 1.0, ...(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) + ... lambda x: (x.byte1, x.byte2, x.short1, x.short2, x.int, x.float, x.date, + ... x.time, x.map["a"], x.struct.b, x.list, x.null)) + >>> results.collect()[0] # doctest: +NORMALIZE_WHITESPACE + (127, -128, -32768, 32767, 2147483647, 1.0, datetime.date(2010, 1, 1), + datetime.datetime(2010, 1, 1, 1, 1, 1), 1, 2, [1, 2, 3], None) >>> srdd.registerTempTable("table2") >>> sqlCtx.sql( |