aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-02-27 20:04:16 -0800
committerJosh Rosen <joshrosen@databricks.com>2015-02-27 20:04:16 -0800
commit576fc54e5c154fc28af1a732a6bea452d0a5cabb (patch)
tree61431752e0f9bc90086f5c268b5bf040e3111b95
parent17b7cc7332c4f89dcdf9ec457c3f825605bf59e9 (diff)
downloadspark-576fc54e5c154fc28af1a732a6bea452d0a5cabb.tar.gz
spark-576fc54e5c154fc28af1a732a6bea452d0a5cabb.tar.bz2
spark-576fc54e5c154fc28af1a732a6bea452d0a5cabb.zip
[SPARK-6055] [PySpark] fix incorrect DataType.__eq__ (for 1.2)
The eq of DataType is not correct, class cache is not use correctly (created class can not be find by dataType), then it will create lots of classes (saved in _cached_cls), never released. Also, all same DataType have same hash code, there will be many object in a dict with the same hash code, end with hash attach, it's very slow to access this dict (depends on the implementation of CPython). This PR also improve the performance of inferSchema (avoid the unnecessary converter of object). Author: Davies Liu <davies@databricks.com> Closes #4809 from davies/leak2 and squashes the following commits: 65c222f [Davies Liu] Update sql.py 9b4dadc [Davies Liu] fix __eq__ of singleton b576107 [Davies Liu] fix tests 6c2909a [Davies Liu] fix incorrect DataType.__eq__
-rw-r--r--python/pyspark/sql.py67
1 files changed, 44 insertions, 23 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index aa5af1bd40..4410925ba0 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -36,6 +36,7 @@ import keyword
import warnings
import json
import re
+import weakref
from array import array
from operator import itemgetter
from itertools import imap
@@ -68,8 +69,7 @@ class DataType(object):
return hash(str(self))
def __eq__(self, other):
- return (isinstance(other, self.__class__) and
- self.__dict__ == other.__dict__)
+ return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
def __ne__(self, other):
return not self.__eq__(other)
@@ -105,10 +105,6 @@ class PrimitiveType(DataType):
__metaclass__ = PrimitiveTypeSingleton
- def __eq__(self, other):
- # because they should be the same object
- return self is other
-
class NullType(PrimitiveType):
@@ -251,9 +247,9 @@ class ArrayType(DataType):
:param elementType: the data type of elements.
:param containsNull: indicates whether the list contains None values.
- >>> ArrayType(StringType) == ArrayType(StringType, True)
+ >>> ArrayType(StringType()) == ArrayType(StringType(), True)
True
- >>> ArrayType(StringType, False) == ArrayType(StringType)
+ >>> ArrayType(StringType(), False) == ArrayType(StringType())
False
"""
self.elementType = elementType
@@ -298,11 +294,11 @@ class MapType(DataType):
:param valueContainsNull: indicates whether values contains
null values.
- >>> (MapType(StringType, IntegerType)
- ... == MapType(StringType, IntegerType, True))
+ >>> (MapType(StringType(), IntegerType())
+ ... == MapType(StringType(), IntegerType(), True))
True
- >>> (MapType(StringType, IntegerType, False)
- ... == MapType(StringType, FloatType))
+ >>> (MapType(StringType(), IntegerType(), False)
+ ... == MapType(StringType(), FloatType()))
False
"""
self.keyType = keyType
@@ -351,11 +347,11 @@ class StructField(DataType):
to simple type that can be serialized to JSON
automatically
- >>> (StructField("f1", StringType, True)
- ... == StructField("f1", StringType, True))
+ >>> (StructField("f1", StringType(), True)
+ ... == StructField("f1", StringType(), True))
True
- >>> (StructField("f1", StringType, True)
- ... == StructField("f2", StringType, True))
+ >>> (StructField("f1", StringType(), True)
+ ... == StructField("f2", StringType(), True))
False
"""
self.name = name
@@ -393,13 +389,13 @@ class StructType(DataType):
def __init__(self, fields):
"""Creates a StructType
- >>> struct1 = StructType([StructField("f1", StringType, True)])
- >>> struct2 = StructType([StructField("f1", StringType, True)])
+ >>> struct1 = StructType([StructField("f1", StringType(), True)])
+ >>> struct2 = StructType([StructField("f1", StringType(), True)])
>>> struct1 == struct2
True
- >>> struct1 = StructType([StructField("f1", StringType, True)])
- >>> struct2 = StructType([StructField("f1", StringType, True),
- ... [StructField("f2", IntegerType, False)]])
+ >>> struct1 = StructType([StructField("f1", StringType(), True)])
+ >>> struct2 = StructType([StructField("f1", StringType(), True),
+ ... StructField("f2", IntegerType(), False)])
>>> struct1 == struct2
False
"""
@@ -499,6 +495,10 @@ _all_complex_types = dict((v.typeName(), v)
def _parse_datatype_json_string(json_string):
"""Parses the given data type JSON string.
+
+ >>> import pickle
+ >>> LongType() == pickle.loads(pickle.dumps(LongType()))
+ True
>>> def check_datatype(datatype):
... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json())
... python_datatype = _parse_datatype_json_string(scala_datatype.json())
@@ -781,8 +781,25 @@ def _merge_type(a, b):
return a
+def _need_converter(dataType):
+ if isinstance(dataType, StructType):
+ return True
+ elif isinstance(dataType, ArrayType):
+ return _need_converter(dataType.elementType)
+ elif isinstance(dataType, MapType):
+ return _need_converter(dataType.keyType) or _need_converter(dataType.valueType)
+ elif isinstance(dataType, NullType):
+ return True
+ else:
+ return False
+
+
def _create_converter(dataType):
"""Create an converter to drop the names of fields in obj """
+
+ if not _need_converter(dataType):
+ return lambda x: x
+
if isinstance(dataType, ArrayType):
conv = _create_converter(dataType.elementType)
return lambda row: map(conv, row)
@@ -800,6 +817,7 @@ def _create_converter(dataType):
# dataType must be StructType
names = [f.name for f in dataType.fields]
converters = [_create_converter(f.dataType) for f in dataType.fields]
+ convert_fields = any(_need_converter(f.dataType) for f in dataType.fields)
def convert_struct(obj):
if obj is None:
@@ -822,7 +840,10 @@ def _create_converter(dataType):
else:
raise ValueError("Unexpected obj: %s" % obj)
- return tuple([conv(d.get(name)) for name, conv in zip(names, converters)])
+ if convert_fields:
+ return tuple([conv(d.get(name)) for name, conv in zip(names, converters)])
+ else:
+ return tuple([d.get(name) for name in names])
return convert_struct
@@ -1039,7 +1060,7 @@ def _verify_type(obj, dataType):
_verify_type(v, f.dataType)
-_cached_cls = {}
+_cached_cls = weakref.WeakValueDictionary()
def _restore_object(dataType, obj):