diff options
Diffstat (limited to 'python/pyspark/sql.py')
-rw-r--r-- | python/pyspark/sql.py | 35 |
1 files changed, 32 insertions, 3 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py index 93bfc25bca..98e41f8575 100644 --- a/python/pyspark/sql.py +++ b/python/pyspark/sql.py @@ -35,6 +35,7 @@ import datetime import keyword import warnings import json +import re from array import array from operator import itemgetter from itertools import imap @@ -148,13 +149,30 @@ class TimestampType(PrimitiveType): """ -class DecimalType(PrimitiveType): +class DecimalType(DataType): """Spark SQL DecimalType The data type representing decimal.Decimal values. """ + def __init__(self, precision=None, scale=None): + self.precision = precision + self.scale = scale + self.hasPrecisionInfo = precision is not None + + def jsonValue(self): + if self.hasPrecisionInfo: + return "decimal(%d,%d)" % (self.precision, self.scale) + else: + return "decimal" + + def __repr__(self): + if self.hasPrecisionInfo: + return "DecimalType(%d,%d)" % (self.precision, self.scale) + else: + return "DecimalType()" + class DoubleType(PrimitiveType): @@ -446,9 +464,20 @@ def _parse_datatype_json_string(json_string): return _parse_datatype_json_value(json.loads(json_string)) +_FIXED_DECIMAL = re.compile("decimal\\((\\d+),(\\d+)\\)") + + def _parse_datatype_json_value(json_value): - if type(json_value) is unicode and json_value in _all_primitive_types.keys(): - return _all_primitive_types[json_value]() + if type(json_value) is unicode: + if json_value in _all_primitive_types.keys(): + return _all_primitive_types[json_value]() + elif json_value == u'decimal': + return DecimalType() + elif _FIXED_DECIMAL.match(json_value): + m = _FIXED_DECIMAL.match(json_value) + return DecimalType(int(m.group(1)), int(m.group(2))) + else: + raise ValueError("Could not parse datatype: %s" % json_value) else: return _all_complex_types[json_value["type"]].fromJson(json_value) |