From 4435de1bd36e2c30b764725fae05a08733f4aad4 Mon Sep 17 00:00:00 2001 From: Davies Liu Date: Fri, 24 Jun 2016 15:20:39 -0700 Subject: [SPARK-16179][PYSPARK] fix bugs for Python udf in generate ## What changes were proposed in this pull request? This PR fix the bug when Python UDF is used in explode (generator), GenerateExec requires that all the attributes in expressions should be resolvable from children when creating, we should replace the children first, then replace it's expressions. ``` >>> df.select(explode(f(*df))).show() Traceback (most recent call last): File "", line 1, in File "/home/vlad/dev/spark/python/pyspark/sql/dataframe.py", line 286, in show print(self._jdf.showString(n, truncate)) File "/home/vlad/dev/spark/python/lib/py4j-0.10.1-src.zip/py4j/java_gateway.py", line 933, in __call__ File "/home/vlad/dev/spark/python/pyspark/sql/utils.py", line 63, in deco return f(*a, **kw) File "/home/vlad/dev/spark/python/lib/py4j-0.10.1-src.zip/py4j/protocol.py", line 312, in get_return_value py4j.protocol.Py4JJavaError: An error occurred while calling o52.showString. : org.apache.spark.sql.catalyst.errors.package$TreeNodeException: makeCopy, tree: Generate explode((_1#0L)), false, false, [col#15L] +- Scan ExistingRDD[_1#0L] at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:50) at org.apache.spark.sql.catalyst.trees.TreeNode.makeCopy(TreeNode.scala:387) at org.apache.spark.sql.execution.SparkPlan.makeCopy(SparkPlan.scala:69) at org.apache.spark.sql.execution.SparkPlan.makeCopy(SparkPlan.scala:45) at org.apache.spark.sql.catalyst.plans.QueryPlan.transformExpressionsDown(QueryPlan.scala:177) at org.apache.spark.sql.catalyst.plans.QueryPlan.transformExpressions(QueryPlan.scala:144) at org.apache.spark.sql.execution.python.ExtractPythonUDFs$.org$apache$spark$sql$execution$python$ExtractPythonUDFs$$extract(ExtractPythonUDFs.scala:153) at org.apache.spark.sql.execution.python.ExtractPythonUDFs$$anonfun$apply$2.applyOrElse(ExtractPythonUDFs.scala:114) at org.apache.spark.sql.execution.python.ExtractPythonUDFs$$anonfun$apply$2.applyOrElse(ExtractPythonUDFs.scala:113) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:301) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:301) at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:69) at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:300) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:298) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:298) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$5.apply(TreeNode.scala:321) at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:179) at org.apache.spark.sql.catalyst.trees.TreeNode.transformChildren(TreeNode.scala:319) at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:298) at org.apache.spark.sql.execution.python.ExtractPythonUDFs$.apply(ExtractPythonUDFs.scala:113) at org.apache.spark.sql.execution.python.ExtractPythonUDFs$.apply(ExtractPythonUDFs.scala:93) at org.apache.spark.sql.execution.QueryExecution$$anonfun$prepareForExecution$1.apply(QueryExecution.scala:95) at org.apache.spark.sql.execution.QueryExecution$$anonfun$prepareForExecution$1.apply(QueryExecution.scala:95) at scala.collection.LinearSeqOptimized$class.foldLeft(LinearSeqOptimized.scala:124) at scala.collection.immutable.List.foldLeft(List.scala:84) at org.apache.spark.sql.execution.QueryExecution.prepareForExecution(QueryExecution.scala:95) at org.apache.spark.sql.execution.QueryExecution.executedPlan$lzycompute(QueryExecution.scala:85) at org.apache.spark.sql.execution.QueryExecution.executedPlan(QueryExecution.scala:85) at org.apache.spark.sql.Dataset.withTypedCallback(Dataset.scala:2557) at org.apache.spark.sql.Dataset.head(Dataset.scala:1923) at org.apache.spark.sql.Dataset.take(Dataset.scala:2138) at org.apache.spark.sql.Dataset.showString(Dataset.scala:239) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:237) at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357) at py4j.Gateway.invoke(Gateway.java:280) at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:128) at py4j.commands.CallCommand.execute(CallCommand.java:79) at py4j.GatewayConnection.run(GatewayConnection.java:211) at java.lang.Thread.run(Thread.java:745) Caused by: java.lang.reflect.InvocationTargetException at sun.reflect.NativeConstructorAccessorImpl.newInstance0(Native Method) at sun.reflect.NativeConstructorAccessorImpl.newInstance(NativeConstructorAccessorImpl.java:62) at sun.reflect.DelegatingConstructorAccessorImpl.newInstance(DelegatingConstructorAccessorImpl.java:45) at java.lang.reflect.Constructor.newInstance(Constructor.java:423) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$makeCopy$1$$anonfun$apply$13.apply(TreeNode.scala:413) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$makeCopy$1$$anonfun$apply$13.apply(TreeNode.scala:413) at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:69) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$makeCopy$1.apply(TreeNode.scala:412) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$makeCopy$1.apply(TreeNode.scala:387) at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:49) ... 42 more Caused by: org.apache.spark.sql.catalyst.errors.package$TreeNodeException: Binding attribute, tree: pythonUDF0#20 at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:50) at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:88) at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:87) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:279) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:279) at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:69) at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:278) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:284) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:284) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$5.apply(TreeNode.scala:321) at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:179) at org.apache.spark.sql.catalyst.trees.TreeNode.transformChildren(TreeNode.scala:319) at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:284) at org.apache.spark.sql.catalyst.trees.TreeNode.transform(TreeNode.scala:268) at org.apache.spark.sql.catalyst.expressions.BindReferences$.bindReference(BoundAttribute.scala:87) at org.apache.spark.sql.execution.GenerateExec.(GenerateExec.scala:63) ... 52 more Caused by: java.lang.RuntimeException: Couldn't find pythonUDF0#20 in [_1#0L] at scala.sys.package$.error(package.scala:27) at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1$$anonfun$applyOrElse$1.apply(BoundAttribute.scala:94) at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1$$anonfun$applyOrElse$1.apply(BoundAttribute.scala:88) at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:49) ... 67 more ``` ## How was this patch tested? Added regression tests. Author: Davies Liu Closes #13883 from davies/udf_in_generate. --- python/pyspark/sql/tests.py | 7 +++++++ 1 file changed, 7 insertions(+) (limited to 'python/pyspark/sql/tests.py') diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 388ac91922..3dc4083704 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -363,6 +363,13 @@ class SQLTests(ReusedPySparkTestCase): .select(my_add(col("k"), col("s")).alias("t")) self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)]) + def test_udf_in_generate(self): + from pyspark.sql.functions import udf, explode + df = self.spark.range(5) + f = udf(lambda x: list(range(x)), ArrayType(LongType())) + row = df.select(explode(f(*df))).groupBy().sum().first() + self.assertEqual(row[0], 10) + def test_basic_functions(self): rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) df = self.spark.read.json(rdd) -- cgit v1.2.3