aboutsummaryrefslogtreecommitdiff
path: root/python
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-04-04 10:56:26 -0700
committerDavies Liu <davies.liu@gmail.com>2016-04-04 10:56:26 -0700
commit5743c6476dbef50852b7f9873112a2d299966ebd (patch)
tree63c8a3682266a0eb2fbcf83ab298d2082bf4bdca /python
parent855ed44ed31210d2001d7ce67c8fa99f8416edd3 (diff)
downloadspark-5743c6476dbef50852b7f9873112a2d299966ebd.tar.gz
spark-5743c6476dbef50852b7f9873112a2d299966ebd.tar.bz2
spark-5743c6476dbef50852b7f9873112a2d299966ebd.zip
[SPARK-12981] [SQL] extract Pyhton UDF in physical plan
## What changes were proposed in this pull request? Currently we extract Python UDFs into a special logical plan EvaluatePython in analyzer, But EvaluatePython is not part of catalyst, many rules have no knowledge of it , which will break many things (for example, filter push down or column pruning). We should treat Python UDFs as normal expressions, until we want to evaluate in physical plan, we could extract them in end of optimizer, or physical plan. This PR extract Python UDFs in physical plan. Closes #10935 ## How was this patch tested? Added regression tests. Author: Davies Liu <davies@databricks.com> Closes #12127 from davies/py_udf.
Diffstat (limited to 'python')
-rw-r--r--python/pyspark/sql/tests.py9
1 files changed, 9 insertions, 0 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 536ef55251..e4f79c911c 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -343,6 +343,15 @@ class SQLTests(ReusedPySparkTestCase):
[res] = self.sqlCtx.sql("SELECT MYUDF('')").collect()
self.assertEqual("", res[0])
+ def test_udf_with_aggregate_function(self):
+ df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
+ from pyspark.sql.functions import udf, col
+ from pyspark.sql.types import BooleanType
+
+ my_filter = udf(lambda a: a == 1, BooleanType())
+ sel = df.select(col("key")).distinct().filter(my_filter(col("key")))
+ self.assertEqual(sel.collect(), [Row(key=1)])
+
def test_basic_functions(self):
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
df = self.sqlCtx.read.json(rdd)