From d5ce61722f2aa4167a8344e1664b000c70d5a3f8 Mon Sep 17 00:00:00 2001 From: Wenchen Fan Date: Tue, 8 Mar 2016 13:46:17 -0800 Subject: [SPARK-13740][SQL] add null check for _verify_type in types.py ## What changes were proposed in this pull request? This PR adds null check in `_verify_type` according to the nullability information. ## How was this patch tested? new doc tests Author: Wenchen Fan Closes #11574 from cloud-fan/py-null-check. --- python/pyspark/sql/types.py | 33 ++++++++++++++++++++++++++------- 1 file changed, 26 insertions(+), 7 deletions(-) (limited to 'python/pyspark/sql') diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index d1f5b47242..c71adfb58f 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -1091,7 +1091,7 @@ _acceptable_types = { } -def _verify_type(obj, dataType): +def _verify_type(obj, dataType, nullable=True): """ Verify the type of obj against dataType, raise a TypeError if they do not match. @@ -1120,10 +1120,29 @@ def _verify_type(obj, dataType): Traceback (most recent call last): ... ValueError:... + >>> _verify_type(None, ByteType(), False) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError:... + >>> _verify_type([1, None], ArrayType(ShortType(), False)) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError:... + >>> _verify_type({None: 1}, MapType(StringType(), IntegerType())) + Traceback (most recent call last): + ... + ValueError:... + >>> schema = StructType().add("a", IntegerType()).add("b", StringType(), False) + >>> _verify_type((1, None), schema) # doctest: +IGNORE_EXCEPTION_DETAIL + Traceback (most recent call last): + ... + ValueError:... """ - # all objects are nullable if obj is None: - return + if nullable: + return + else: + raise ValueError("This field is not nullable, but got None") # StringType can work with any types if isinstance(dataType, StringType): @@ -1160,19 +1179,19 @@ def _verify_type(obj, dataType): elif isinstance(dataType, ArrayType): for i in obj: - _verify_type(i, dataType.elementType) + _verify_type(i, dataType.elementType, dataType.containsNull) elif isinstance(dataType, MapType): for k, v in obj.items(): - _verify_type(k, dataType.keyType) - _verify_type(v, dataType.valueType) + _verify_type(k, dataType.keyType, False) + _verify_type(v, dataType.valueType, dataType.valueContainsNull) elif isinstance(dataType, StructType): if len(obj) != len(dataType.fields): raise ValueError("Length of object (%d) does not match with " "length of fields (%d)" % (len(obj), len(dataType.fields))) for v, f in zip(obj, dataType.fields): - _verify_type(v, f.dataType) + _verify_type(v, f.dataType, f.nullable) # This is used to unpickle a Row from JVM -- cgit v1.2.3