aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-06-28 14:09:38 -0700
committerDavies Liu <davies.liu@gmail.com>2016-06-28 14:09:38 -0700
commit35438fb0ad3bcda5c5a3a0ccde1a620699d012db (patch)
treec79785d17d712fa6c69626d779f50b904bb290ed /python/pyspark
parent1aad8c6e59c1e8b18a3eaa8ded93ff6ad05d83df (diff)
downloadspark-35438fb0ad3bcda5c5a3a0ccde1a620699d012db.tar.gz
spark-35438fb0ad3bcda5c5a3a0ccde1a620699d012db.tar.bz2
spark-35438fb0ad3bcda5c5a3a0ccde1a620699d012db.zip
[SPARK-16175] [PYSPARK] handle None for UDT
## What changes were proposed in this pull request? Scala UDT will bypass all the null and will not pass them into serialize() and deserialize() of UDT, this PR update the Python UDT to do this as well. ## How was this patch tested? Added tests. Author: Davies Liu <davies@databricks.com> Closes #13878 from davies/udt_null.
Diffstat (limited to 'python/pyspark')
-rw-r--r--python/pyspark/sql/tests.py11
-rw-r--r--python/pyspark/sql/types.py7
2 files changed, 16 insertions, 2 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index f863485e6c..a8ca386e1c 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -575,6 +575,17 @@ class SQLTests(ReusedPySparkTestCase):
_verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT())
self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT()))
+ def test_udt_with_none(self):
+ df = self.spark.range(0, 10, 1, 1)
+
+ def myudf(x):
+ if x > 0:
+ return PythonOnlyPoint(float(x), float(x))
+
+ self.spark.catalog.registerFunction("udf", myudf, PythonOnlyUDT())
+ rows = [r[0] for r in df.selectExpr("udf(id)").take(2)]
+ self.assertEqual(rows, [None, PythonOnlyPoint(1, 1)])
+
def test_infer_schema_with_udt(self):
from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
diff --git a/python/pyspark/sql/types.py b/python/pyspark/sql/types.py
index f0b56be8da..a3679873e1 100644
--- a/python/pyspark/sql/types.py
+++ b/python/pyspark/sql/types.py
@@ -648,10 +648,13 @@ class UserDefinedType(DataType):
return cls._cached_sql_type
def toInternal(self, obj):
- return self._cachedSqlType().toInternal(self.serialize(obj))
+ if obj is not None:
+ return self._cachedSqlType().toInternal(self.serialize(obj))
def fromInternal(self, obj):
- return self.deserialize(self._cachedSqlType().fromInternal(obj))
+ v = self._cachedSqlType().fromInternal(obj)
+ if v is not None:
+ return self.deserialize(v)
def serialize(self, obj):
"""