aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <viirya@gmail.com>2016-12-05 17:50:43 -0800
committerHerman van Hovell <hvanhovell@databricks.com>2016-12-05 17:50:43 -0800
commit3ba69b64852ccbf6d4ec05a021bc20616a09f574 (patch)
treef30aac795f0f8e1d140525c994bf7db6d2a7ea19 /python/pyspark
parent18eaabb71eeee6e6502aa0633b6d46fdb67d3c3b (diff)
downloadspark-3ba69b64852ccbf6d4ec05a021bc20616a09f574.tar.gz
spark-3ba69b64852ccbf6d4ec05a021bc20616a09f574.tar.bz2
spark-3ba69b64852ccbf6d4ec05a021bc20616a09f574.zip
[SPARK-18634][PYSPARK][SQL] Corruption and Correctness issues with exploding Python UDFs
## What changes were proposed in this pull request? As reported in the Jira, there are some weird issues with exploding Python UDFs in SparkSQL. The following test code can reproduce it. Notice: the following test code is reported to return wrong results in the Jira. However, as I tested on master branch, it causes exception and so can't return any result. >>> from pyspark.sql.functions import * >>> from pyspark.sql.types import * >>> >>> df = spark.range(10) >>> >>> def return_range(value): ... return [(i, str(i)) for i in range(value - 1, value + 1)] ... >>> range_udf = udf(return_range, ArrayType(StructType([StructField("integer_val", IntegerType()), ... StructField("string_val", StringType())]))) >>> >>> df.select("id", explode(range_udf(df.id))).show() Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/spark/python/pyspark/sql/dataframe.py", line 318, in show print(self._jdf.showString(n, 20)) File "/spark/python/lib/py4j-0.10.4-src.zip/py4j/java_gateway.py", line 1133, in __call__ File "/spark/python/pyspark/sql/utils.py", line 63, in deco return f(*a, **kw) File "/spark/python/lib/py4j-0.10.4-src.zip/py4j/protocol.py", line 319, in get_return_value py4j.protocol.Py4JJavaError: An error occurred while calling o126.showString.: java.lang.AssertionError: assertion failed at scala.Predef$.assert(Predef.scala:156) at org.apache.spark.sql.execution.CodegenSupport$class.consume(WholeStageCodegenExec.scala:120) at org.apache.spark.sql.execution.GenerateExec.consume(GenerateExec.scala:57) The cause of this issue is, in `ExtractPythonUDFs` we insert `BatchEvalPythonExec` to run PythonUDFs in batch. `BatchEvalPythonExec` will add extra outputs (e.g., `pythonUDF0`) to original plan. In above case, the original `Range` only has one output `id`. After `ExtractPythonUDFs`, the added `BatchEvalPythonExec` has two outputs `id` and `pythonUDF0`. Because the output of `GenerateExec` is given after analysis phase, in above case, it is the combination of `id`, i.e., the output of `Range`, and `col`. But in planning phase, we change `GenerateExec`'s child plan to `BatchEvalPythonExec` with additional output attributes. It will cause no problem in non wholestage codegen. Because when evaluating the additional attributes are projected out the final output of `GenerateExec`. However, as `GenerateExec` now supports wholestage codegen, the framework will input all the outputs of the child plan to `GenerateExec`. Then when consuming `GenerateExec`'s output data (i.e., calling `consume`), the number of output attributes is different to the output variables in wholestage codegen. To solve this issue, this patch only gives the generator's output to `GenerateExec` after analysis phase. `GenerateExec`'s output is the combination of its child plan's output and the generator's output. So when we change `GenerateExec`'s child, its output is still correct. ## How was this patch tested? Added test cases to PySpark. Please review http://spark.apache.org/contributing.html before opening a pull request. Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #16120 from viirya/fix-py-udf-with-generator.
Diffstat (limited to 'python/pyspark')
-rw-r--r--python/pyspark/sql/tests.py20
1 files changed, 20 insertions, 0 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 9f34414f64..66a3490a64 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -384,6 +384,26 @@ class SQLTests(ReusedPySparkTestCase):
row = df.select(explode(f(*df))).groupBy().sum().first()
self.assertEqual(row[0], 10)
+ df = self.spark.range(3)
+ res = df.select("id", explode(f(df.id))).collect()
+ self.assertEqual(res[0][0], 1)
+ self.assertEqual(res[0][1], 0)
+ self.assertEqual(res[1][0], 2)
+ self.assertEqual(res[1][1], 0)
+ self.assertEqual(res[2][0], 2)
+ self.assertEqual(res[2][1], 1)
+
+ range_udf = udf(lambda value: list(range(value - 1, value + 1)), ArrayType(IntegerType()))
+ res = df.select("id", explode(range_udf(df.id))).collect()
+ self.assertEqual(res[0][0], 0)
+ self.assertEqual(res[0][1], -1)
+ self.assertEqual(res[1][0], 0)
+ self.assertEqual(res[1][1], 0)
+ self.assertEqual(res[2][0], 1)
+ self.assertEqual(res[2][1], 0)
+ self.assertEqual(res[3][0], 1)
+ self.assertEqual(res[3][1], 1)
+
def test_udf_with_order_by_and_limit(self):
from pyspark.sql.functions import udf
my_copy = udf(lambda x: x, IntegerType())