diff options
Diffstat (limited to 'python/pyspark/sql/tests.py')
-rw-r--r-- | python/pyspark/sql/tests.py | 4 |
1 files changed, 3 insertions, 1 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 7a55d801e4..ea821f486f 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -417,12 +417,14 @@ class SQLTests(ReusedPySparkTestCase): self.assertEquals(point, ExamplePoint(1.0, 2.0)) def test_udf_with_udt(self): - from pyspark.sql.tests import ExamplePoint + from pyspark.sql.tests import ExamplePoint, ExamplePointUDT row = Row(label=1.0, point=ExamplePoint(1.0, 2.0)) df = self.sc.parallelize([row]).toDF() self.assertEqual(1.0, df.map(lambda r: r.point.x).first()) udf = UserDefinedFunction(lambda p: p.y, DoubleType()) self.assertEqual(2.0, df.select(udf(df.point)).first()[0]) + udf2 = UserDefinedFunction(lambda p: ExamplePoint(p.x + 1, p.y + 1), ExamplePointUDT()) + self.assertEqual(ExamplePoint(2.0, 3.0), df.select(udf2(df.point)).first()[0]) def test_parquet_with_udt(self): from pyspark.sql.tests import ExamplePoint |