aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-09-12 16:35:42 -0700
committerDavies Liu <davies.liu@gmail.com>2016-09-12 16:35:42 -0700
commita91ab705e8c124aa116c3e5b1f3ba88ce832dcde (patch)
treeec4c202b2f8ab704a2ce1058ade1aeb2d91af0f4 /python
parentf9c580f11098d95f098936a0b90fa21d71021205 (diff)
downloadspark-a91ab705e8c124aa116c3e5b1f3ba88ce832dcde.tar.gz
spark-a91ab705e8c124aa116c3e5b1f3ba88ce832dcde.tar.bz2
spark-a91ab705e8c124aa116c3e5b1f3ba88ce832dcde.zip
[SPARK-17474] [SQL] fix python udf in TakeOrderedAndProjectExec
## What changes were proposed in this pull request? When there is any Python UDF in the Project between Sort and Limit, it will be collected into TakeOrderedAndProjectExec, ExtractPythonUDFs failed to pull the Python UDFs out because QueryPlan.expressions does not include the expression inside Option[Seq[Expression]]. Ideally, we should fix the `QueryPlan.expressions`, but tried with no luck (it always run into infinite loop). In PR, I changed the TakeOrderedAndProjectExec to no use Option[Seq[Expression]] to workaround it. cc JoshRosen ## How was this patch tested? Added regression test. Author: Davies Liu <davies@databricks.com> Closes #15030 from davies/all_expr.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql/tests.py8
1 files changed, 8 insertions, 0 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index fd8e9cec3e..769e454072 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -376,6 +376,14 @@ class SQLTests(ReusedPySparkTestCase):
row = df.select(explode(f(*df))).groupBy().sum().first()
self.assertEqual(row[0], 10)
+ def test_udf_with_order_by_and_limit(self):
+ from pyspark.sql.functions import udf
+ my_copy = udf(lambda x: x, IntegerType())
+ df = self.spark.range(10).orderBy("id")
+ res = df.select(df.id, my_copy(df.id).alias("copy")).limit(1)
+ res.explain(True)
+ self.assertEqual(res.collect(), [Row(id=0, copy=0)])
+
def test_basic_functions(self):
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
df = self.spark.read.json(rdd)