diff options
author | Wenchen Fan <wenchen@databricks.com> | 2016-03-03 20:16:37 -0800 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2016-03-03 20:16:37 -0800 |
commit | 15d57f9c23145ace37d1631d8f9c19675c142214 (patch) | |
tree | c5c152bde4f55245a0ea0e1ff1cb8b82f9acac08 | |
parent | d062587dd2c4ed13998ee8bcc9d08f29734df228 (diff) | |
download | spark-15d57f9c23145ace37d1631d8f9c19675c142214.tar.gz spark-15d57f9c23145ace37d1631d8f9c19675c142214.tar.bz2 spark-15d57f9c23145ace37d1631d8f9c19675c142214.zip |
[SPARK-13647] [SQL] also check if numeric value is within allowed range in _verify_type
## What changes were proposed in this pull request?
This PR makes the `_verify_type` in `types.py` more strict, also check if numeric value is within allowed range.
## How was this patch tested?
newly added doc test.
Author: Wenchen Fan <wenchen@databricks.com>
Closes #11492 from cloud-fan/py-verify.
-rw-r--r-- | python/pyspark/sql/types.py | 27 |
1 files changed, 24 insertions, 3 deletions
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index 5bc0773fa8..d1f5b47242 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1093,8 +1093,11 @@ _acceptable_types = { def _verify_type(obj, dataType): """ - Verify the type of obj against dataType, raise an exception if - they do not match. + Verify the type of obj against dataType, raise a TypeError if they do not match. + + Also verify the value of obj against datatype, raise a ValueError if it's not within the allowed + range, e.g. using 128 as ByteType will overflow. Note that, Python float is not checked, so it + will become infinity when cast to Java float if it overflows. >>> _verify_type(None, StructType([])) >>> _verify_type("", StringType()) @@ -1111,6 +1114,12 @@ def _verify_type(obj, dataType): Traceback (most recent call last): ... ValueError:... + >>> # Check if numeric values are within the allowed range. + >>> _verify_type(12, ByteType()) + >>> _verify_type(1234, ByteType()) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError:... """ # all objects are nullable if obj is None: @@ -1137,7 +1146,19 @@ def _verify_type(obj, dataType): if type(obj) not in _acceptable_types[_type]: raise TypeError("%s can not accept object %r in type %s" % (dataType, obj, type(obj))) - if isinstance(dataType, ArrayType): + if isinstance(dataType, ByteType): + if obj < -128 or obj > 127: + raise ValueError("object of ByteType out of range, got: %s" % obj) + + elif isinstance(dataType, ShortType): + if obj < -32768 or obj > 32767: + raise ValueError("object of ShortType out of range, got: %s" % obj) + + elif isinstance(dataType, IntegerType): + if obj < -2147483648 or obj > 2147483647: + raise ValueError("object of IntegerType out of range, got: %s" % obj) + + elif isinstance(dataType, ArrayType): for i in obj: _verify_type(i, dataType.elementType) |