aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-04-04 10:56:26 -0700
committerDavies Liu <davies.liu@gmail.com>2016-04-04 10:56:26 -0700
commit5743c6476dbef50852b7f9873112a2d299966ebd (patch)
tree63c8a3682266a0eb2fbcf83ab298d2082bf4bdca /sql
parent855ed44ed31210d2001d7ce67c8fa99f8416edd3 (diff)
downloadspark-5743c6476dbef50852b7f9873112a2d299966ebd.tar.gz
spark-5743c6476dbef50852b7f9873112a2d299966ebd.tar.bz2
spark-5743c6476dbef50852b7f9873112a2d299966ebd.zip
[SPARK-12981] [SQL] extract Pyhton UDF in physical plan
## What changes were proposed in this pull request? Currently we extract Python UDFs into a special logical plan EvaluatePython in analyzer, But EvaluatePython is not part of catalyst, many rules have no knowledge of it , which will break many things (for example, filter push down or column pruning). We should treat Python UDFs as normal expressions, until we want to evaluate in physical plan, we could extract them in end of optimizer, or physical plan. This PR extract Python UDFs in physical plan. Closes #10935 ## How was this patch tested? Added regression tests. Author: Davies Liu <davies@databricks.com> Closes #12127 from davies/py_udf.
Diffstat (limited to 'sql')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala23
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala94
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala1
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala1
7 files changed, 55 insertions, 70 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index 63eb1aa24e..f5e1e77263 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -74,6 +74,7 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) {
/** A sequence of rules that will be applied in order to the physical plan before execution. */
protected def preparations: Seq[Rule[SparkPlan]] = Seq(
+ python.ExtractPythonUDFs,
PlanSubqueries(sqlContext),
EnsureRequirements(sqlContext.conf),
CollapseCodegenStages(sqlContext.conf),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index e1fabf519a..e52f05a5f4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -392,8 +392,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.RepartitionByExpression(expressions, child, nPartitions) =>
exchange.ShuffleExchange(HashPartitioning(
expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil
- case e @ python.EvaluatePython(udfs, child, _) =>
- python.BatchPythonEvaluation(udfs, e.output, planLater(child)) :: Nil
case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil
case BroadcastHint(child) => planLater(child) :: Nil
case _ => Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
index f3d1c44b25..3b05e29e52 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
@@ -35,30 +35,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, Generic
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
-/**
- * Evaluates a list of [[PythonUDF]], appending the result to the end of the input tuple.
- */
-case class EvaluatePython(
- udfs: Seq[PythonUDF],
- child: LogicalPlan,
- resultAttribute: Seq[AttributeReference])
- extends logical.UnaryNode {
-
- def output: Seq[Attribute] = child.output ++ resultAttribute
-
- // References should not include the produced attribute.
- override def references: AttributeSet = AttributeSet(udfs.flatMap(_.references))
-}
-
-
object EvaluatePython {
- def apply(udfs: Seq[PythonUDF], child: LogicalPlan): EvaluatePython = {
- val resultAttrs = udfs.zipWithIndex.map { case (u, i) =>
- AttributeReference(s"pythonUDF$i", u.dataType)()
- }
- new EvaluatePython(udfs, child, resultAttrs)
- }
-
def takeAndServe(df: DataFrame, n: Int): Int = {
registerPicklers()
df.withNewExecutionId {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
index 0934cd135d..d72b3d347d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -19,10 +19,10 @@ package org.apache.spark.sql.execution.python
import scala.collection.mutable
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
-import org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution
+import org.apache.spark.sql.execution.SparkPlan
/**
* Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated
@@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
* This has the limitation that the input to the Python UDF is not allowed include attributes from
* multiple child operators.
*/
-private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] {
+private[spark] object ExtractPythonUDFs extends Rule[SparkPlan] {
private def hasPythonUDF(e: Expression): Boolean = {
e.find(_.isInstanceOf[PythonUDF]).isDefined
@@ -54,49 +54,61 @@ private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] {
case e => e.children.flatMap(collectEvaluatableUDF)
}
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
- // Skip EvaluatePython nodes.
- case plan: EvaluatePython => plan
+ def apply(plan: SparkPlan): SparkPlan = plan transformUp {
+ case plan: SparkPlan => extract(plan)
+ }
- case plan: LogicalPlan if plan.resolved =>
- // Extract any PythonUDFs from the current operator.
- val udfs = plan.expressions.flatMap(collectEvaluatableUDF).filter(_.resolved)
- if (udfs.isEmpty) {
- // If there aren't any, we are done.
- plan
- } else {
- val attributeMap = mutable.HashMap[PythonUDF, Expression]()
- // Rewrite the child that has the input required for the UDF
- val newChildren = plan.children.map { child =>
- // Pick the UDF we are going to evaluate
- val validUdfs = udfs.filter { case udf =>
- // Check to make sure that the UDF can be evaluated with only the input of this child.
- udf.references.subsetOf(child.outputSet)
- }
- if (validUdfs.nonEmpty) {
- val evaluation = EvaluatePython(validUdfs, child)
- attributeMap ++= validUdfs.zip(evaluation.resultAttribute)
- evaluation
- } else {
- child
- }
+ /**
+ * Extract all the PythonUDFs from the current operator.
+ */
+ def extract(plan: SparkPlan): SparkPlan = {
+ val udfs = plan.expressions.flatMap(collectEvaluatableUDF)
+ if (udfs.isEmpty) {
+ // If there aren't any, we are done.
+ plan
+ } else {
+ val attributeMap = mutable.HashMap[PythonUDF, Expression]()
+ // Rewrite the child that has the input required for the UDF
+ val newChildren = plan.children.map { child =>
+ // Pick the UDF we are going to evaluate
+ val validUdfs = udfs.filter { case udf =>
+ // Check to make sure that the UDF can be evaluated with only the input of this child.
+ udf.references.subsetOf(child.outputSet)
}
- // Other cases are disallowed as they are ambiguous or would require a cartesian
- // product.
- udfs.filterNot(attributeMap.contains).foreach { udf =>
- if (udf.references.subsetOf(plan.inputSet)) {
- sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
- } else {
- sys.error(s"Unable to evaluate PythonUDF $udf. Missing input attributes.")
+ if (validUdfs.nonEmpty) {
+ val resultAttrs = udfs.zipWithIndex.map { case (u, i) =>
+ AttributeReference(s"pythonUDF$i", u.dataType)()
}
+ val evaluation = BatchPythonEvaluation(validUdfs, child.output ++ resultAttrs, child)
+ attributeMap ++= validUdfs.zip(resultAttrs)
+ evaluation
+ } else {
+ child
+ }
+ }
+ // Other cases are disallowed as they are ambiguous or would require a cartesian
+ // product.
+ udfs.filterNot(attributeMap.contains).foreach { udf =>
+ if (udf.references.subsetOf(plan.inputSet)) {
+ sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
+ } else {
+ sys.error(s"Unable to evaluate PythonUDF $udf. Missing input attributes.")
}
+ }
+
+ val rewritten = plan.transformExpressions {
+ case p: PythonUDF if attributeMap.contains(p) =>
+ attributeMap(p)
+ }.withNewChildren(newChildren)
+ // extract remaining python UDFs recursively
+ val newPlan = extract(rewritten)
+ if (newPlan.output != plan.output) {
// Trim away the new UDF value if it was only used for filtering or something.
- logical.Project(
- plan.output,
- plan.transformExpressions {
- case p: PythonUDF if attributeMap.contains(p) => attributeMap(p)
- }.withNewChildren(newChildren))
+ execution.Project(plan.output, newPlan)
+ } else {
+ newPlan
}
+ }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
index 4f1b837158..59d7e8dd6d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
@@ -18,7 +18,6 @@
package org.apache.spark.sql.execution.python
import org.apache.spark.api.python.PythonFunction
-import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable}
import org.apache.spark.sql.types.DataType
@@ -30,7 +29,7 @@ case class PythonUDF(
func: PythonFunction,
dataType: DataType,
children: Seq[Expression])
- extends Expression with Unevaluable with NonSQLExpression with Logging {
+ extends Expression with Unevaluable with NonSQLExpression {
override def toString: String = s"$name(${children.mkString(", ")})"
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
index cd3d254d1e..cd29def3be 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
@@ -64,7 +64,6 @@ private[sql] class SessionState(ctx: SQLContext) {
lazy val analyzer: Analyzer = {
new Analyzer(catalog, functionRegistry, conf) {
override val extendedResolutionRules =
- python.ExtractPythonUDFs ::
PreInsertCastAndRename ::
DataSourceAnalysis ::
(if (conf.runSQLOnFile) new ResolveDataSource(ctx) :: Nil else Nil)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
index ff40c366c8..829afa8432 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
@@ -60,7 +60,6 @@ private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx)
catalog.OrcConversions ::
catalog.CreateTables ::
catalog.PreInsertionCasts ::
- python.ExtractPythonUDFs ::
PreInsertCastAndRename ::
DataSourceAnalysis ::
(if (conf.runSQLOnFile) new ResolveDataSource(ctx) :: Nil else Nil)