aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/types.py
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-02-27 20:07:17 -0800
committerJosh Rosen <joshrosen@databricks.com>2015-02-27 20:07:17 -0800
commite0e64ba4b1b8eb72e856286f756c65fa22ab0a36 (patch)
treeca358052d6b572756ecbbf98133a093db3f4cc83 /python/pyspark/sql/types.py
parent8c468a6600e0deb5464990df60148212e64fdecd (diff)
downloadspark-e0e64ba4b1b8eb72e856286f756c65fa22ab0a36.tar.gz
spark-e0e64ba4b1b8eb72e856286f756c65fa22ab0a36.tar.bz2
spark-e0e64ba4b1b8eb72e856286f756c65fa22ab0a36.zip
[SPARK-6055] [PySpark] fix incorrect __eq__ of DataType
The _eq_ of DataType is not correct, class cache is not use correctly (created class can not be find by dataType), then it will create lots of classes (saved in _cached_cls), never released. Also, all same DataType have same hash code, there will be many object in a dict with the same hash code, end with hash attach, it's very slow to access this dict (depends on the implementation of CPython). This PR also improve the performance of inferSchema (avoid the unnecessary converter of object). cc pwendell JoshRosen Author: Davies Liu <davies@databricks.com> Closes #4808 from davies/leak and squashes the following commits: 6a322a4 [Davies Liu] tests refactor 3da44fc [Davies Liu] fix __eq__ of Singleton 534ac90 [Davies Liu] add more checks 46999dc [Davies Liu] fix tests d9ae973 [Davies Liu] fix memory leak in sql
Diffstat (limited to 'python/pyspark/sql/types.py')
-rw-r--r--python/pyspark/sql/types.py120
1 files changed, 73 insertions, 47 deletions
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 0f5dc2be6d..31a861e1fe 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -21,6 +21,7 @@ import keyword
import warnings
import json
import re
+import weakref
from array import array
from operator import itemgetter
@@ -42,8 +43,7 @@ class DataType(object):
return hash(str(self))
def __eq__(self, other):
- return (isinstance(other, self.__class__) and
- self.__dict__ == other.__dict__)
+ return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
def __ne__(self, other):
return not self.__eq__(other)
@@ -64,6 +64,8 @@ class DataType(object):
sort_keys=True)
+# This singleton pattern does not work with pickle, you will get
+# another object after pickle and unpickle
class PrimitiveTypeSingleton(type):
"""Metaclass for PrimitiveType"""
@@ -82,10 +84,6 @@ class PrimitiveType(DataType):
__metaclass__ = PrimitiveTypeSingleton
- def __eq__(self, other):
- # because they should be the same object
- return self is other
-
class NullType(PrimitiveType):
@@ -242,11 +240,12 @@ class ArrayType(DataType):
:param elementType: the data type of elements.
:param containsNull: indicates whether the list contains None values.
- >>> ArrayType(StringType) == ArrayType(StringType, True)
+ >>> ArrayType(StringType()) == ArrayType(StringType(), True)
True
- >>> ArrayType(StringType, False) == ArrayType(StringType)
+ >>> ArrayType(StringType(), False) == ArrayType(StringType())
False
"""
+ assert isinstance(elementType, DataType), "elementType should be DataType"
self.elementType = elementType
self.containsNull = containsNull
@@ -292,13 +291,15 @@ class MapType(DataType):
:param valueContainsNull: indicates whether values contains
null values.
- >>> (MapType(StringType, IntegerType)
- ... == MapType(StringType, IntegerType, True))
+ >>> (MapType(StringType(), IntegerType())
+ ... == MapType(StringType(), IntegerType(), True))
True
- >>> (MapType(StringType, IntegerType, False)
- ... == MapType(StringType, FloatType))
+ >>> (MapType(StringType(), IntegerType(), False)
+ ... == MapType(StringType(), FloatType()))
False
"""
+ assert isinstance(keyType, DataType), "keyType should be DataType"
+ assert isinstance(valueType, DataType), "valueType should be DataType"
self.keyType = keyType
self.valueType = valueType
self.valueContainsNull = valueContainsNull
@@ -348,13 +349,14 @@ class StructField(DataType):
to simple type that can be serialized to JSON
automatically
- >>> (StructField("f1", StringType, True)
- ... == StructField("f1", StringType, True))
+ >>> (StructField("f1", StringType(), True)
+ ... == StructField("f1", StringType(), True))
True
- >>> (StructField("f1", StringType, True)
- ... == StructField("f2", StringType, True))
+ >>> (StructField("f1", StringType(), True)
+ ... == StructField("f2", StringType(), True))
False
"""
+ assert isinstance(dataType, DataType), "dataType should be DataType"
self.name = name
self.dataType = dataType
self.nullable = nullable
@@ -393,16 +395,17 @@ class StructType(DataType):
def __init__(self, fields):
"""Creates a StructType
- >>> struct1 = StructType([StructField("f1", StringType, True)])
- >>> struct2 = StructType([StructField("f1", StringType, True)])
+ >>> struct1 = StructType([StructField("f1", StringType(), True)])
+ >>> struct2 = StructType([StructField("f1", StringType(), True)])
>>> struct1 == struct2
True
- >>> struct1 = StructType([StructField("f1", StringType, True)])
- >>> struct2 = StructType([StructField("f1", StringType, True),
- ... [StructField("f2", IntegerType, False)]])
+ >>> struct1 = StructType([StructField("f1", StringType(), True)])
+ >>> struct2 = StructType([StructField("f1", StringType(), True),
+ ... StructField("f2", IntegerType(), False)])
>>> struct1 == struct2
False
"""
+ assert all(isinstance(f, DataType) for f in fields), "fields should be a list of DataType"
self.fields = fields
def simpleString(self):
@@ -505,20 +508,24 @@ _all_complex_types = dict((v.typeName(), v)
def _parse_datatype_json_string(json_string):
"""Parses the given data type JSON string.
+ >>> import pickle
>>> def check_datatype(datatype):
+ ... pickled = pickle.loads(pickle.dumps(datatype))
+ ... assert datatype == pickled
... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json())
... python_datatype = _parse_datatype_json_string(scala_datatype.json())
- ... return datatype == python_datatype
- >>> all(check_datatype(cls()) for cls in _all_primitive_types.values())
- True
+ ... assert datatype == python_datatype
+ >>> for cls in _all_primitive_types.values():
+ ... check_datatype(cls())
+
>>> # Simple ArrayType.
>>> simple_arraytype = ArrayType(StringType(), True)
>>> check_datatype(simple_arraytype)
- True
+
>>> # Simple MapType.
>>> simple_maptype = MapType(StringType(), LongType())
>>> check_datatype(simple_maptype)
- True
+
>>> # Simple StructType.
>>> simple_structtype = StructType([
... StructField("a", DecimalType(), False),
@@ -526,7 +533,7 @@ def _parse_datatype_json_string(json_string):
... StructField("c", LongType(), True),
... StructField("d", BinaryType(), False)])
>>> check_datatype(simple_structtype)
- True
+
>>> # Complex StructType.
>>> complex_structtype = StructType([
... StructField("simpleArray", simple_arraytype, True),
@@ -535,22 +542,20 @@ def _parse_datatype_json_string(json_string):
... StructField("boolean", BooleanType(), False),
... StructField("withMeta", DoubleType(), False, {"name": "age"})])
>>> check_datatype(complex_structtype)
- True
+
>>> # Complex ArrayType.
>>> complex_arraytype = ArrayType(complex_structtype, True)
>>> check_datatype(complex_arraytype)
- True
+
>>> # Complex MapType.
>>> complex_maptype = MapType(complex_structtype,
... complex_arraytype, False)
>>> check_datatype(complex_maptype)
- True
+
>>> check_datatype(ExamplePointUDT())
- True
>>> structtype_with_udt = StructType([StructField("label", DoubleType(), False),
... StructField("point", ExamplePointUDT(), False)])
>>> check_datatype(structtype_with_udt)
- True
"""
return _parse_datatype_json_value(json.loads(json_string))
@@ -786,8 +791,24 @@ def _merge_type(a, b):
return a
+def _need_converter(dataType):
+ if isinstance(dataType, StructType):
+ return True
+ elif isinstance(dataType, ArrayType):
+ return _need_converter(dataType.elementType)
+ elif isinstance(dataType, MapType):
+ return _need_converter(dataType.keyType) or _need_converter(dataType.valueType)
+ elif isinstance(dataType, NullType):
+ return True
+ else:
+ return False
+
+
def _create_converter(dataType):
"""Create an converter to drop the names of fields in obj """
+ if not _need_converter(dataType):
+ return lambda x: x
+
if isinstance(dataType, ArrayType):
conv = _create_converter(dataType.elementType)
return lambda row: map(conv, row)
@@ -806,13 +827,17 @@ def _create_converter(dataType):
# dataType must be StructType
names = [f.name for f in dataType.fields]
converters = [_create_converter(f.dataType) for f in dataType.fields]
+ convert_fields = any(_need_converter(f.dataType) for f in dataType.fields)
def convert_struct(obj):
if obj is None:
return
if isinstance(obj, (tuple, list)):
- return tuple(conv(v) for v, conv in zip(obj, converters))
+ if convert_fields:
+ return tuple(conv(v) for v, conv in zip(obj, converters))
+ else:
+ return tuple(obj)
if isinstance(obj, dict):
d = obj
@@ -821,7 +846,10 @@ def _create_converter(dataType):
else:
raise ValueError("Unexpected obj: %s" % obj)
- return tuple([conv(d.get(name)) for name, conv in zip(names, converters)])
+ if convert_fields:
+ return tuple([conv(d.get(name)) for name, conv in zip(names, converters)])
+ else:
+ return tuple([d.get(name) for name in names])
return convert_struct
@@ -871,20 +899,20 @@ def _parse_field_abstract(s):
Parse a field in schema abstract
>>> _parse_field_abstract("a")
- StructField(a,None,true)
+ StructField(a,NullType,true)
>>> _parse_field_abstract("b(c d)")
- StructField(b,StructType(...c,None,true),StructField(d...
+ StructField(b,StructType(...c,NullType,true),StructField(d...
>>> _parse_field_abstract("a[]")
- StructField(a,ArrayType(None,true),true)
+ StructField(a,ArrayType(NullType,true),true)
>>> _parse_field_abstract("a{[]}")
- StructField(a,MapType(None,ArrayType(None,true),true),true)
+ StructField(a,MapType(NullType,ArrayType(NullType,true),true),true)
"""
if set(_BRACKETS.keys()) & set(s):
idx = min((s.index(c) for c in _BRACKETS if c in s))
name = s[:idx]
return StructField(name, _parse_schema_abstract(s[idx:]), True)
else:
- return StructField(s, None, True)
+ return StructField(s, NullType(), True)
def _parse_schema_abstract(s):
@@ -898,11 +926,11 @@ def _parse_schema_abstract(s):
>>> _parse_schema_abstract("c{} d{a b}")
StructType...c,MapType...d,MapType...a...b...
>>> _parse_schema_abstract("a b(t)").fields[1]
- StructField(b,StructType(List(StructField(t,None,true))),true)
+ StructField(b,StructType(List(StructField(t,NullType,true))),true)
"""
s = s.strip()
if not s:
- return
+ return NullType()
elif s.startswith('('):
return _parse_schema_abstract(s[1:-1])
@@ -911,7 +939,7 @@ def _parse_schema_abstract(s):
return ArrayType(_parse_schema_abstract(s[1:-1]), True)
elif s.startswith('{'):
- return MapType(None, _parse_schema_abstract(s[1:-1]))
+ return MapType(NullType(), _parse_schema_abstract(s[1:-1]))
parts = _split_schema_abstract(s)
fields = [_parse_field_abstract(p) for p in parts]
@@ -931,7 +959,7 @@ def _infer_schema_type(obj, dataType):
>>> _infer_schema_type(row, schema)
StructType...a,ArrayType...b,MapType(StringType,...c,LongType...
"""
- if dataType is None:
+ if dataType is NullType():
return _infer_type(obj)
if not obj:
@@ -1037,8 +1065,7 @@ def _verify_type(obj, dataType):
for v, f in zip(obj, dataType.fields):
_verify_type(v, f.dataType)
-
-_cached_cls = {}
+_cached_cls = weakref.WeakValueDictionary()
def _restore_object(dataType, obj):
@@ -1233,8 +1260,7 @@ class Row(tuple):
elif kwargs:
# create row objects
names = sorted(kwargs.keys())
- values = tuple(kwargs[n] for n in names)
- row = tuple.__new__(self, values)
+ row = tuple.__new__(self, [kwargs[n] for n in names])
row.__FIELDS__ = names
return row