aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/tests.py
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2015-07-20 12:14:47 -0700
committerDavies Liu <davies.liu@gmail.com>2015-07-20 12:14:47 -0700
commit9f913c4fd6f0f223fd378e453d5b9a87beda1ac4 (patch)
treeb0e89b1b7f4c0617c7bab6b314b161fc5240c95b /python/pyspark/sql/tests.py
parent02181fb6d14833448fb5c501045655213d3cf340 (diff)
downloadspark-9f913c4fd6f0f223fd378e453d5b9a87beda1ac4.tar.gz
spark-9f913c4fd6f0f223fd378e453d5b9a87beda1ac4.tar.bz2
spark-9f913c4fd6f0f223fd378e453d5b9a87beda1ac4.zip
[SPARK-9114] [SQL] [PySpark] convert returned object from UDF into internal type
This PR also remove the duplicated code between registerFunction and UserDefinedFunction. cc JoshRosen Author: Davies Liu <davies@databricks.com> Closes #7450 from davies/fix_return_type and squashes the following commits: e80bf9f [Davies Liu] remove debugging code f94b1f6 [Davies Liu] fix mima 8f9c58b [Davies Liu] convert returned object from UDF into internal type
Diffstat (limited to 'python/pyspark/sql/tests.py')
-rw-r--r--python/pyspark/sql/tests.py4
1 files changed, 3 insertions, 1 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 7a55d801e4..ea821f486f 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -417,12 +417,14 @@ class SQLTests(ReusedPySparkTestCase):
self.assertEquals(point, ExamplePoint(1.0, 2.0))
def test_udf_with_udt(self):
- from pyspark.sql.tests import ExamplePoint
+ from pyspark.sql.tests import ExamplePoint, ExamplePointUDT
row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
df = self.sc.parallelize([row]).toDF()
self.assertEqual(1.0, df.map(lambda r: r.point.x).first())
udf = UserDefinedFunction(lambda p: p.y, DoubleType())
self.assertEqual(2.0, df.select(udf(df.point)).first()[0])
+ udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT())
+ self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0])
def test_parquet_with_udt(self):
from pyspark.sql.tests import ExamplePoint