aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <simonh@tw.ibm.com>2016-08-02 10:08:18 -0700
committerDavies Liu <davies.liu@gmail.com>2016-08-02 10:08:18 -0700
commit146001a9ffefc7aaedd3d888d68c7a9b80bca545 (patch)
tree194b8ce9975f93ea574541c7b643edd05fe70521 /python/pyspark/sql
parent1dab63d8d3c59a3d6b4ee8e777810c44849e58b8 (diff)
downloadspark-146001a9ffefc7aaedd3d888d68c7a9b80bca545.tar.gz
spark-146001a9ffefc7aaedd3d888d68c7a9b80bca545.tar.bz2
spark-146001a9ffefc7aaedd3d888d68c7a9b80bca545.zip
[SPARK-16062] [SPARK-15989] [SQL] Fix two bugs of Python-only UDTs
## What changes were proposed in this pull request? There are two related bugs of Python-only UDTs. Because the test case of second one needs the first fix too. I put them into one PR. If it is not appropriate, please let me know. ### First bug: When MapObjects works on Python-only UDTs `RowEncoder` will use `PythonUserDefinedType.sqlType` for its deserializer expression. If the sql type is `ArrayType`, we will have `MapObjects` working on it. But `MapObjects` doesn't consider `PythonUserDefinedType` as its input data type. It causes error like: import pyspark.sql.group from pyspark.sql.tests import PythonOnlyPoint, PythonOnlyUDT from pyspark.sql.types import * schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT()) df = spark.createDataFrame([(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)], schema=schema) df.show() File "/home/spark/python/lib/py4j-0.10.1-src.zip/py4j/protocol.py", line 312, in get_return_value py4j.protocol.Py4JJavaError: An error occurred while calling o36.showString. : java.lang.RuntimeException: Error while decoding: scala.MatchError: org.apache.spark.sql.types.PythonUserDefinedTypef4ceede8 (of class org.apache.spark.sql.types.PythonUserDefinedType) ... ### Second bug: When Python-only UDTs is the element type of ArrayType import pyspark.sql.group from pyspark.sql.tests import PythonOnlyPoint, PythonOnlyUDT from pyspark.sql.types import * schema = StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT())) df = spark.createDataFrame([(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in range(10)], schema=schema) df.show() ## How was this patch tested? PySpark's sql tests. Author: Liang-Chi Hsieh <simonh@tw.ibm.com> Closes #13778 from viirya/fix-pyudt.
Diffstat (limited to 'python/pyspark/sql')
-rw-r--r--python/pyspark/sql/tests.py35
1 files changed, 35 insertions, 0 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index a8ca386e1c..87dbb50495 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -575,6 +575,41 @@ class SQLTests(ReusedPySparkTestCase):
_verify_type(PythonOnlyPoint(1.0, 2.0), PythonOnlyUDT())
self.assertRaises(ValueError, lambda: _verify_type([1.0, 2.0], PythonOnlyUDT()))
+ def test_simple_udt_in_df(self):
+ schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT())
+ df = self.spark.createDataFrame(
+ [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)],
+ schema=schema)
+ df.show()
+
+ def test_nested_udt_in_df(self):
+ schema = StructType().add("key", LongType()).add("val", ArrayType(PythonOnlyUDT()))
+ df = self.spark.createDataFrame(
+ [(i % 3, [PythonOnlyPoint(float(i), float(i))]) for i in range(10)],
+ schema=schema)
+ df.collect()
+
+ schema = StructType().add("key", LongType()).add("val",
+ MapType(LongType(), PythonOnlyUDT()))
+ df = self.spark.createDataFrame(
+ [(i % 3, {i % 3: PythonOnlyPoint(float(i + 1), float(i + 1))}) for i in range(10)],
+ schema=schema)
+ df.collect()
+
+ def test_complex_nested_udt_in_df(self):
+ from pyspark.sql.functions import udf
+
+ schema = StructType().add("key", LongType()).add("val", PythonOnlyUDT())
+ df = self.spark.createDataFrame(
+ [(i % 3, PythonOnlyPoint(float(i), float(i))) for i in range(10)],
+ schema=schema)
+ df.collect()
+
+ gd = df.groupby("key").agg({"val": "collect_list"})
+ gd.collect()
+ udf = udf(lambda k, v: [(k, v[0])], ArrayType(df.schema))
+ gd.select(udf(*gd)).collect()
+
def test_udt_with_none(self):
df = self.spark.range(0, 10, 1, 1)