aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/sql/tests.py9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala29
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala110
3 files changed, 143 insertions, 5 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 66320bd050..af7d52cdac 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -360,6 +360,15 @@ class SQLTests(ReusedPySparkTestCase):
[res] = self.spark.sql("SELECT MYUDF('')").collect()
self.assertEqual("", res[0])
+ def test_udf_with_filter_function(self):
+ df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
+ from pyspark.sql.functions import udf, col
+ from pyspark.sql.types import BooleanType
+
+ my_filter = udf(lambda a: a < 2, BooleanType())
+ sel = df.select(col("key"), col("value")).filter((my_filter(col("key"))) & (df.value < "2"))
+ self.assertEqual(sel.collect(), [Row(key=1, value='1')])
+
def test_udf_with_aggregate_function(self):
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
from pyspark.sql.functions import udf, col, sum
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
index 16e44845d5..69b4b7bb07 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution
-import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.{FilterExec, SparkPlan}
/**
@@ -90,7 +90,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {
* This has the limitation that the input to the Python UDF is not allowed include attributes from
* multiple child operators.
*/
-object ExtractPythonUDFs extends Rule[SparkPlan] {
+object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper {
private def hasPythonUDF(e: Expression): Boolean = {
e.find(_.isInstanceOf[PythonUDF]).isDefined
@@ -126,10 +126,11 @@ object ExtractPythonUDFs extends Rule[SparkPlan] {
plan
} else {
val attributeMap = mutable.HashMap[PythonUDF, Expression]()
+ val splitFilter = trySplitFilter(plan)
// Rewrite the child that has the input required for the UDF
- val newChildren = plan.children.map { child =>
+ val newChildren = splitFilter.children.map { child =>
// Pick the UDF we are going to evaluate
- val validUdfs = udfs.filter { case udf =>
+ val validUdfs = udfs.filter { udf =>
// Check to make sure that the UDF can be evaluated with only the input of this child.
udf.references.subsetOf(child.outputSet)
}.toArray // Turn it into an array since iterators cannot be serialized in Scala 2.10
@@ -150,7 +151,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] {
sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
}
- val rewritten = plan.withNewChildren(newChildren).transformExpressions {
+ val rewritten = splitFilter.withNewChildren(newChildren).transformExpressions {
case p: PythonUDF if attributeMap.contains(p) =>
attributeMap(p)
}
@@ -165,4 +166,22 @@ object ExtractPythonUDFs extends Rule[SparkPlan] {
}
}
}
+
+ // Split the original FilterExec to two FilterExecs. Only push down the first few predicates
+ // that are all deterministic.
+ private def trySplitFilter(plan: SparkPlan): SparkPlan = {
+ plan match {
+ case filter: FilterExec =>
+ val (candidates, containingNonDeterministic) =
+ splitConjunctivePredicates(filter.condition).span(_.deterministic)
+ val (pushDown, rest) = candidates.partition(!hasPythonUDF(_))
+ if (pushDown.nonEmpty) {
+ val newChild = FilterExec(pushDown.reduceLeft(And), filter.child)
+ FilterExec((rest ++ containingNonDeterministic).reduceLeft(And), newChild)
+ } else {
+ filter
+ }
+ case o => o
+ }
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
new file mode 100644
index 0000000000..81bea2fef8
--- /dev/null
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala
@@ -0,0 +1,110 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import scala.collection.JavaConverters._
+import scala.collection.mutable.ArrayBuffer
+
+import org.apache.spark.api.python.PythonFunction
+import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, GreaterThan, In}
+import org.apache.spark.sql.execution.{FilterExec, InputAdapter, SparkPlanTest, WholeStageCodegenExec}
+import org.apache.spark.sql.test.SharedSQLContext
+import org.apache.spark.sql.types.BooleanType
+
+class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext {
+ import testImplicits.newProductEncoder
+ import testImplicits.localSeqToDatasetHolder
+
+ override def beforeAll(): Unit = {
+ super.beforeAll()
+ spark.udf.registerPython("dummyPythonUDF", new MyDummyPythonUDF)
+ }
+
+ override def afterAll(): Unit = {
+ spark.sessionState.functionRegistry.dropFunction("dummyPythonUDF")
+ super.afterAll()
+ }
+
+ test("Python UDF: push down deterministic FilterExec predicates") {
+ val df = Seq(("Hello", 4)).toDF("a", "b")
+ .where("dummyPythonUDF(b) and dummyPythonUDF(a) and a in (3, 4)")
+ val qualifiedPlanNodes = df.queryExecution.executedPlan.collect {
+ case f @ FilterExec(
+ And(_: AttributeReference, _: AttributeReference),
+ InputAdapter(_: BatchEvalPythonExec)) => f
+ case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(FilterExec(_: In, _))) => b
+ }
+ assert(qualifiedPlanNodes.size == 2)
+ }
+
+ test("Nested Python UDF: push down deterministic FilterExec predicates") {
+ val df = Seq(("Hello", 4)).toDF("a", "b")
+ .where("dummyPythonUDF(a, dummyPythonUDF(a, b)) and a in (3, 4)")
+ val qualifiedPlanNodes = df.queryExecution.executedPlan.collect {
+ case f @ FilterExec(_: AttributeReference, InputAdapter(_: BatchEvalPythonExec)) => f
+ case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(FilterExec(_: In, _))) => b
+ }
+ assert(qualifiedPlanNodes.size == 2)
+ }
+
+ test("Python UDF: no push down on non-deterministic") {
+ val df = Seq(("Hello", 4)).toDF("a", "b")
+ .where("b > 4 and dummyPythonUDF(a) and rand() > 3")
+ val qualifiedPlanNodes = df.queryExecution.executedPlan.collect {
+ case f @ FilterExec(
+ And(_: AttributeReference, _: GreaterThan),
+ InputAdapter(_: BatchEvalPythonExec)) => f
+ case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(_: FilterExec)) => b
+ }
+ assert(qualifiedPlanNodes.size == 2)
+ }
+
+ test("Python UDF: no push down on predicates starting from the first non-deterministic") {
+ val df = Seq(("Hello", 4)).toDF("a", "b")
+ .where("dummyPythonUDF(a) and rand() > 3 and b > 4")
+ val qualifiedPlanNodes = df.queryExecution.executedPlan.collect {
+ case f @ FilterExec(And(_: And, _: GreaterThan), InputAdapter(_: BatchEvalPythonExec)) => f
+ }
+ assert(qualifiedPlanNodes.size == 1)
+ }
+
+ test("Python UDF refers to the attributes from more than one child") {
+ val df = Seq(("Hello", 4)).toDF("a", "b")
+ val df2 = Seq(("Hello", 4)).toDF("c", "d")
+ val joinDF = df.join(df2).where("dummyPythonUDF(a, c) == dummyPythonUDF(d, c)")
+
+ val e = intercept[RuntimeException] {
+ joinDF.queryExecution.executedPlan
+ }.getMessage
+ assert(Seq("Invalid PythonUDF dummyUDF", "requires attributes from more than one child")
+ .forall(e.contains))
+ }
+}
+
+// This Python UDF is dummy and just for testing. Unable to execute.
+class DummyUDF extends PythonFunction(
+ command = Array[Byte](),
+ envVars = Map("" -> "").asJava,
+ pythonIncludes = ArrayBuffer("").asJava,
+ pythonExec = "",
+ pythonVer = "",
+ broadcastVars = null,
+ accumulator = null)
+
+class MyDummyPythonUDF
+ extends UserDefinedPythonFunction(name = "dummyUDF", func = new DummyUDF, dataType = BooleanType)