aboutsummaryrefslogtreecommitdiff
path: root/python/pyspark/sql/tests.py
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-06-24 15:20:39 -0700
committerReynold Xin <rxin@databricks.com>2016-06-24 15:20:39 -0700
commit4435de1bd36e2c30b764725fae05a08733f4aad4 (patch)
treeb10bfa807f42fa06cad4e14f944f5fe5c5e92cd4 /python/pyspark/sql/tests.py
parent5f8de21606f63408058bda44701d49fbf5027820 (diff)
downloadspark-4435de1bd36e2c30b764725fae05a08733f4aad4.tar.gz
spark-4435de1bd36e2c30b764725fae05a08733f4aad4.tar.bz2
spark-4435de1bd36e2c30b764725fae05a08733f4aad4.zip
[SPARK-16179][PYSPARK] fix bugs for Python udf in generate
## What changes were proposed in this pull request? This PR fix the bug when Python UDF is used in explode (generator), GenerateExec requires that all the attributes in expressions should be resolvable from children when creating, we should replace the children first, then replace it's expressions. ``` >>> df.select(explode(f(*df))).show() Traceback (most recent call last): File "<stdin>", line 1, in <module> File "/home/vlad/dev/spark/python/pyspark/sql/dataframe.py", line 286, in show print(self._jdf.showString(n, truncate)) File "/home/vlad/dev/spark/python/lib/py4j-0.10.1-src.zip/py4j/java_gateway.py", line 933, in __call__ File "/home/vlad/dev/spark/python/pyspark/sql/utils.py", line 63, in deco return f(*a, **kw) File "/home/vlad/dev/spark/python/lib/py4j-0.10.1-src.zip/py4j/protocol.py", line 312, in get_return_value py4j.protocol.Py4JJavaError: An error occurred while calling o52.showString. : org.apache.spark.sql.catalyst.errors.package$TreeNodeException: makeCopy, tree: Generate explode(<lambda>(_1#0L)), false, false, [col#15L] +- Scan ExistingRDD[_1#0L] at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:50) at org.apache.spark.sql.catalyst.trees.TreeNode.makeCopy(TreeNode.scala:387) at org.apache.spark.sql.execution.SparkPlan.makeCopy(SparkPlan.scala:69) at org.apache.spark.sql.execution.SparkPlan.makeCopy(SparkPlan.scala:45) at org.apache.spark.sql.catalyst.plans.QueryPlan.transformExpressionsDown(QueryPlan.scala:177) at org.apache.spark.sql.catalyst.plans.QueryPlan.transformExpressions(QueryPlan.scala:144) at org.apache.spark.sql.execution.python.ExtractPythonUDFs$.org$apache$spark$sql$execution$python$ExtractPythonUDFs$$extract(ExtractPythonUDFs.scala:153) at org.apache.spark.sql.execution.python.ExtractPythonUDFs$$anonfun$apply$2.applyOrElse(ExtractPythonUDFs.scala:114) at org.apache.spark.sql.execution.python.ExtractPythonUDFs$$anonfun$apply$2.applyOrElse(ExtractPythonUDFs.scala:113) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:301) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformUp$1.apply(TreeNode.scala:301) at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:69) at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:300) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:298) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$4.apply(TreeNode.scala:298) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$5.apply(TreeNode.scala:321) at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:179) at org.apache.spark.sql.catalyst.trees.TreeNode.transformChildren(TreeNode.scala:319) at org.apache.spark.sql.catalyst.trees.TreeNode.transformUp(TreeNode.scala:298) at org.apache.spark.sql.execution.python.ExtractPythonUDFs$.apply(ExtractPythonUDFs.scala:113) at org.apache.spark.sql.execution.python.ExtractPythonUDFs$.apply(ExtractPythonUDFs.scala:93) at org.apache.spark.sql.execution.QueryExecution$$anonfun$prepareForExecution$1.apply(QueryExecution.scala:95) at org.apache.spark.sql.execution.QueryExecution$$anonfun$prepareForExecution$1.apply(QueryExecution.scala:95) at scala.collection.LinearSeqOptimized$class.foldLeft(LinearSeqOptimized.scala:124) at scala.collection.immutable.List.foldLeft(List.scala:84) at org.apache.spark.sql.execution.QueryExecution.prepareForExecution(QueryExecution.scala:95) at org.apache.spark.sql.execution.QueryExecution.executedPlan$lzycompute(QueryExecution.scala:85) at org.apache.spark.sql.execution.QueryExecution.executedPlan(QueryExecution.scala:85) at org.apache.spark.sql.Dataset.withTypedCallback(Dataset.scala:2557) at org.apache.spark.sql.Dataset.head(Dataset.scala:1923) at org.apache.spark.sql.Dataset.take(Dataset.scala:2138) at org.apache.spark.sql.Dataset.showString(Dataset.scala:239) at sun.reflect.NativeMethodAccessorImpl.invoke0(Native Method) at sun.reflect.NativeMethodAccessorImpl.invoke(NativeMethodAccessorImpl.java:62) at sun.reflect.DelegatingMethodAccessorImpl.invoke(DelegatingMethodAccessorImpl.java:43) at java.lang.reflect.Method.invoke(Method.java:498) at py4j.reflection.MethodInvoker.invoke(MethodInvoker.java:237) at py4j.reflection.ReflectionEngine.invoke(ReflectionEngine.java:357) at py4j.Gateway.invoke(Gateway.java:280) at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:128) at py4j.commands.CallCommand.execute(CallCommand.java:79) at py4j.GatewayConnection.run(GatewayConnection.java:211) at java.lang.Thread.run(Thread.java:745) Caused by: java.lang.reflect.InvocationTargetException at sun.reflect.NativeConstructorAccessorImpl.newInstance0(Native Method) at sun.reflect.NativeConstructorAccessorImpl.newInstance(NativeConstructorAccessorImpl.java:62) at sun.reflect.DelegatingConstructorAccessorImpl.newInstance(DelegatingConstructorAccessorImpl.java:45) at java.lang.reflect.Constructor.newInstance(Constructor.java:423) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$makeCopy$1$$anonfun$apply$13.apply(TreeNode.scala:413) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$makeCopy$1$$anonfun$apply$13.apply(TreeNode.scala:413) at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:69) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$makeCopy$1.apply(TreeNode.scala:412) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$makeCopy$1.apply(TreeNode.scala:387) at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:49) ... 42 more Caused by: org.apache.spark.sql.catalyst.errors.package$TreeNodeException: Binding attribute, tree: pythonUDF0#20 at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:50) at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:88) at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1.applyOrElse(BoundAttribute.scala:87) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:279) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$3.apply(TreeNode.scala:279) at org.apache.spark.sql.catalyst.trees.CurrentOrigin$.withOrigin(TreeNode.scala:69) at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:278) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:284) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$transformDown$1.apply(TreeNode.scala:284) at org.apache.spark.sql.catalyst.trees.TreeNode$$anonfun$5.apply(TreeNode.scala:321) at org.apache.spark.sql.catalyst.trees.TreeNode.mapProductIterator(TreeNode.scala:179) at org.apache.spark.sql.catalyst.trees.TreeNode.transformChildren(TreeNode.scala:319) at org.apache.spark.sql.catalyst.trees.TreeNode.transformDown(TreeNode.scala:284) at org.apache.spark.sql.catalyst.trees.TreeNode.transform(TreeNode.scala:268) at org.apache.spark.sql.catalyst.expressions.BindReferences$.bindReference(BoundAttribute.scala:87) at org.apache.spark.sql.execution.GenerateExec.<init>(GenerateExec.scala:63) ... 52 more Caused by: java.lang.RuntimeException: Couldn't find pythonUDF0#20 in [_1#0L] at scala.sys.package$.error(package.scala:27) at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1$$anonfun$applyOrElse$1.apply(BoundAttribute.scala:94) at org.apache.spark.sql.catalyst.expressions.BindReferences$$anonfun$bindReference$1$$anonfun$applyOrElse$1.apply(BoundAttribute.scala:88) at org.apache.spark.sql.catalyst.errors.package$.attachTree(package.scala:49) ... 67 more ``` ## How was this patch tested? Added regression tests. Author: Davies Liu <davies@databricks.com> Closes #13883 from davies/udf_in_generate.
Diffstat (limited to 'python/pyspark/sql/tests.py')
-rw-r--r--python/pyspark/sql/tests.py7
1 files changed, 7 insertions, 0 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 388ac91922..3dc4083704 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -363,6 +363,13 @@ class SQLTests(ReusedPySparkTestCase):
.select(my_add(col("k"), col("s")).alias("t"))
self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)])
+ def test_udf_in_generate(self):
+ from pyspark.sql.functions import udf, explode
+ df = self.spark.range(5)
+ f = udf(lambda x: list(range(x)), ArrayType(LongType()))
+ row = df.select(explode(f(*df))).groupBy().sum().first()
+ self.assertEqual(row[0], 10)
+
def test_basic_functions(self):
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
df = self.spark.read.json(rdd)