#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
import sys
import decimal
import time
import datetime
import calendar
import json
import re
from array import array
if sys.version >= "3":
long = int
unicode = str
from py4j.protocol import register_input_converter
from py4j.java_gateway import JavaClass
__all__ = [
"DataType", "NullType", "StringType", "BinaryType", "BooleanType", "DateType",
"TimestampType", "DecimalType", "DoubleType", "FloatType", "ByteType", "IntegerType",
"LongType", "ShortType", "ArrayType", "MapType", "StructField", "StructType"]
class DataType(object):
"""Base class for data types."""
def __repr__(self):
return self.__class__.__name__
def __hash__(self):
return hash(str(self))
def __eq__(self, other):
return isinstance(other, self.__class__) and self.__dict__ == other.__dict__
def __ne__(self, other):
return not self.__eq__(other)
@classmethod
def typeName(cls):
return cls.__name__[:-4].lower()
def simpleString(self):
return self.typeName()
def jsonValue(self):
return self.typeName()
def json(self):
return json.dumps(self.jsonValue(),
separators=(',', ':'),
sort_keys=True)
def needConversion(self):
"""
Does this type need to conversion between Python object and internal SQL object.
This is used to avoid the unnecessary conversion for ArrayType/MapType/StructType.
"""
return False
def toInternal(self, obj):
"""
Converts a Python object into an internal SQL object.
"""
return obj
def fromInternal(self, obj):
"""
Converts an internal SQL object into a native Python object.
"""
return obj
# This singleton pattern does not work with pickle, you will get
# another object after pickle and unpickle
class DataTypeSingleton(type):
"""Metaclass for DataType"""
_instances = {}
def __call__(cls):
if cls not in cls._instances:
cls._instances[cls] = super(DataTypeSingleton, cls).__call__()
return cls._instances[cls]
class NullType(DataType):
"""Null type.
The data type representing None, used for the types that cannot be inferred.
"""
__metaclass__ = DataTypeSingleton
class AtomicType(DataType):
"""An internal type used to represent everything that is not
null, UDTs, arrays, structs, and maps."""
class NumericType(AtomicType):
"""Numeric data types.
"""
class IntegralType(NumericType):
"""Integral data types.
"""
__metaclass__ = DataTypeSingleton
class FractionalType(NumericType):
"""Fractional data types.
"""
class StringType(AtomicType):
"""String data type.
"""
__metaclass__ = DataTypeSingleton
class BinaryType(AtomicType):
"""Binary (byte array) data type.
"""
__metaclass__ = DataTypeSingleton
class BooleanType(AtomicType):
"""Boolean data type.
"""
__metaclass__ = DataTypeSingleton
class DateType(AtomicType):
"""Date (datetime.date) data type.
"""
__metaclass__ = DataTypeSingleton
EPOCH_ORDINAL = datetime.datetime(1970, 1, 1).toordinal()
def needConversion(self):
return True
def toInternal(self, d):
return d and d.toordinal() - self.EPOCH_ORDINAL
def fromInternal(self, v):
return v and datetime.date.fromordinal(v + self.EPOCH_ORDINAL)
class TimestampType(AtomicType):
"""Timestamp (datetime.datetime) data type.
"""
__metaclass__ = DataTypeSingleton
def needConversion(self):
return True
def toInternal(self, dt):
if dt is not None:
seconds = (calendar.timegm(dt.utctimetuple()) if dt.tzinfo
else time.mktime(dt.timetuple()))
return int(seconds * 1e6 + dt.microsecond)
def fromInternal(self, ts):
if ts is not None:
return datetime.datetime.fromtimestamp(ts / 1e6)
class DecimalType(FractionalType):
"""Decimal (decimal.Decimal) data type.
"""
def __init__(self, precision=None, scale=None):
self.precision = precision
self.scale = scale
self.hasPrecisionInfo = precision is not None
def simpleString(self):
if self.hasPrecisionInfo:
return "decimal(%d,%d)" % (self.precision, self.scale)
else:
return "decimal(10,0)"
def jsonValue(self):
if self.hasPrecisionInfo:
return "decimal(%d,%d)" % (self.precision, self.scale)
else:
return "decimal"
def __repr__(self):
if self.hasPrecisionInfo:
return "DecimalType(%d,%d)" % (self.precision, self.scale)
else:
return "DecimalType()"
class DoubleType(FractionalType):
"""Double data type, representing double precision floats.
"""
__metaclass__ = DataTypeSingleton
class FloatType(FractionalType):
"""Float data type, representing single precision floats.
"""
__metaclass__ = DataTypeSingleton
class ByteType(IntegralType):
"""Byte data type, i.e. a signed integer in a single byte.
"""
def simpleString(self):
return 'tinyint'
class IntegerType(IntegralType):
"""Int data type, i.e. a signed 32-bit integer.
"""
def simpleString(self):
return 'int'
class LongType(IntegralType):
"""Long data type, i.e. a signed 64-bit integer.
If the values are beyond the range of [-9223372036854775808, 9223372036854775807],
please use :class:`DecimalType`.
"""
def simpleString(self):
return 'bigint'
class ShortType(IntegralType):
"""Short data type, i.e. a signed 16-bit integer.
"""
def simpleString(self):
return 'smallint'
class ArrayType(DataType):
"""Array data type.
:param elementType: :class:`DataType` of each element in the array.
:param containsNull: boolean, whether the array can contain null (None) values.
"""
def __init__(self, elementType, containsNull=True):
"""
>>> ArrayType(StringType()) == ArrayType(StringType(), True)
True
>>> ArrayType(StringType(), False) == ArrayType(StringType())
False
"""
assert isinstance(elementType, DataType), "elementType should be DataType"
self.elementType = elementType
self.containsNull = containsNull
def simpleString(self):
return 'array<%s>' % self.elementType.simpleString()
def __repr__(self):
return "ArrayType(%s,%s)" % (self.elementType,
str(self.containsNull).lower())
def jsonValue(self):
return {"type": self.typeName(),
"elementType": self.elementType.jsonValue(),
"containsNull": self.containsNull}
@classmethod
def fromJson(cls, json):
return ArrayType(_parse_datatype_json_value(json["elementType"]),
json["containsNull"])
def needConversion(self):
return self.elementType.needConversion()
def toInternal(self, obj):
if not self.needConversion():
return obj
return obj and [self.elementType.toInternal(v) for v in obj]
def fromInternal(self, obj):
if not self.needConversion():
return obj
return obj and [self.elementType.fromInternal(v) for v in obj]
class MapType(DataType):
"""Map data type.
:param keyType: :class:`DataType` of the keys in the map.
:param valueType: :class:`DataType` of the values in the map.
:param valueContainsNull: indicates whether values can contain null (None) values.
Keys in a map data type are not allowed to be null (None).
"""
def __init__(self, keyType, valueType, valueContainsNull=True):
"""
>>> (MapType(StringType(), IntegerType())
... == MapType(StringType(), IntegerType(), True))
True
>>> (MapType(StringType(), IntegerType(), False)
... == MapType(StringType(), FloatType()))
False
"""
assert isinstance(keyType, DataType), "keyType should be DataType"
assert isinstance(valueType, DataType), "valueType should be DataType"
self.keyType = keyType
self.valueType = valueType
self.valueContainsNull = valueContainsNull
def simpleString(self):
return 'map<%s,%s>' % (self.keyType.simpleString(), self.valueType.simpleString())
def __repr__(self):
return "MapType(%s,%s,%s)" % (self.keyType, self.valueType,
str(self.valueContainsNull).lower())
def jsonValue(self):
return {"type": self.typeName(),
"keyType": self.keyType.jsonValue(),
"valueType": self.valueType.jsonValue(),
"valueContainsNull": self.valueContainsNull}
@classmethod
def fromJson(cls, json):
return MapType(_parse_datatype_json_value(json["keyType"]),
_parse_datatype_json_value(json["valueType"]),
json["valueContainsNull"])
def needConversion(self):
return self.keyType.needConversion() or self.valueType.needConversion()
def toInternal(self, obj):
if not self.needConversion():
return obj
return obj and dict((self.keyType.toInternal(k), self.valueType.toInternal(v))
for k, v in obj.items())
def fromInternal(self, obj):
if not self.needConversion():
return obj
return obj and dict((self.keyType.fromInternal(k), self.valueType.fromInternal(v))
for k, v in obj.items())
class StructField(DataType):
"""A field in :class:`StructType`.
:param name: string, name of the field.
:param dataType: :class:`DataType` of the field.
:param nullable: boolean, whether the field can be null (None) or not.
:param metadata: a dict from string to simple type that can be toInternald to JSON automatically
"""
def __init__(self, name, dataType, nullable=True, metadata=None):
"""
>>> (StructField("f1", StringType(), True)
... == StructField("f1", StringType(), True))
True
>>> (StructField("f1", StringType(), True)
... == StructField("f2", StringType(), True))
False
"""
assert isinstance(dataType, DataType), "dataType should be DataType"
if not isinstance(name, str):
name = name.encode('utf-8')
self.name = name
self.dataType = dataType
self.nullable = nullable
self.metadata = metadata or {}
def simpleString(self):
return '%s:%s' % (self.name, self.dataType.simpleString())
def __repr__(self):
return "StructField(%s,%s,%s)" % (self.name, self.dataType,
str(self.nullable).lower())
def jsonValue(self):
return {"name": self.name,
"type": self.dataType.jsonValue(),
"nullable": self.nullable,
"metadata": self.metadata}
@classmethod
def fromJson(cls, json):
return StructField(json["name"],
_parse_datatype_json_value(json["type"]),
json["nullable"],
json["metadata"])
def needConversion(self):
return self.dataType.needConversion()
def toInternal(self, obj):
return self.dataType.toInternal(obj)
def fromInternal(self, obj):
return self.dataType.fromInternal(obj)
class StructType(DataType):
"""Struct type, consisting of a list of :class:`StructField`.
This is the data type representing a :class:`Row`.
"""
def __init__(self, fields=None):
"""
>>> struct1 = StructType([StructField("f1", StringType(), True)])
>>> struct2 = StructType([StructField("f1", StringType(), True)])
>>> struct1 == struct2
True
>>> struct1 = StructType([StructField("f1", StringType(), True)])
>>> struct2 = StructType([StructField("f1", StringType(), True),
... StructField("f2", IntegerType(), False)])
>>> struct1 == struct2
False
"""
if not fields:
self.fields = []
self.names = []
else:
self.fields = fields
self.names = [f.name for f in fields]
assert all(isinstance(f, StructField) for f in fields),\
"fields should be a list of StructField"
self._needSerializeFields = None
def add(self, field, data_type=None, nullable=True, metadata=None):
"""
Construct a StructType by adding new elements to it to define the schema. The method accepts
either:
a) A single parameter which is a StructField object.
b) Between 2 and 4 parameters as (name, data_type, nullable (optional),
metadata(optional). The data_type parameter may be either a String or a DataType object
>>> struct1 = StructType().add("f1", StringType(), True).add("f2", StringType(), True, None)
>>> struct2 = StructType([StructField("f1", StringType(), True),\
StructField("f2", StringType(), True, None)])
>>> struct1 == struct2
True
>>> struct1 = StructType().add(StructField("f1", StringType(), True))
>>> struct2 = StructType([StructField("f1", StringType(), True)])
>>> struct1 == struct2
True
>>> struct1 = StructType().add("f1", "string", True)
>>> struct2 = StructType([StructField("f1", StringType(), True)])
>>> struct1 == struct2
True
:param field: Either the name of the field or a StructField object
:param data_type: If present, the DataType of the StructField to create
:param nullable: Whether the field to add should be nullable (default True)
:param metadata: Any additional metadata (default None)
:return: a new updated StructType
"""
if isinstance(field, StructField):
self.fields.append(field)
self.names.append(field.name)
else:
if isinstance(field, str) and data_type is None:
raise ValueError("Must specify DataType if passing name of struct_field to create.")
if isinstance(data_type, str):
data_type_f = _parse_datatype_json_value(data_type)
else:
data_type_f = data_type
self.fields.append(StructField(field, data_type_f, nullable, metadata))
self.names.append(field)
return self
def simpleString(self):
return 'struct<%s>' % (','.join(f.simpleString() for f in self.fields))
def __repr__(self):
return ("StructType(List(%s))" %
",".join(str(field) for field in self.fields))
def jsonValue(self):
return {"type": self.typeName(),
"fields": [f.jsonValue() for f in self.fields]}
@classmethod
def fromJson(cls, json):
return StructType([StructField.fromJson(f) for f in json["fields"]])
def needConversion(self):
# We need convert Row()/namedtuple into tuple()
return True
def toInternal(self, obj):
if obj is None:
return
if self._needSerializeFields is None:
self._needSerializeFields = any(f.needConversion() for f in self.fields)
if self._needSerializeFields:
if isinstance(obj, dict):
return tuple(f.toInternal(obj.get(n)) for n, f in zip(names, self.fields))
elif isinstance(obj, (tuple, list)):
return tuple(f.toInternal(v) for f, v in zip(self.fields, obj))
else:
raise ValueError("Unexpected tuple %r with StructType" % obj)
else:
if isinstance(obj, dict):
return tuple(obj.get(n) for n in self.names)
elif isinstance(obj, (list, tuple)):
return tuple(obj)
else:
raise ValueError("Unexpected tuple %r with StructType" % obj)
def fromInternal(self, obj):
if obj is None:
return
if isinstance(obj, Row):
# it's already converted by pickler
return obj
values = [f.dataType.fromInternal(v) for f, v in zip(self.fields, obj)]
return _create_row(self.names, values)
class UserDefinedType(DataType):
"""User-defined type (UDT).
.. note:: WARN: Spark Internal Use Only
"""
@classmethod
def typeName(cls):
return cls.__name__.lower()
@classmethod
def sqlType(cls):
"""
Underlying SQL storage type for this UDT.
"""
raise NotImplementedError("UDT must implement sqlType().")
@classmethod
def module(cls):
"""
The Python module of the UDT.
"""
raise NotImplementedError("UDT must implement module().")
@classmethod
def scalaUDT(cls):
"""
The class name of the paired Scala UDT.
"""
raise NotImplementedError("UDT must have a paired Scala UDT.")
def needConversion(self):
return True
@classmethod
def _cachedSqlType(cls):
"""
Cache the sqlType() into class, because it's heavy used in `toInternal`.
"""
if not hasattr(cls, "_cached_sql_type"):
cls._cached_sql_type = cls.sqlType()
return cls._cached_sql_type
def toInternal(self, obj):
return self._cachedSqlType().toInternal(self.serialize(obj))
def fromInternal(self, obj):
return self.deserialize(self._cachedSqlType().fromInternal(obj))
def serialize(self, obj):
"""
Converts the a user-type object into a SQL datum.
"""
raise NotImplementedError("UDT must implement toInternal().")
def deserialize(self, datum):
"""
Converts a SQL datum into a user-type object.
"""
raise NotImplementedError("UDT must implement fromInternal().")
def simpleString(self):
return 'udt'
def json(self):
return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True)
def jsonValue(self):
schema = {
"type": "udt",
"class": self.scalaUDT(),
"pyClass": "%s.%s" % (self.module(), type(self).__name__),
"sqlType": self.sqlType().jsonValue()
}
return schema
@classmethod
def fromJson(cls, json):
pyUDT = json["pyClass"]
split = pyUDT.rfind(".")
pyModule = pyUDT[:split]
pyClass = pyUDT[split+1:]
m = __import__(pyModule, globals(), locals(), [pyClass])
UDT = getattr(m, pyClass)
return UDT()
def __eq__(self, other):
return type(self) == type(other)
_atomic_types = [StringType, BinaryType, BooleanType, DecimalType, FloatType, DoubleType,
ByteType, ShortType, IntegerType, LongType, DateType, TimestampType]
_all_atomic_types = dict((t.typeName(), t) for t in _atomic_types)
_all_complex_types = dict((v.typeName(), v)
for v in [ArrayType, MapType, StructType])
def _parse_datatype_json_string(json_string):
"""Parses the given data type JSON string.
>>> import pickle
>>> def check_datatype(datatype):
... pickled = pickle.loads(pickle.dumps(datatype))
... assert datatype == pickled
... scala_datatype = sqlContext._ssql_ctx.parseDataType(datatype.json())
... python_datatype = _parse_datatype_json_string(scala_datatype.json())
... assert datatype == python_datatype
>>> for cls in _all_atomic_types.values():
... check_datatype(cls())
>>> # Simple ArrayType.
>>> simple_arraytype = ArrayType(StringType(), True)
>>> check_datatype(simple_arraytype)
>>> # Simple MapType.
>>> simple_maptype = MapType(StringType(), LongType())
>>> check_datatype(simple_maptype)
>>> # Simple StructType.
>>> simple_structtype = StructType([
... StructField("a", DecimalType(), False),
... StructField("b", BooleanType(), True),
... StructField("c", LongType(), True),
... StructField("d", BinaryType(), False)])
>>> check_datatype(simple_structtype)
>>> # Complex StructType.
>>> complex_structtype = StructType([
... StructField("simpleArray", simple_arraytype, True),
... StructField("simpleMap", simple_maptype, True),
... StructField("simpleStruct", simple_structtype, True),
... StructField("boolean", BooleanType(), False),
... StructField("withMeta", DoubleType(), False, {"name": "age"})])
>>> check_datatype(complex_structtype)
>>> # Complex ArrayType.
>>> complex_arraytype = ArrayType(complex_structtype, True)
>>> check_datatype(complex_arraytype)
>>> # Complex MapType.
>>> complex_maptype = MapType(complex_structtype,
... complex_arraytype, False)
>>> check_datatype(complex_maptype)
>>> check_datatype(ExamplePointUDT())
>>> structtype_with_udt = StructType([StructField("label", DoubleType(), False),
... StructField("point", ExamplePointUDT(), False)])
>>> check_datatype(structtype_with_udt)
"""
return _parse_datatype_json_value(json.loads(json_string))
_FIXED_DECIMAL = re.compile("decimal\\((\\d+),(\\d+)\\)")
def _parse_datatype_json_value(json_value):
if not isinstance(json_value, dict):
if json_value in _all_atomic_types.keys():
return _all_atomic_types[json_value]()
elif json_value == 'decimal':
return DecimalType()
elif _FIXED_DECIMAL.match(json_value):
m = _FIXED_DECIMAL.match(json_value)
return DecimalType(int(m.group(1)), int(m.group(2)))
else:
raise ValueError("Could not parse datatype: %s" % json_value)
else:
tpe = json_value["type"]
if tpe in _all_complex_types:
return _all_complex_types[tpe].fromJson(json_value)
elif tpe == 'udt':
return UserDefinedType.fromJson(json_value)
else:
raise ValueError("not supported type: %s" % tpe)
# Mapping Python types to Spark SQL DataType
_type_mappings = {
type(None): NullType,
bool: BooleanType,
int: LongType,
float: DoubleType,
str: StringType,
bytearray: BinaryType,
decimal.Decimal: DecimalType,
datetime.date: DateType,
datetime.datetime: TimestampType,
datetime.time: TimestampType,
}
if sys.version < "3":
_type_mappings.update({
unicode: StringType,
long: LongType,
})
def _infer_type(obj):
"""Infer the DataType from obj
>>> p = ExamplePoint(1.0, 2.0)
>>> _infer_type(p)
ExamplePointUDT
"""
if obj is None:
return NullType()
if hasattr(obj, '__UDT__'):
return obj.__UDT__
dataType = _type_mappings.get(type(obj))
if dataType is not None:
return dataType()
if isinstance(obj, dict):
for key, value in obj.items():
if key is not None and value is not None:
return MapType(_infer_type(key), _infer_type(value), True)
else:
return MapType(NullType(), NullType(), True)
elif isinstance(obj, (list, array)):
for v in obj:
if v is not None:
return ArrayType(_infer_type(obj[0]), True)
else:
return ArrayType(NullType(), True)
else:
try:
return _infer_schema(obj)
except TypeError:
raise TypeError("not supported type: %s" % type(obj))
def _infer_schema(row):
"""Infer the schema from dict/namedtuple/object"""
if isinstance(row, dict):
items = sorted(row.items())
elif isinstance(row, (tuple, list)):
if hasattr(row, "__fields__"): # Row
items = zip(row.__fields__, tuple(row))
elif hasattr(row, "_fields"): # namedtuple
items = zip(row._fields, tuple(row))
else:
names = ['_%d' % i for i in range(1, len(row) + 1)]
items = zip(names, row)
elif hasattr(row, "__dict__"): # object
items = sorted(row.__dict__.items())
else:
raise TypeError("Can not infer schema for type: %s" % type(row))
fields = [StructField(k, _infer_type(v), True) for k, v in items]
return StructType(fields)
def _has_nulltype(dt):
""" Return whether there is NullType in `dt` or not """
if isinstance(dt, StructType):
return any(_has_nulltype(f.dataType) for f in dt.fields)
elif isinstance(dt, ArrayType):
return _has_nulltype((dt.elementType))
elif isinstance(dt, MapType):
return _has_nulltype(dt.keyType) or _has_nulltype(dt.valueType)
else:
return isinstance(dt, NullType)
def _merge_type(a, b):
if isinstance(a, NullType):
return b
elif isinstance(b, NullType):
return a
elif type(a) is not type(b):
# TODO: type cast (such as int -> long)
raise TypeError("Can not merge type %s and %s" % (type(a), type(b)))
# same type
if isinstance(a, StructType):
nfs = dict((f.name, f.dataType) for f in b.fields)
fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType())))
for f in a.fields]
names = set([f.name for f in fields])
for n in nfs:
if n not in names:
fields.append(StructField(n, nfs[n]))
return StructType(fields)
elif isinstance(a, ArrayType):
return ArrayType(_merge_type(a.elementType, b.elementType), True)
elif isinstance(a, MapType):
return MapType(_merge_type(a.keyType, b.keyType),
_merge_type(a.valueType, b.valueType),
True)
else:
return a
def _need_converter(dataType):
if isinstance(dataType, StructType):
return True
elif isinstance(dataType, ArrayType):
return _need_converter(dataType.elementType)
elif isinstance(dataType, MapType):
return _need_converter(dataType.keyType) or _need_converter(dataType.valueType)
elif isinstance(dataType, NullType):
return True
else:
return False
def _create_converter(dataType):
"""Create an converter to drop the names of fields in obj """
if not _need_converter(dataType):
return lambda x: x
if isinstance(dataType, ArrayType):
conv = _create_converter(dataType.elementType)
return lambda row: [conv(v) for v in row]
elif isinstance(dataType, MapType):
kconv = _create_converter(dataType.keyType)
vconv = _create_converter(dataType.valueType)
return lambda row: dict((kconv(k), vconv(v)) for k, v in row.items())
elif isinstance(dataType, NullType):
return lambda x: None
elif not isinstance(dataType, StructType):
return lambda x: x
# dataType must be StructType
names = [f.name for f in dataType.fields]
converters = [_create_converter(f.dataType) for f in dataType.fields]
convert_fields = any(_need_converter(f.dataType) for f in dataType.fields)
def convert_struct(obj):
if obj is None:
return
if isinstance(obj, (tuple, list)):
if convert_fields:
return tuple(conv(v) for v, conv in zip(obj, converters))
else:
return tuple(obj)
if isinstance(obj, dict):
d = obj
elif hasattr(obj, "__dict__"): # object
d = obj.__dict__
else:
raise TypeError("Unexpected obj type: %s" % type(obj))
if convert_fields:
return tuple([conv(d.get(name)) for name, conv in zip(names, converters)])
else:
return tuple([d.get(name) for name in names])
return convert_struct
_BRACKETS = {'(': ')', '[': ']', '{': '}'}
def _split_schema_abstract(s):
"""
split the schema abstract into fields
>>> _split_schema_abstract("a b c")
['a', 'b', 'c']
>>> _split_schema_abstract("a(a b)")
['a(a b)']
>>> _split_schema_abstract("a b[] c{a b}")
['a', 'b[]', 'c{a b}']
>>> _split_schema_abstract(" ")
[]
"""
r = []
w = ''
brackets = []
for c in s:
if c == ' ' and not brackets:
if w:
r.append(w)
w = ''
else:
w += c
if c in _BRACKETS:
brackets.append(c)
elif c in _BRACKETS.values():
if not brackets or c != _BRACKETS[brackets.pop()]:
raise ValueError("unexpected " + c)
if brackets:
raise ValueError("brackets not closed: %s" % brackets)
if w:
r.append(w)
return r
def _parse_field_abstract(s):
"""
Parse a field in schema abstract
>>> _parse_field_abstract("a")
StructField(a,NullType,true)
>>> _parse_field_abstract("b(c d)")
StructField(b,StructType(...c,NullType,true),StructField(d...
>>> _parse_field_abstract("a[]")
StructField(a,ArrayType(NullType,true),true)
>>> _parse_field_abstract("a{[]}")
StructField(a,MapType(NullType,ArrayType(NullType,true),true),true)
"""
if set(_BRACKETS.keys()) & set(s):
idx = min((s.index(c) for c in _BRACKETS if c in s))
name = s[:idx]
return StructField(name, _parse_schema_abstract(s[idx:]), True)
else:
return StructField(s, NullType(), True)
def _parse_schema_abstract(s):
"""
parse abstract into schema
>>> _parse_schema_abstract("a b c")
StructType...a...b...c...
>>> _parse_schema_abstract("a[b c] b{}")
StructType...a,ArrayType...b...c...b,MapType...
>>> _parse_schema_abstract("c{} d{a b}")
StructType...c,MapType...d,MapType...a...b...
>>> _parse_schema_abstract("a b(t)").fields[1]
StructField(b,StructType(List(StructField(t,NullType,true))),true)
"""
s = s.strip()
if not s:
return NullType()
elif s.startswith('('):
return _parse_schema_abstract(s[1:-1])
elif s.startswith('['):
return ArrayType(_parse_schema_abstract(s[1:-1]), True)
elif s.startswith('{'):
return MapType(NullType(), _parse_schema_abstract(s[1:-1]))
parts = _split_schema_abstract(s)
fields = [_parse_field_abstract(p) for p in parts]
return StructType(fields)
def _infer_schema_type(obj, dataType):
"""
Fill the dataType with types inferred from obj
>>> schema = _parse_schema_abstract("a b c d")
>>> row = (1, 1.0, "str", datetime.date(2014, 10, 10))
>>> _infer_schema_type(row, schema)
StructType...LongType...DoubleType...StringType...DateType...
>>> row = [[1], {"key": (1, 2.0)}]
>>> schema = _parse_schema_abstract("a[] b{c d}")
>>> _infer_schema_type(row, schema)
StructType...a,ArrayType...b,MapType(StringType,...c,LongType...
"""
if isinstance(dataType, NullType):
return _infer_type(obj)
if not obj:
return NullType()
if isinstance(dataType, ArrayType):
eType = _infer_schema_type(obj[0], dataType.elementType)
return ArrayType(eType, True)
elif isinstance(dataType, MapType):
k, v = next(iter(obj.items()))
return MapType(_infer_schema_type(k, dataType.keyType),
_infer_schema_type(v, dataType.valueType))
elif isinstance(dataType, StructType):
fs = dataType.fields
assert len(fs) == len(obj), \
"Obj(%s) have different length with fields(%s)" % (obj, fs)
fields = [StructField(f.name, _infer_schema_type(o, f.dataType), True)
for o, f in zip(obj, fs)]
return StructType(fields)
else:
raise TypeError("Unexpected dataType: %s" % type(dataType))
_acceptable_types = {
BooleanType: (bool,),
ByteType: (int, long),
ShortType: (int, long),
IntegerType: (int, long),
LongType: (int, long),
FloatType: (float,),
DoubleType: (float,),
DecimalType: (decimal.Decimal,),
StringType: (str, unicode),
BinaryType: (bytearray,),
DateType: (datetime.date, datetime.datetime),
TimestampType: (datetime.datetime,),
ArrayType: (list, tuple, array),
MapType: (dict,),
StructType: (tuple, list),
}
def _verify_type(obj, dataType):
"""
Verify the type of obj against dataType, raise an exception if
they do not match.
>>> _verify_type(None, StructType([]))
>>> _verify_type("", StringType())
>>> _verify_type(0, LongType())
>>> _verify_type(list(range(3)), ArrayType(ShortType()))
>>> _verify_type(set(), ArrayType(StringType())) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
TypeError:...
>>> _verify_type({}, MapType(StringType(), IntegerType()))
>>> _verify_type((), StructType([]))
>>> _verify_type([], StructType([]))
>>> _verify_type([1], StructType([])) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
>>> _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT())
>>> _verify_type([1.0, 2.0], ExamplePointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
ValueError:...
"""
# all objects are nullable
if obj is None:
return
# StringType can work with any types
if isinstance(dataType, StringType):
return
if isinstance(dataType, UserDefinedType):
if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType):
raise ValueError("%r is not an instance of type %r" % (obj, dataType))
_verify_type(dataType.toInternal(obj), dataType.sqlType())
return
_type = type(dataType)
assert _type in _acceptable_types, "unknown datatype: %s" % dataType
if _type is StructType:
if not isinstance(obj, (tuple, list)):
raise TypeError("StructType can not accept object in type %s" % type(obj))
else:
# subclass of them can not be fromInternald in JVM
if type(obj) not in _acceptable_types[_type]:
raise TypeError("%s can not accept object in type %s" % (dataType, type(obj)))
if isinstance(dataType, ArrayType):
for i in obj:
_verify_type(i, dataType.elementType)
elif isinstance(dataType, MapType):
for k, v in obj.items():
_verify_type(k, dataType.keyType)
_verify_type(v, dataType.valueType)
elif isinstance(dataType, StructType):
if len(obj) != len(dataType.fields):
raise ValueError("Length of object (%d) does not match with "
"length of fields (%d)" % (len(obj), len(dataType.fields)))
for v, f in zip(obj, dataType.fields):
_verify_type(v, f.dataType)
# This is used to unpickle a Row from JVM
def _create_row_inbound_converter(dataType):
return lambda *a: dataType.fromInternal(a)
def _create_row(fields, values):
row = Row(*values)
row.__fields__ = fields
return row
class Row(tuple):
"""
A row in L{DataFrame}. The fields in it can be accessed like attributes.
Row can be used to create a row object by using named arguments,
the fields will be sorted by names.
>>> row = Row(name="Alice", age=11)
>>> row
Row(age=11, name='Alice')
>>> row.name, row.age
('Alice', 11)
Row also can be used to create another Row like class, then it
could be used to create Row objects, such as
>>> Person = Row("name", "age")
>>> Person
<Row(name, age)>
>>> Person("Alice", 11)
Row(name='Alice', age=11)
"""
def __new__(self, *args, **kwargs):
if args and kwargs:
raise ValueError("Can not use both args "
"and kwargs to create Row")
if args:
# create row class or objects
return tuple.__new__(self, args)
elif kwargs:
# create row objects
names = sorted(kwargs.keys())
row = tuple.__new__(self, [kwargs[n] for n in names])
row.__fields__ = names
return row
else:
raise ValueError("No args or kwargs")
def asDict(self):
"""
Return as an dict
"""
if not hasattr(self, "__fields__"):
raise TypeError("Cannot convert a Row class into dict")
return dict(zip(self.__fields__, self))
# let object acts like class
def __call__(self, *args):
"""create new Row object"""
return _create_row(self, args)
def __getattr__(self, item):
if item.startswith("__"):
raise AttributeError(item)
try:
# it will be slow when it has many fields,
# but this will not be used in normal cases
idx = self.__fields__.index(item)
return self[idx]
except IndexError:
raise AttributeError(item)
except ValueError:
raise AttributeError(item)
def __reduce__(self):
"""Returns a tuple so Python knows how to pickle Row."""
if hasattr(self, "__fields__"):
return (_create_row, (self.__fields__, tuple(self)))
else:
return tuple.__reduce__(self)
def __repr__(self):
"""Printable representation of Row used in Python REPL."""
if hasattr(self, "__fields__"):
return "Row(%s)" % ", ".join("%s=%r" % (k, v)
for k, v in zip(self.__fields__, tuple(self)))
else:
return "<Row(%s)>" % ", ".join(self)
class DateConverter(object):
def can_convert(self, obj):
return isinstance(obj, datetime.date)
def convert(self, obj, gateway_client):
Date = JavaClass("java.sql.Date", gateway_client)
return Date.valueOf(obj.strftime("%Y-%m-%d"))
class DatetimeConverter(object):
def can_convert(self, obj):
return isinstance(obj, datetime.datetime)
def convert(self, obj, gateway_client):
Timestamp = JavaClass("java.sql.Timestamp", gateway_client)
return Timestamp(int(time.mktime(obj.timetuple())) * 1000 + obj.microsecond // 1000)
# datetime is a subclass of date, we should register DatetimeConverter first
register_input_converter(DatetimeConverter())
register_input_converter(DateConverter())
def _test():
import doctest
from pyspark.context import SparkContext
# let doctest run in pyspark.sql.types, so DataTypes can be picklable
import pyspark.sql.types
from pyspark.sql import Row, SQLContext
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
globs = pyspark.sql.types.__dict__.copy()
sc = SparkContext('local[4]', 'PythonTest')
globs['sc'] = sc
globs['sqlContext'] = SQLContext(sc)
globs['ExamplePoint'] = ExamplePoint
globs['ExamplePointUDT'] = ExamplePointUDT
(failure_count, test_count) = doctest.testmod(
pyspark.sql.types, globs=globs, optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
if failure_count:
exit(-1)
if __name__ == "__main__":
_test()