From adfd366814499c0540a15dd6017091ba8c0f05da Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 15 May 2015 20:05:26 -0700 Subject: [SPARK-7073] [SQL] [PySpark] Clean up SQL data type hierarchy in Python Author: Davies Liu Closes #6206 from davies/sql_type and squashes the following commits: 33d6860 [Davies Liu] [SPARK-7073] [SQL] [PySpark] Clean up SQL data type hierarchy in Python --- python/pyspark/sql/_types.py | 76 +++++++++++++++++++++++++++----------------- 1 file changed, 46 insertions(+), 30 deletions(-) (limited to 'python/pyspark') diff --git a/python/pyspark/sql/_types.py b/python/pyspark/sql/_types.py index 629c3a9451..9e7e9f04bc 100644 --- a/python/pyspark/sql/_types.py +++ b/python/pyspark/sql/_types.py @@ -73,56 +73,74 @@ class DataType(object): # This singleton pattern does not work with pickle, you will get # another object after pickle and unpickle -class PrimitiveTypeSingleton(type): - """Metaclass for PrimitiveType""" +class DataTypeSingleton(type): + """Metaclass for DataType""" _instances = {} def __call__(cls): if cls not in cls._instances: - cls._instances[cls] = super(PrimitiveTypeSingleton, cls).__call__() + cls._instances[cls] = super(DataTypeSingleton, cls).__call__() return cls._instances[cls] -class PrimitiveType(DataType): - """Spark SQL PrimitiveType""" +class NullType(DataType): + """Null type. - __metaclass__ = PrimitiveTypeSingleton + The data type representing None, used for the types that cannot be inferred. + """ + __metaclass__ = DataTypeSingleton -class NullType(PrimitiveType): - """Null type. - The data type representing None, used for the types that cannot be inferred. +class AtomicType(DataType): + """An internal type used to represent everything that is not + null, UDTs, arrays, structs, and maps.""" + + __metaclass__ = DataTypeSingleton + + +class NumericType(AtomicType): + """Numeric data types. """ -class StringType(PrimitiveType): +class IntegralType(NumericType): + """Integral data types. + """ + + +class FractionalType(NumericType): + """Fractional data types. + """ + + +class StringType(AtomicType): """String data type. """ -class BinaryType(PrimitiveType): +class BinaryType(AtomicType): """Binary (byte array) data type. """ -class BooleanType(PrimitiveType): +class BooleanType(AtomicType): """Boolean data type. """ -class DateType(PrimitiveType): +class DateType(AtomicType): """Date (datetime.date) data type. """ -class TimestampType(PrimitiveType): +class TimestampType(AtomicType): """Timestamp (datetime.datetime) data type. """ -class DecimalType(DataType): +class DecimalType(FractionalType): """Decimal (decimal.Decimal) data type. """ @@ -150,31 +168,31 @@ class DecimalType(DataType): return "DecimalType()" -class DoubleType(PrimitiveType): +class DoubleType(FractionalType): """Double data type, representing double precision floats. """ -class FloatType(PrimitiveType): +class FloatType(FractionalType): """Float data type, representing single precision floats. """ -class ByteType(PrimitiveType): +class ByteType(IntegralType): """Byte data type, i.e. a signed integer in a single byte. """ def simpleString(self): return 'tinyint' -class IntegerType(PrimitiveType): +class IntegerType(IntegralType): """Int data type, i.e. a signed 32-bit integer. """ def simpleString(self): return 'int' -class LongType(PrimitiveType): +class LongType(IntegralType): """Long data type, i.e. a signed 64-bit integer. If the values are beyond the range of [-9223372036854775808, 9223372036854775807], @@ -184,7 +202,7 @@ class LongType(PrimitiveType): return 'bigint' -class ShortType(PrimitiveType): +class ShortType(IntegralType): """Short data type, i.e. a signed 16-bit integer. """ def simpleString(self): @@ -426,11 +444,9 @@ class UserDefinedType(DataType): return type(self) == type(other) -_all_primitive_types = dict((v.typeName(), v) - for v in list(globals().values()) - if (type(v) is type or type(v) is PrimitiveTypeSingleton) - and v.__base__ == PrimitiveType) - +_atomic_types = [StringType, BinaryType, BooleanType, DecimalType, FloatType, DoubleType, + ByteType, ShortType, IntegerType, LongType, DateType, TimestampType] +_all_atomic_types = dict((t.typeName(), t) for t in _atomic_types) _all_complex_types = dict((v.typeName(), v) for v in [ArrayType, MapType, StructType]) @@ -444,7 +460,7 @@ def _parse_datatype_json_string(json_string): ... scala_datatype = sqlContext._ssql_ctx.parseDataType(datatype.json()) ... python_datatype = _parse_datatype_json_string(scala_datatype.json()) ... assert datatype == python_datatype - >>> for cls in _all_primitive_types.values(): + >>> for cls in _all_atomic_types.values(): ... check_datatype(cls()) >>> # Simple ArrayType. @@ -494,8 +510,8 @@ _FIXED_DECIMAL = re.compile("decimal\\((\\d+),(\\d+)\\)") def _parse_datatype_json_value(json_value): if not isinstance(json_value, dict): - if json_value in _all_primitive_types.keys(): - return _all_primitive_types[json_value]() + if json_value in _all_atomic_types.keys(): + return _all_atomic_types[json_value]() elif json_value == 'decimal': return DecimalType() elif _FIXED_DECIMAL.match(json_value): @@ -1125,7 +1141,7 @@ def _create_cls(dataType): return lambda datum: dataType.deserialize(datum) elif not isinstance(dataType, StructType): - # no wrapper for primitive types + # no wrapper for atomic types return lambda x: x class Row(tuple): -- cgit v1.2.3