aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/types.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/sql/types.py')
-rw-r--r--python/pyspark/sql/types.py120
1 files changed, 73 insertions, 47 deletions
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 0f5dc2be6d..31a861e1fe 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -21,6 +21,7 @@ import keyword
import warnings
import json
import re
+import weakref
from array import array
from operator import itemgetter
@@ -42,8 +43,7 @@ class DataType(object):
return hash(str(self))
def __eq__(self, other):
- return (isinstance(other, self.__class__) and
- self.__dict__ == other.__dict__)
+ return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
def __ne__(self, other):
return not self.__eq__(other)
@@ -64,6 +64,8 @@ class DataType(object):
sort_keys=True)
+# This singleton pattern does not work with pickle, you will get
+# another object after pickle and unpickle
class PrimitiveTypeSingleton(type):
"""Metaclass for PrimitiveType"""
@@ -82,10 +84,6 @@ class PrimitiveType(DataType):
__metaclass__ = PrimitiveTypeSingleton
- def __eq__(self, other):
- # because they should be the same object
- return self is other
-
class NullType(PrimitiveType):
@@ -242,11 +240,12 @@ class ArrayType(DataType):
:param elementType: the data type of elements.
:param containsNull: indicates whether the list contains None values.
- >>> ArrayType(StringType) == ArrayType(StringType, True)
+ >>> ArrayType(StringType()) == ArrayType(StringType(), True)
True
- >>> ArrayType(StringType, False) == ArrayType(StringType)
+ >>> ArrayType(StringType(), False) == ArrayType(StringType())
False
"""
+ assert isinstance(elementType, DataType), "elementType should be DataType"
self.elementType = elementType
self.containsNull = containsNull
@@ -292,13 +291,15 @@ class MapType(DataType):
:param valueContainsNull: indicates whether values contains
null values.
- >>> (MapType(StringType, IntegerType)
- ... == MapType(StringType, IntegerType, True))
+ >>> (MapType(StringType(), IntegerType())
+ ... == MapType(StringType(), IntegerType(), True))
True
- >>> (MapType(StringType, IntegerType, False)
- ... == MapType(StringType, FloatType))
+ >>> (MapType(StringType(), IntegerType(), False)
+ ... == MapType(StringType(), FloatType()))
False
"""
+ assert isinstance(keyType, DataType), "keyType should be DataType"
+ assert isinstance(valueType, DataType), "valueType should be DataType"
self.keyType = keyType
self.valueType = valueType
self.valueContainsNull = valueContainsNull
@@ -348,13 +349,14 @@ class StructField(DataType):
to simple type that can be serialized to JSON
automatically
- >>> (StructField("f1", StringType, True)
- ... == StructField("f1", StringType, True))
+ >>> (StructField("f1", StringType(), True)
+ ... == StructField("f1", StringType(), True))
True
- >>> (StructField("f1", StringType, True)
- ... == StructField("f2", StringType, True))
+ >>> (StructField("f1", StringType(), True)
+ ... == StructField("f2", StringType(), True))
False
"""
+ assert isinstance(dataType, DataType), "dataType should be DataType"
self.name = name
self.dataType = dataType
self.nullable = nullable
@@ -393,16 +395,17 @@ class StructType(DataType):
def __init__(self, fields):
"""Creates a StructType
- >>> struct1 = StructType([StructField("f1", StringType, True)])
- >>> struct2 = StructType([StructField("f1", StringType, True)])
+ >>> struct1 = StructType([StructField("f1", StringType(), True)])
+ >>> struct2 = StructType([StructField("f1", StringType(), True)])
>>> struct1 == struct2
True
- >>> struct1 = StructType([StructField("f1", StringType, True)])
- >>> struct2 = StructType([StructField("f1", StringType, True),
- ... [StructField("f2", IntegerType, False)]])
+ >>> struct1 = StructType([StructField("f1", StringType(), True)])
+ >>> struct2 = StructType([StructField("f1", StringType(), True),
+ ... StructField("f2", IntegerType(), False)])
>>> struct1 == struct2
False
"""
+ assert all(isinstance(f, DataType) for f in fields), "fields should be a list of DataType"
self.fields = fields
def simpleString(self):
@@ -505,20 +508,24 @@ _all_complex_types = dict((v.typeName(), v)
def _parse_datatype_json_string(json_string):
"""Parses the given data type JSON string.
+ >>> import pickle
>>> def check_datatype(datatype):
+ ... pickled = pickle.loads(pickle.dumps(datatype))
+ ... assert datatype == pickled
... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json())
... python_datatype = _parse_datatype_json_string(scala_datatype.json())
- ... return datatype == python_datatype
- >>> all(check_datatype(cls()) for cls in _all_primitive_types.values())
- True
+ ... assert datatype == python_datatype
+ >>> for cls in _all_primitive_types.values():
+ ... check_datatype(cls())
+
>>> # Simple ArrayType.
>>> simple_arraytype = ArrayType(StringType(), True)
>>> check_datatype(simple_arraytype)
- True
+
>>> # Simple MapType.
>>> simple_maptype = MapType(StringType(), LongType())
>>> check_datatype(simple_maptype)
- True
+
>>> # Simple StructType.
>>> simple_structtype = StructType([
... StructField("a", DecimalType(), False),
@@ -526,7 +533,7 @@ def _parse_datatype_json_string(json_string):
... StructField("c", LongType(), True),
... StructField("d", BinaryType(), False)])
>>> check_datatype(simple_structtype)
- True
+
>>> # Complex StructType.
>>> complex_structtype = StructType([
... StructField("simpleArray", simple_arraytype, True),
@@ -535,22 +542,20 @@ def _parse_datatype_json_string(json_string):
... StructField("boolean", BooleanType(), False),
... StructField("withMeta", DoubleType(), False, {"name": "age"})])
>>> check_datatype(complex_structtype)
- True
+
>>> # Complex ArrayType.
>>> complex_arraytype = ArrayType(complex_structtype, True)
>>> check_datatype(complex_arraytype)
- True
+
>>> # Complex MapType.
>>> complex_maptype = MapType(complex_structtype,
... complex_arraytype, False)
>>> check_datatype(complex_maptype)
- True
+
>>> check_datatype(ExamplePointUDT())
- True
>>> structtype_with_udt = StructType([StructField("label", DoubleType(), False),
... StructField("point", ExamplePointUDT(), False)])
>>> check_datatype(structtype_with_udt)
- True
"""
return _parse_datatype_json_value(json.loads(json_string))
@@ -786,8 +791,24 @@ def _merge_type(a, b):
return a
+def _need_converter(dataType):
+ if isinstance(dataType, StructType):
+ return True
+ elif isinstance(dataType, ArrayType):
+ return _need_converter(dataType.elementType)
+ elif isinstance(dataType, MapType):
+ return _need_converter(dataType.keyType) or _need_converter(dataType.valueType)
+ elif isinstance(dataType, NullType):
+ return True
+ else:
+ return False
+
+
def _create_converter(dataType):
"""Create an converter to drop the names of fields in obj """
+ if not _need_converter(dataType):
+ return lambda x: x
+
if isinstance(dataType, ArrayType):
conv = _create_converter(dataType.elementType)
return lambda row: map(conv, row)
@@ -806,13 +827,17 @@ def _create_converter(dataType):
# dataType must be StructType
names = [f.name for f in dataType.fields]
converters = [_create_converter(f.dataType) for f in dataType.fields]
+ convert_fields = any(_need_converter(f.dataType) for f in dataType.fields)
def convert_struct(obj):
if obj is None:
return
if isinstance(obj, (tuple, list)):
- return tuple(conv(v) for v, conv in zip(obj, converters))
+ if convert_fields:
+ return tuple(conv(v) for v, conv in zip(obj, converters))
+ else:
+ return tuple(obj)
if isinstance(obj, dict):
d = obj
@@ -821,7 +846,10 @@ def _create_converter(dataType):
else:
raise ValueError("Unexpected obj: %s" % obj)
- return tuple([conv(d.get(name)) for name, conv in zip(names, converters)])
+ if convert_fields:
+ return tuple([conv(d.get(name)) for name, conv in zip(names, converters)])
+ else:
+ return tuple([d.get(name) for name in names])
return convert_struct
@@ -871,20 +899,20 @@ def _parse_field_abstract(s):
Parse a field in schema abstract
>>> _parse_field_abstract("a")
- StructField(a,None,true)
+ StructField(a,NullType,true)
>>> _parse_field_abstract("b(c d)")
- StructField(b,StructType(...c,None,true),StructField(d...
+ StructField(b,StructType(...c,NullType,true),StructField(d...
>>> _parse_field_abstract("a[]")
- StructField(a,ArrayType(None,true),true)
+ StructField(a,ArrayType(NullType,true),true)
>>> _parse_field_abstract("a{[]}")
- StructField(a,MapType(None,ArrayType(None,true),true),true)
+ StructField(a,MapType(NullType,ArrayType(NullType,true),true),true)
"""
if set(_BRACKETS.keys()) & set(s):
idx = min((s.index(c) for c in _BRACKETS if c in s))
name = s[:idx]
return StructField(name, _parse_schema_abstract(s[idx:]), True)
else:
- return StructField(s, None, True)
+ return StructField(s, NullType(), True)
def _parse_schema_abstract(s):
@@ -898,11 +926,11 @@ def _parse_schema_abstract(s):
>>> _parse_schema_abstract("c{} d{a b}")
StructType...c,MapType...d,MapType...a...b...
>>> _parse_schema_abstract("a b(t)").fields[1]
- StructField(b,StructType(List(StructField(t,None,true))),true)
+ StructField(b,StructType(List(StructField(t,NullType,true))),true)
"""
s = s.strip()
if not s:
- return
+ return NullType()
elif s.startswith('('):
return _parse_schema_abstract(s[1:-1])
@@ -911,7 +939,7 @@ def _parse_schema_abstract(s):
return ArrayType(_parse_schema_abstract(s[1:-1]), True)
elif s.startswith('{'):
- return MapType(None, _parse_schema_abstract(s[1:-1]))
+ return MapType(NullType(), _parse_schema_abstract(s[1:-1]))
parts = _split_schema_abstract(s)
fields = [_parse_field_abstract(p) for p in parts]
@@ -931,7 +959,7 @@ def _infer_schema_type(obj, dataType):
>>> _infer_schema_type(row, schema)
StructType...a,ArrayType...b,MapType(StringType,...c,LongType...
"""
- if dataType is None:
+ if dataType is NullType():
return _infer_type(obj)
if not obj:
@@ -1037,8 +1065,7 @@ def _verify_type(obj, dataType):
for v, f in zip(obj, dataType.fields):
_verify_type(v, f.dataType)
-
-_cached_cls = {}
+_cached_cls = weakref.WeakValueDictionary()
def _restore_object(dataType, obj):
@@ -1233,8 +1260,7 @@ class Row(tuple):
elif kwargs:
# create row objects
names = sorted(kwargs.keys())
- values = tuple(kwargs[n] for n in names)
- row = tuple.__new__(self, values)
+ row = tuple.__new__(self, [kwargs[n] for n in names])
row.__FIELDS__ = names
return row