aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark')
-rw-r--r--python/pyspark/sql/types.py39
1 files changed, 20 insertions, 19 deletions
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 7e0124b136..ef76d84c00 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -567,8 +567,8 @@ def _infer_schema(row):
elif isinstance(row, (tuple, list)):
if hasattr(row, "_fields"): # namedtuple
items = zip(row._fields, tuple(row))
- elif hasattr(row, "__FIELDS__"): # Row
- items = zip(row.__FIELDS__, tuple(row))
+ elif hasattr(row, "__fields__"): # Row
+ items = zip(row.__fields__, tuple(row))
else:
names = ['_%d' % i for i in range(1, len(row) + 1)]
items = zip(names, row)
@@ -647,7 +647,7 @@ def _python_to_sql_converter(dataType):
if isinstance(obj, dict):
return tuple(c(obj.get(n)) for n, c in zip(names, converters))
elif isinstance(obj, tuple):
- if hasattr(obj, "_fields") or hasattr(obj, "__FIELDS__"):
+ if hasattr(obj, "_fields") or hasattr(obj, "__fields__"):
return tuple(c(v) for c, v in zip(converters, obj))
elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs
d = dict(obj)
@@ -997,12 +997,13 @@ def _restore_object(dataType, obj):
# same object in most cases.
k = id(dataType)
cls = _cached_cls.get(k)
- if cls is None:
+ if cls is None or cls.__datatype is not dataType:
# use dataType as key to avoid create multiple class
cls = _cached_cls.get(dataType)
if cls is None:
cls = _create_cls(dataType)
_cached_cls[dataType] = cls
+ cls.__datatype = dataType
_cached_cls[k] = cls
return cls(obj)
@@ -1119,8 +1120,8 @@ def _create_cls(dataType):
class Row(tuple):
""" Row in DataFrame """
- __DATATYPE__ = dataType
- __FIELDS__ = tuple(f.name for f in dataType.fields)
+ __datatype = dataType
+ __fields__ = tuple(f.name for f in dataType.fields)
__slots__ = ()
# create property for fast access
@@ -1128,22 +1129,22 @@ def _create_cls(dataType):
def asDict(self):
""" Return as a dict """
- return dict((n, getattr(self, n)) for n in self.__FIELDS__)
+ return dict((n, getattr(self, n)) for n in self.__fields__)
def __repr__(self):
# call collect __repr__ for nested objects
return ("Row(%s)" % ", ".join("%s=%r" % (n, getattr(self, n))
- for n in self.__FIELDS__))
+ for n in self.__fields__))
def __reduce__(self):
- return (_restore_object, (self.__DATATYPE__, tuple(self)))
+ return (_restore_object, (self.__datatype, tuple(self)))
return Row
def _create_row(fields, values):
row = Row(*values)
- row.__FIELDS__ = fields
+ row.__fields__ = fields
return row
@@ -1183,7 +1184,7 @@ class Row(tuple):
# create row objects
names = sorted(kwargs.keys())
row = tuple.__new__(self, [kwargs[n] for n in names])
- row.__FIELDS__ = names
+ row.__fields__ = names
return row
else:
@@ -1193,11 +1194,11 @@ class Row(tuple):
"""
Return as an dict
"""
- if not hasattr(self, "__FIELDS__"):
+ if not hasattr(self, "__fields__"):
raise TypeError("Cannot convert a Row class into dict")
- return dict(zip(self.__FIELDS__, self))
+ return dict(zip(self.__fields__, self))
- # let obect acs like class
+ # let object acts like class
def __call__(self, *args):
"""create new Row object"""
return _create_row(self, args)
@@ -1208,21 +1209,21 @@ class Row(tuple):
try:
# it will be slow when it has many fields,
# but this will not be used in normal cases
- idx = self.__FIELDS__.index(item)
+ idx = self.__fields__.index(item)
return self[idx]
except IndexError:
raise AttributeError(item)
def __reduce__(self):
- if hasattr(self, "__FIELDS__"):
- return (_create_row, (self.__FIELDS__, tuple(self)))
+ if hasattr(self, "__fields__"):
+ return (_create_row, (self.__fields__, tuple(self)))
else:
return tuple.__reduce__(self)
def __repr__(self):
- if hasattr(self, "__FIELDS__"):
+ if hasattr(self, "__fields__"):
return "Row(%s)" % ", ".join("%s=%r" % (k, v)
- for k, v in zip(self.__FIELDS__, self))
+ for k, v in zip(self.__fields__, tuple(self)))
else:
return "<Row(%s)>" % ", ".join(self)