aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-06-24 15:20:39 -0700
committerReynold Xin <rxin@databricks.com>2016-06-24 15:20:39 -0700
commit4435de1bd36e2c30b764725fae05a08733f4aad4 (patch)
treeb10bfa807f42fa06cad4e14f944f5fe5c5e92cd4
parent5f8de21606f63408058bda44701d49fbf5027820 (diff)
downloadspark-4435de1bd36e2c30b764725fae05a08733f4aad4.tar.gz
spark-4435de1bd36e2c30b764725fae05a08733f4aad4.tar.bz2
spark-4435de1bd36e2c30b764725fae05a08733f4aad4.zip
[SPARK-16179][PYSPARK] fix bugs for Python udf in generate
## What changes were proposed in this pull request? This PR fix the bug when Python UDF is used in explode (generator), GenerateExec requires that all the attributes in expressions should be resolvable from children when creating, we should replace the children first, then replace it's expressions. ``` >>> df.select(explode(f(*df))).show() Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/home/vlad/dev/spark/python/pyspark/sql/dataframe.py", line 286, in show print(self._jdf.showString(n, truncate)) File "/home/vlad/dev/spark/python/lib/py4j-0.10.1-src.zip/py4j/java_gateway.py", line 933, in __call__ File "/home/vlad/dev/spark/python/pyspark/sql/utils.py", line 63, in deco return f(*a, **kw) File "/home/vlad/dev/spark/python/lib/py4j-0.10.1-src.zip/py4j/protocol.py", line 312, in get_return_value py4j.protocol.Py4JJavaError: An error occurred while calling o52.showString. : org.apache.spark.sql.catalyst.errors.package$TreeNodeException: makeCopy, tree: Generate explode(<lambda>(_1#0L)), false, false, [col#15L] +- Scan ExistingRDD[_1#0L] at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:50) at org.apache.spark.sql.catalyst.trees.TreeNode.makeCopy(TreeNode.scala:387) at org.apache.spark.sql.execution.SparkPlan.makeCopy(SparkPlan.scala:69) at org.apache.spark.sql.execution.SparkPlan.makeCopy(SparkPlan.scala:45) at org.apache.spark.sql.catalyst.plans.QueryPlan.transformExpressionsDown(QueryPlan.scala:177) at org.apache.spark.sql.catalyst.plans.QueryPlan.transformExpressions(QueryPlan.scala:144) at org.apache.spark.sql.execution.python.ExtractPythonUDFs$.org$apache$spark$sql$execution$python$ExtractPythonUDFs$$extract(ExtractPythonUDFs.scala:153) at org.apache.spark.sql.execution.python.ExtractPythonUDFs$$anonfun$apply$2.applyOrElse(ExtractPythonUDFs.scala:114) at org.apache.spark.sql.execution.python.ExtractPythonUDFs$$anonfun$apply$2.applyOrElse(ExtractPythonUDFs.scala:113) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:301) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:301) at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:69) at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:300) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:298) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:298) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$5.apply(TreeNode.scala:321) at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:179) at org.apache.spark.sql.catalyst.trees.TreeNode.transformChildren(TreeNode.scala:319) at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:298) at org.apache.spark.sql.execution.python.ExtractPythonUDFs$.apply(ExtractPythonUDFs.scala:113) at org.apache.spark.sql.execution.python.ExtractPythonUDFs$.apply(ExtractPythonUDFs.scala:93) at org.apache.spark.sql.execution.QueryExecution$$anonfun$prepareForExecution$1.apply(QueryExecution.scala:95) at org.apache.spark.sql.execution.QueryExecution$$anonfun$prepareForExecution$1.apply(QueryExecution.scala:95) at scala.collection.LinearSeqOptimized$class.foldLeft(LinearSeqOptimized.scala:124) at scala.collection.immutable.List.foldLeft(List.scala:84) at org.apache.spark.sql.execution.QueryExecution.prepareForExecution(QueryExecution.scala:95) at org.apache.spark.sql.execution.QueryExecution.executedPlan$lzycompute(QueryExecution.scala:85) at org.apache.spark.sql.execution.QueryExecution.executedPlan(QueryExecution.scala:85) at org.apache.spark.sql.Dataset.withTypedCallback(Dataset.scala:2557) at org.apache.spark.sql.Dataset.head(Dataset.scala:1923) at org.apache.spark.sql.Dataset.take(Dataset.scala:2138) at org.apache.spark.sql.Dataset.showString(Dataset.scala:239) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:237) at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357) at py4j.Gateway.invoke(Gateway.java:280) at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:128) at py4j.commands.CallCommand.execute(CallCommand.java:79) at py4j.GatewayConnection.run(GatewayConnection.java:211) at java.lang.Thread.run(Thread.java:745) Caused by: java.lang.reflect.InvocationTargetException at sun.reflect.NativeConstructorAccessorImpl.newInstance0(Native Method) at sun.reflect.NativeConstructorAccessorImpl.newInstance(NativeConstructorAccessorImpl.java:62) at sun.reflect.DelegatingConstructorAccessorImpl.newInstance(DelegatingConstructorAccessorImpl.java:45) at java.lang.reflect.Constructor.newInstance(Constructor.java:423) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$makeCopy$1$$anonfun$apply$13.apply(TreeNode.scala:413) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$makeCopy$1$$anonfun$apply$13.apply(TreeNode.scala:413) at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:69) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$makeCopy$1.apply(TreeNode.scala:412) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$makeCopy$1.apply(TreeNode.scala:387) at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:49) ... 42 more Caused by: org.apache.spark.sql.catalyst.errors.package$TreeNodeException: Binding attribute, tree: pythonUDF0#20 at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:50) at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:88) at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:87) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:279) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:279) at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:69) at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:278) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:284) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:284) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$5.apply(TreeNode.scala:321) at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:179) at org.apache.spark.sql.catalyst.trees.TreeNode.transformChildren(TreeNode.scala:319) at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:284) at org.apache.spark.sql.catalyst.trees.TreeNode.transform(TreeNode.scala:268) at org.apache.spark.sql.catalyst.expressions.BindReferences$.bindReference(BoundAttribute.scala:87) at org.apache.spark.sql.execution.GenerateExec.<init>(GenerateExec.scala:63) ... 52 more Caused by: java.lang.RuntimeException: Couldn't find pythonUDF0#20 in [_1#0L] at scala.sys.package$.error(package.scala:27) at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1$$anonfun$applyOrElse$1.apply(BoundAttribute.scala:94) at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1$$anonfun$applyOrElse$1.apply(BoundAttribute.scala:88) at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:49) ... 67 more ``` ## How was this patch tested? Added regression tests. Author: Davies Liu <davies@databricks.com> Closes #13883 from davies/udf_in_generate.
-rw-r--r--python/pyspark/sql/tests.py7
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala4
2 files changed, 9 insertions, 2 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 388ac91922..3dc4083704 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -363,6 +363,13 @@ class SQLTests(ReusedPySparkTestCase):
.select(my_add(col("k"), col("s")).alias("t"))
self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)])
+ def test_udf_in_generate(self):
+ from pyspark.sql.functions import udf, explode
+ df = self.spark.range(5)
+ f = udf(lambda x: list(range(x)), ArrayType(LongType()))
+ row = df.select(explode(f(*df))).groupBy().sum().first()
+ self.assertEqual(row[0], 10)
+
def test_basic_functions(self):
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
df = self.spark.read.json(rdd)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
index 87583c8234..829bcae6f9 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -150,10 +150,10 @@ private[spark] object ExtractPythonUDFs extends Rule[SparkPlan] {
sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
}
- val rewritten = plan.transformExpressions {
+ val rewritten = plan.withNewChildren(newChildren).transformExpressions {
case p: PythonUDF if attributeMap.contains(p) =>
attributeMap(p)
- }.withNewChildren(newChildren)
+ }
// extract remaining python UDFs recursively
val newPlan = extract(rewritten)