aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/sql/tests.py10
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala6
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala70
4 files changed, 77 insertions, 11 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 1d5d691696..c631ad8a46 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -339,13 +339,21 @@ class SQLTests(ReusedPySparkTestCase):
def test_udf_with_aggregate_function(self):
df = self.spark.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
- from pyspark.sql.functions import udf, col
+ from pyspark.sql.functions import udf, col, sum
from pyspark.sql.types import BooleanType
my_filter = udf(lambda a: a == 1, BooleanType())
sel = df.select(col("key")).distinct().filter(my_filter(col("key")))
self.assertEqual(sel.collect(), [Row(key=1)])
+ my_copy = udf(lambda x: x, IntegerType())
+ my_add = udf(lambda a, b: int(a + b), IntegerType())
+ my_strlen = udf(lambda x: len(x), IntegerType())
+ sel = df.groupBy(my_copy(col("key")).alias("k"))\
+ .agg(sum(my_strlen(col("value"))).alias("s"))\
+ .select(my_add(col("k"), col("s")).alias("t"))
+ self.assertEqual(sel.collect(), [Row(t=4), Row(t=3)])
+
def test_basic_functions(self):
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
df = self.spark.read.json(rdd)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
index 08b2d7fcd4..12a10cba20 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkOptimizer.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.sql.ExperimentalMethods
import org.apache.spark.sql.catalyst.catalog.SessionCatalog
import org.apache.spark.sql.catalyst.optimizer.Optimizer
+import org.apache.spark.sql.execution.python.ExtractPythonUDFFromAggregate
import org.apache.spark.sql.internal.SQLConf
class SparkOptimizer(
@@ -28,6 +29,7 @@ class SparkOptimizer(
experimentalMethods: ExperimentalMethods)
extends Optimizer(catalog, conf) {
- override def batches: Seq[Batch] = super.batches :+ Batch(
- "User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*)
+ override def batches: Seq[Batch] = super.batches :+
+ Batch("Extract Python UDF from Aggregate", Once, ExtractPythonUDFFromAggregate) :+
+ Batch("User Provided Optimizers", fixedPoint, experimentalMethods.extraOptimizations: _*)
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
index 061d7c7f79..d9bf4d3ccf 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchEvalPythonExec.scala
@@ -46,6 +46,8 @@ case class BatchEvalPythonExec(udfs: Seq[PythonUDF], output: Seq[Attribute], chi
def children: Seq[SparkPlan] = child :: Nil
+ override def producedAttributes: AttributeSet = AttributeSet(output.drop(child.output.length))
+
private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = {
udf.children match {
case Seq(u: PythonUDF) =>
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
index ab192360e1..668470ee6a 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -18,12 +18,68 @@
package org.apache.spark.sql.execution.python
import scala.collection.mutable
+import scala.collection.mutable.ArrayBuffer
-import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
+import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
+import org.apache.spark.sql.catalyst.plans.logical.{Aggregate, LogicalPlan, Project}
import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.execution
import org.apache.spark.sql.execution.SparkPlan
+
+/**
+ * Extracts all the Python UDFs in logical aggregate, which depends on aggregate expression or
+ * grouping key, evaluate them after aggregate.
+ */
+private[spark] object ExtractPythonUDFFromAggregate extends Rule[LogicalPlan] {
+
+ /**
+ * Returns whether the expression could only be evaluated within aggregate.
+ */
+ private def belongAggregate(e: Expression, agg: Aggregate): Boolean = {
+ e.isInstanceOf[AggregateExpression] ||
+ agg.groupingExpressions.exists(_.semanticEquals(e))
+ }
+
+ private def hasPythonUdfOverAggregate(expr: Expression, agg: Aggregate): Boolean = {
+ expr.find {
+ e => e.isInstanceOf[PythonUDF] && e.find(belongAggregate(_, agg)).isDefined
+ }.isDefined
+ }
+
+ private def extract(agg: Aggregate): LogicalPlan = {
+ val projList = new ArrayBuffer[NamedExpression]()
+ val aggExpr = new ArrayBuffer[NamedExpression]()
+ agg.aggregateExpressions.foreach { expr =>
+ if (hasPythonUdfOverAggregate(expr, agg)) {
+ // Python UDF can only be evaluated after aggregate
+ val newE = expr transformDown {
+ case e: Expression if belongAggregate(e, agg) =>
+ val alias = e match {
+ case a: NamedExpression => a
+ case o => Alias(e, "agg")()
+ }
+ aggExpr += alias
+ alias.toAttribute
+ }
+ projList += newE.asInstanceOf[NamedExpression]
+ } else {
+ aggExpr += expr
+ projList += expr.toAttribute
+ }
+ }
+ // There is no Python UDF over aggregate expression
+ Project(projList, agg.copy(aggregateExpressions = aggExpr))
+ }
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan transformUp {
+ case agg: Aggregate if agg.aggregateExpressions.exists(hasPythonUdfOverAggregate(_, agg)) =>
+ extract(agg)
+ }
+}
+
+
/**
* Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated
* alone in a batch.
@@ -59,10 +115,12 @@ private[spark] object ExtractPythonUDFs extends Rule[SparkPlan] {
}
/**
- * Extract all the PythonUDFs from the current operator.
+ * Extract all the PythonUDFs from the current operator and evaluate them before the operator.
*/
- def extract(plan: SparkPlan): SparkPlan = {
+ private def extract(plan: SparkPlan): SparkPlan = {
val udfs = plan.expressions.flatMap(collectEvaluatableUDF)
+ // ignore the PythonUDF that come from second/third aggregate, which is not used
+ .filter(udf => udf.references.subsetOf(plan.inputSet))
if (udfs.isEmpty) {
// If there aren't any, we are done.
plan
@@ -89,11 +147,7 @@ private[spark] object ExtractPythonUDFs extends Rule[SparkPlan] {
// Other cases are disallowed as they are ambiguous or would require a cartesian
// product.
udfs.filterNot(attributeMap.contains).foreach { udf =>
- if (udf.references.subsetOf(plan.inputSet)) {
- sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
- } else {
- sys.error(s"Unable to evaluate PythonUDF $udf. Missing input attributes.")
- }
+ sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
}
val rewritten = plan.transformExpressions {