aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql
diff options
context:
space:
mode:
Diffstat (limited to 'python/pyspark/sql')
-rw-r--r--python/pyspark/sql/types.py33
1 files changed, 26 insertions, 7 deletions
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index d1f5b47242..c71adfb58f 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -1091,7 +1091,7 @@ _acceptable_types = {
}
-def _verify_type(obj, dataType):
+def _verify_type(obj, dataType, nullable=True):
"""
Verify the type of obj against dataType, raise a TypeError if they do not match.
@@ -1120,10 +1120,29 @@ def _verify_type(obj, dataType):
Traceback (most recent call last):
...
ValueError:...
+ >>> _verify_type(None, ByteType(), False) # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ ValueError:...
+ >>> _verify_type([1, None], ArrayType(ShortType(), False)) # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ ValueError:...
+ >>> _verify_type({None: 1}, MapType(StringType(), IntegerType()))
+ Traceback (most recent call last):
+ ...
+ ValueError:...
+ >>> schema = StructType().add("a", IntegerType()).add("b", StringType(), False)
+ >>> _verify_type((1, None), schema) # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ ValueError:...
"""
- # all objects are nullable
if obj is None:
- return
+ if nullable:
+ return
+ else:
+ raise ValueError("This field is not nullable, but got None")
# StringType can work with any types
if isinstance(dataType, StringType):
@@ -1160,19 +1179,19 @@ def _verify_type(obj, dataType):
elif isinstance(dataType, ArrayType):
for i in obj:
- _verify_type(i, dataType.elementType)
+ _verify_type(i, dataType.elementType, dataType.containsNull)
elif isinstance(dataType, MapType):
for k, v in obj.items():
- _verify_type(k, dataType.keyType)
- _verify_type(v, dataType.valueType)
+ _verify_type(k, dataType.keyType, False)
+ _verify_type(v, dataType.valueType, dataType.valueContainsNull)
elif isinstance(dataType, StructType):
if len(obj) != len(dataType.fields):
raise ValueError("Length of object (%d) does not match with "
"length of fields (%d)" % (len(obj), len(dataType.fields)))
for v, f in zip(obj, dataType.fields):
- _verify_type(v, f.dataType)
+ _verify_type(v, f.dataType, f.nullable)
# This is used to unpickle a Row from JVM