diff options
author | Davies Liu <davies@databricks.com> | 2015-02-27 20:04:16 -0800 |
---|---|---|
committer | Josh Rosen <joshrosen@databricks.com> | 2015-02-27 20:04:16 -0800 |
commit | 576fc54e5c154fc28af1a732a6bea452d0a5cabb (patch) | |
tree | 61431752e0f9bc90086f5c268b5bf040e3111b95 | |
parent | 17b7cc7332c4f89dcdf9ec457c3f825605bf59e9 (diff) | |
download | spark-576fc54e5c154fc28af1a732a6bea452d0a5cabb.tar.gz spark-576fc54e5c154fc28af1a732a6bea452d0a5cabb.tar.bz2 spark-576fc54e5c154fc28af1a732a6bea452d0a5cabb.zip |
[SPARK-6055] [PySpark] fix incorrect DataType.__eq__ (for 1.2)
The eq of DataType is not correct, class cache is not use correctly (created class can not be find by dataType), then it will create lots of classes (saved in _cached_cls), never released.
Also, all same DataType have same hash code, there will be many object in a dict with the same hash code, end with hash attach, it's very slow to access this dict (depends on the implementation of CPython).
This PR also improve the performance of inferSchema (avoid the unnecessary converter of object).
Author: Davies Liu <davies@databricks.com>
Closes #4809 from davies/leak2 and squashes the following commits:
65c222f [Davies Liu] Update sql.py
9b4dadc [Davies Liu] fix __eq__ of singleton
b576107 [Davies Liu] fix tests
6c2909a [Davies Liu] fix incorrect DataType.__eq__
-rw-r--r-- | python/pyspark/sql.py | 67 |
1 files changed, 44 insertions, 23 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index aa5af1bd40..4410925ba0 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -36,6 +36,7 @@ import keyword import warnings import json import re +import weakref from array import array from operator import itemgetter from itertools import imap @@ -68,8 +69,7 @@ class DataType(object): return hash(str(self)) def __eq__(self, other): - return (isinstance(other, self.__class__) and - self.__dict__ == other.__dict__) + return isinstance(other, self.__class__) and self.__dict__ == other.__dict__ def __ne__(self, other): return not self.__eq__(other) @@ -105,10 +105,6 @@ class PrimitiveType(DataType): __metaclass__ = PrimitiveTypeSingleton - def __eq__(self, other): - # because they should be the same object - return self is other - class NullType(PrimitiveType): @@ -251,9 +247,9 @@ class ArrayType(DataType): :param elementType: the data type of elements. :param containsNull: indicates whether the list contains None values. - >>> ArrayType(StringType) == ArrayType(StringType, True) + >>> ArrayType(StringType()) == ArrayType(StringType(), True) True - >>> ArrayType(StringType, False) == ArrayType(StringType) + >>> ArrayType(StringType(), False) == ArrayType(StringType()) False """ self.elementType = elementType @@ -298,11 +294,11 @@ class MapType(DataType): :param valueContainsNull: indicates whether values contains null values. - >>> (MapType(StringType, IntegerType) - ... == MapType(StringType, IntegerType, True)) + >>> (MapType(StringType(), IntegerType()) + ... == MapType(StringType(), IntegerType(), True)) True - >>> (MapType(StringType, IntegerType, False) - ... == MapType(StringType, FloatType)) + >>> (MapType(StringType(), IntegerType(), False) + ... == MapType(StringType(), FloatType())) False """ self.keyType = keyType @@ -351,11 +347,11 @@ class StructField(DataType): to simple type that can be serialized to JSON automatically - >>> (StructField("f1", StringType, True) - ... == StructField("f1", StringType, True)) + >>> (StructField("f1", StringType(), True) + ... == StructField("f1", StringType(), True)) True - >>> (StructField("f1", StringType, True) - ... == StructField("f2", StringType, True)) + >>> (StructField("f1", StringType(), True) + ... == StructField("f2", StringType(), True)) False """ self.name = name @@ -393,13 +389,13 @@ class StructType(DataType): def __init__(self, fields): """Creates a StructType - >>> struct1 = StructType([StructField("f1", StringType, True)]) - >>> struct2 = StructType([StructField("f1", StringType, True)]) + >>> struct1 = StructType([StructField("f1", StringType(), True)]) + >>> struct2 = StructType([StructField("f1", StringType(), True)]) >>> struct1 == struct2 True - >>> struct1 = StructType([StructField("f1", StringType, True)]) - >>> struct2 = StructType([StructField("f1", StringType, True), - ... [StructField("f2", IntegerType, False)]]) + >>> struct1 = StructType([StructField("f1", StringType(), True)]) + >>> struct2 = StructType([StructField("f1", StringType(), True), + ... StructField("f2", IntegerType(), False)]) >>> struct1 == struct2 False """ @@ -499,6 +495,10 @@ _all_complex_types = dict((v.typeName(), v) def _parse_datatype_json_string(json_string): """Parses the given data type JSON string. + + >>> import pickle + >>> LongType() == pickle.loads(pickle.dumps(LongType())) + True >>> def check_datatype(datatype): ... scala_datatype = sqlCtx._ssql_ctx.parseDataType(datatype.json()) ... python_datatype = _parse_datatype_json_string(scala_datatype.json()) @@ -781,8 +781,25 @@ def _merge_type(a, b): return a +def _need_converter(dataType): + if isinstance(dataType, StructType): + return True + elif isinstance(dataType, ArrayType): + return _need_converter(dataType.elementType) + elif isinstance(dataType, MapType): + return _need_converter(dataType.keyType) or _need_converter(dataType.valueType) + elif isinstance(dataType, NullType): + return True + else: + return False + + def _create_converter(dataType): """Create an converter to drop the names of fields in obj """ + + if not _need_converter(dataType): + return lambda x: x + if isinstance(dataType, ArrayType): conv = _create_converter(dataType.elementType) return lambda row: map(conv, row) @@ -800,6 +817,7 @@ def _create_converter(dataType): # dataType must be StructType names = [f.name for f in dataType.fields] converters = [_create_converter(f.dataType) for f in dataType.fields] + convert_fields = any(_need_converter(f.dataType) for f in dataType.fields) def convert_struct(obj): if obj is None: @@ -822,7 +840,10 @@ def _create_converter(dataType): else: raise ValueError("Unexpected obj: %s" % obj) - return tuple([conv(d.get(name)) for name, conv in zip(names, converters)]) + if convert_fields: + return tuple([conv(d.get(name)) for name, conv in zip(names, converters)]) + else: + return tuple([d.get(name) for name in names]) return convert_struct @@ -1039,7 +1060,7 @@ def _verify_type(obj, dataType): _verify_type(v, f.dataType) -_cached_cls = {} +_cached_cls = weakref.WeakValueDictionary() def _restore_object(dataType, obj): |