aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/sql.py')
-rw-r--r--python/pyspark/sql.py35
1 files changed, 32 insertions, 3 deletions
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 93bfc25bca..98e41f8575 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -35,6 +35,7 @@ import datetime
import keyword
import warnings
import json
+import re
from array import array
from operator import itemgetter
from itertools import imap
@@ -148,13 +149,30 @@ class TimestampType(PrimitiveType):
"""
-class DecimalType(PrimitiveType):
+class DecimalType(DataType):
"""Spark SQL DecimalType
The data type representing decimal.Decimal values.
"""
+ def __init__(self, precision=None, scale=None):
+ self.precision = precision
+ self.scale = scale
+ self.hasPrecisionInfo = precision is not None
+
+ def jsonValue(self):
+ if self.hasPrecisionInfo:
+ return "decimal(%d,%d)" % (self.precision, self.scale)
+ else:
+ return "decimal"
+
+ def __repr__(self):
+ if self.hasPrecisionInfo:
+ return "DecimalType(%d,%d)" % (self.precision, self.scale)
+ else:
+ return "DecimalType()"
+
class DoubleType(PrimitiveType):
@@ -446,9 +464,20 @@ def _parse_datatype_json_string(json_string):
return _parse_datatype_json_value(json.loads(json_string))
+_FIXED_DECIMAL = re.compile("decimal\\((\\d+),(\\d+)\\)")
+
+
def _parse_datatype_json_value(json_value):
- if type(json_value) is unicode and json_value in _all_primitive_types.keys():
- return _all_primitive_types[json_value]()
+ if type(json_value) is unicode:
+ if json_value in _all_primitive_types.keys():
+ return _all_primitive_types[json_value]()
+ elif json_value == u'decimal':
+ return DecimalType()
+ elif _FIXED_DECIMAL.match(json_value):
+ m = _FIXED_DECIMAL.match(json_value)
+ return DecimalType(int(m.group(1)), int(m.group(2)))
+ else:
+ raise ValueError("Could not parse datatype: %s" % json_value)
else:
return _all_complex_types[json_value["type"]].fromJson(json_value)