diff options
3 files changed, 143 insertions, 5 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 66320bd050..af7d52cdac 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -360,6 +360,15 @@ class SQLTests(ReusedPySparkTestCase): [res] = self.spark.sql("SELECT MYUDF('')").collect() self.assertEqual("", res[0]) + def test_udf_with_filter_function(self): + df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + from pyspark.sql.functions import udf, col + from pyspark.sql.types import BooleanType + + my_filter = udf(lambda a: a < 2, BooleanType()) + sel = df.select(col("key"), col("value")).filter((my_filter(col("key"))) & (df.value < "2")) + self.assertEqual(sel.collect(), [Row(key=1, value='1')]) + def test_udf_with_aggregate_function(self): df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) from pyspark.sql.functions import udf, col, sum diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 16e44845d5..69b4b7bb07 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -25,7 +25,7 @@ import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project} import org.apache.spark.sql.catalyst.rules.Rule import org.apache.spark.sql.execution -import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.{FilterExec, SparkPlan} /** @@ -90,7 +90,7 @@ object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] { * This has the limitation that the input to the Python UDF is not allowed include attributes from * multiple child operators. */ -object ExtractPythonUDFs extends Rule[SparkPlan] { +object ExtractPythonUDFs extends Rule[SparkPlan] with PredicateHelper { private def hasPythonUDF(e: Expression): Boolean = { e.find(_.isInstanceOf[PythonUDF]).isDefined @@ -126,10 +126,11 @@ object ExtractPythonUDFs extends Rule[SparkPlan] { plan } else { val attributeMap = mutable.HashMap[PythonUDF, Expression]() + val splitFilter = trySplitFilter(plan) // Rewrite the child that has the input required for the UDF - val newChildren = plan.children.map { child => + val newChildren = splitFilter.children.map { child => // Pick the UDF we are going to evaluate - val validUdfs = udfs.filter { case udf => + val validUdfs = udfs.filter { udf => // Check to make sure that the UDF can be evaluated with only the input of this child. udf.references.subsetOf(child.outputSet) }.toArray // Turn it into an array since iterators cannot be serialized in Scala 2.10 @@ -150,7 +151,7 @@ object ExtractPythonUDFs extends Rule[SparkPlan] { sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.") } - val rewritten = plan.withNewChildren(newChildren).transformExpressions { + val rewritten = splitFilter.withNewChildren(newChildren).transformExpressions { case p: PythonUDF if attributeMap.contains(p) => attributeMap(p) } @@ -165,4 +166,22 @@ object ExtractPythonUDFs extends Rule[SparkPlan] { } } } + + // Split the original FilterExec to two FilterExecs. Only push down the first few predicates + // that are all deterministic. + private def trySplitFilter(plan: SparkPlan): SparkPlan = { + plan match { + case filter: FilterExec => + val (candidates, containingNonDeterministic) = + splitConjunctivePredicates(filter.condition).span(_.deterministic) + val (pushDown, rest) = candidates.partition(!hasPythonUDF(_)) + if (pushDown.nonEmpty) { + val newChild = FilterExec(pushDown.reduceLeft(And), filter.child) + FilterExec((rest ++ containingNonDeterministic).reduceLeft(And), newChild) + } else { + filter + } + case o => o + } + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala new file mode 100644 index 0000000000..81bea2fef8 --- /dev/null +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExecSuite.scala @@ -0,0 +1,110 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import scala.collection.JavaConverters._ +import scala.collection.mutable.ArrayBuffer + +import org.apache.spark.api.python.PythonFunction +import org.apache.spark.sql.catalyst.expressions.{And, AttributeReference, GreaterThan, In} +import org.apache.spark.sql.execution.{FilterExec, InputAdapter, SparkPlanTest, WholeStageCodegenExec} +import org.apache.spark.sql.test.SharedSQLContext +import org.apache.spark.sql.types.BooleanType + +class BatchEvalPythonExecSuite extends SparkPlanTest with SharedSQLContext { + import testImplicits.newProductEncoder + import testImplicits.localSeqToDatasetHolder + + override def beforeAll(): Unit = { + super.beforeAll() + spark.udf.registerPython("dummyPythonUDF", new MyDummyPythonUDF) + } + + override def afterAll(): Unit = { + spark.sessionState.functionRegistry.dropFunction("dummyPythonUDF") + super.afterAll() + } + + test("Python UDF: push down deterministic FilterExec predicates") { + val df = Seq(("Hello", 4)).toDF("a", "b") + .where("dummyPythonUDF(b) and dummyPythonUDF(a) and a in (3, 4)") + val qualifiedPlanNodes = df.queryExecution.executedPlan.collect { + case f @ FilterExec( + And(_: AttributeReference, _: AttributeReference), + InputAdapter(_: BatchEvalPythonExec)) => f + case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(FilterExec(_: In, _))) => b + } + assert(qualifiedPlanNodes.size == 2) + } + + test("Nested Python UDF: push down deterministic FilterExec predicates") { + val df = Seq(("Hello", 4)).toDF("a", "b") + .where("dummyPythonUDF(a, dummyPythonUDF(a, b)) and a in (3, 4)") + val qualifiedPlanNodes = df.queryExecution.executedPlan.collect { + case f @ FilterExec(_: AttributeReference, InputAdapter(_: BatchEvalPythonExec)) => f + case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(FilterExec(_: In, _))) => b + } + assert(qualifiedPlanNodes.size == 2) + } + + test("Python UDF: no push down on non-deterministic") { + val df = Seq(("Hello", 4)).toDF("a", "b") + .where("b > 4 and dummyPythonUDF(a) and rand() > 3") + val qualifiedPlanNodes = df.queryExecution.executedPlan.collect { + case f @ FilterExec( + And(_: AttributeReference, _: GreaterThan), + InputAdapter(_: BatchEvalPythonExec)) => f + case b @ BatchEvalPythonExec(_, _, WholeStageCodegenExec(_: FilterExec)) => b + } + assert(qualifiedPlanNodes.size == 2) + } + + test("Python UDF: no push down on predicates starting from the first non-deterministic") { + val df = Seq(("Hello", 4)).toDF("a", "b") + .where("dummyPythonUDF(a) and rand() > 3 and b > 4") + val qualifiedPlanNodes = df.queryExecution.executedPlan.collect { + case f @ FilterExec(And(_: And, _: GreaterThan), InputAdapter(_: BatchEvalPythonExec)) => f + } + assert(qualifiedPlanNodes.size == 1) + } + + test("Python UDF refers to the attributes from more than one child") { + val df = Seq(("Hello", 4)).toDF("a", "b") + val df2 = Seq(("Hello", 4)).toDF("c", "d") + val joinDF = df.join(df2).where("dummyPythonUDF(a, c) == dummyPythonUDF(d, c)") + + val e = intercept[RuntimeException] { + joinDF.queryExecution.executedPlan + }.getMessage + assert(Seq("Invalid PythonUDF dummyUDF", "requires attributes from more than one child") + .forall(e.contains)) + } +} + +// This Python UDF is dummy and just for testing. Unable to execute. +class DummyUDF extends PythonFunction( + command = Array[Byte](), + envVars = Map("" -> "").asJava, + pythonIncludes = ArrayBuffer("").asJava, + pythonExec = "", + pythonVer = "", + broadcastVars = null, + accumulator = null) + +class MyDummyPythonUDF + extends UserDefinedPythonFunction(name = "dummyUDF", func = new DummyUDF, dataType = BooleanType) |