aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/sql/tests.py11
-rw-r--r--python/pyspark/sql/types.py7
2 files changed, 16 insertions, 2 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index f863485e6c..a8ca386e1c 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -575,6 +575,17 @@ class SQLTests(ReusedPySparkTestCase):
_verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT())
self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT()))
+ def test_udt_with_none(self):
+ df = self.spark.range(0, 10, 1, 1)
+
+ def myudf(x):
+ if x > 0:
+ return PythonOnlyPoint(float(x), float(x))
+
+ self.spark.catalog.registerFunction("udf", myudf, PythonOnlyUDT())
+ rows = [r[0] for r in df.selectExpr("udf(id)").take(2)]
+ self.assertEqual(rows, [None, PythonOnlyPoint(1, 1)])
+
def test_infer_schema_with_udt(self):
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index f0b56be8da..a3679873e1 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -648,10 +648,13 @@ class UserDefinedType(DataType):
return cls._cached_sql_type
def toInternal(self, obj):
- return self._cachedSqlType().toInternal(self.serialize(obj))
+ if obj is not None:
+ return self._cachedSqlType().toInternal(self.serialize(obj))
def fromInternal(self, obj):
- return self.deserialize(self._cachedSqlType().fromInternal(obj))
+ v = self._cachedSqlType().fromInternal(obj)
+ if v is not None:
+ return self.deserialize(v)
def serialize(self, obj):
"""