diff options
Diffstat (limited to 'python/pyspark')
-rw-r--r-- | python/pyspark/sql/tests.py | 11 | ||||
-rw-r--r-- | python/pyspark/sql/types.py | 7 |
2 files changed, 16 insertions, 2 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index f863485e6c..a8ca386e1c 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -575,6 +575,17 @@ class SQLTests(ReusedPySparkTestCase): _verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT()) self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT())) + def test_udt_with_none(self): + df = self.spark.range(0, 10, 1, 1) + + def myudf(x): + if x > 0: + return PythonOnlyPoint(float(x), float(x)) + + self.spark.catalog.registerFunction("udf", myudf, PythonOnlyUDT()) + rows = [r[0] for r in df.selectExpr("udf(id)").take(2)] + self.assertEqual(rows, [None, PythonOnlyPoint(1, 1)]) + def test_infer_schema_with_udt(self): from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py index f0b56be8da..a3679873e1 100644 --- a/python/pyspark/sql/types.py +++ b/python/pyspark/sql/types.py @@ -648,10 +648,13 @@ class UserDefinedType(DataType): return cls._cached_sql_type def toInternal(self, obj): - return self._cachedSqlType().toInternal(self.serialize(obj)) + if obj is not None: + return self._cachedSqlType().toInternal(self.serialize(obj)) def fromInternal(self, obj): - return self.deserialize(self._cachedSqlType().fromInternal(obj)) + v = self._cachedSqlType().fromInternal(obj) + if v is not None: + return self.deserialize(v) def serialize(self, obj): """ |