aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql/types.py36
1 files changed, 21 insertions, 15 deletions
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index 10ad89ea14..b97d50c945 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -194,30 +194,33 @@ class TimestampType(AtomicType):
class DecimalType(FractionalType):
"""Decimal (decimal.Decimal) data type.
+
+ The DecimalType must have fixed precision (the maximum total number of digits)
+ and scale (the number of digits on the right of dot). For example, (5, 2) can
+ support the value from [-999.99 to 999.99].
+
+ The precision can be up to 38, the scale must less or equal to precision.
+
+ When create a DecimalType, the default precision and scale is (10, 0). When infer
+ schema from decimal.Decimal objects, it will be DecimalType(38, 18).
+
+ :param precision: the maximum total number of digits (default: 10)
+ :param scale: the number of digits on right side of dot. (default: 0)
"""
- def __init__(self, precision=None, scale=None):
+ def __init__(self, precision=10, scale=0):
self.precision = precision
self.scale = scale
- self.hasPrecisionInfo = precision is not None
+ self.hasPrecisionInfo = True # this is public API
def simpleString(self):
- if self.hasPrecisionInfo:
- return "decimal(%d,%d)" % (self.precision, self.scale)
- else:
- return "decimal(10,0)"
+ return "decimal(%d,%d)" % (self.precision, self.scale)
def jsonValue(self):
- if self.hasPrecisionInfo:
- return "decimal(%d,%d)" % (self.precision, self.scale)
- else:
- return "decimal"
+ return "decimal(%d,%d)" % (self.precision, self.scale)
def __repr__(self):
- if self.hasPrecisionInfo:
- return "DecimalType(%d,%d)" % (self.precision, self.scale)
- else:
- return "DecimalType()"
+ return "DecimalType(%d,%d)" % (self.precision, self.scale)
class DoubleType(FractionalType):
@@ -761,7 +764,10 @@ def _infer_type(obj):
return obj.__UDT__
dataType = _type_mappings.get(type(obj))
- if dataType is not None:
+ if dataType is DecimalType:
+ # the precision and scale of `obj` may be different from row to row.
+ return DecimalType(38, 18)
+ elif dataType is not None:
return dataType()
if isinstance(obj, dict):