diff options
8 files changed, 64 insertions, 70 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py index 536ef55251..e4f79c911c 100644 --- a/python/pyspark/sql/tests.py +++ b/python/pyspark/sql/tests.py @@ -343,6 +343,15 @@ class SQLTests(ReusedPySparkTestCase): [res] = self.sqlCtx.sql("SELECT MYUDF('')").collect() self.assertEqual("", res[0]) + def test_udf_with_aggregate_function(self): + df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"]) + from pyspark.sql.functions import udf, col + from pyspark.sql.types import BooleanType + + my_filter = udf(lambda a: a == 1, BooleanType()) + sel = df.select(col("key")).distinct().filter(my_filter(col("key"))) + self.assertEqual(sel.collect(), [Row(key=1)]) + def test_basic_functions(self): rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}']) df = self.sqlCtx.read.json(rdd) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala index 63eb1aa24e..f5e1e77263 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala @@ -74,6 +74,7 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) { /** A sequence of rules that will be applied in order to the physical plan before execution. */ protected def preparations: Seq[Rule[SparkPlan]] = Seq( + python.ExtractPythonUDFs, PlanSubqueries(sqlContext), EnsureRequirements(sqlContext.conf), CollapseCodegenStages(sqlContext.conf), diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala index e1fabf519a..e52f05a5f4 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala @@ -392,8 +392,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] { case logical.RepartitionByExpression(expressions, child, nPartitions) => exchange.ShuffleExchange(HashPartitioning( expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil - case e @ python.EvaluatePython(udfs, child, _) => - python.BatchPythonEvaluation(udfs, e.output, planLater(child)) :: Nil case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil case BroadcastHint(child) => planLater(child) :: Nil case _ => Nil diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala index f3d1c44b25..3b05e29e52 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala @@ -35,30 +35,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, Generic import org.apache.spark.sql.types._ import org.apache.spark.unsafe.types.UTF8String -/** - * Evaluates a list of [[PythonUDF]], appending the result to the end of the input tuple. - */ -case class EvaluatePython( - udfs: Seq[PythonUDF], - child: LogicalPlan, - resultAttribute: Seq[AttributeReference]) - extends logical.UnaryNode { - - def output: Seq[Attribute] = child.output ++ resultAttribute - - // References should not include the produced attribute. - override def references: AttributeSet = AttributeSet(udfs.flatMap(_.references)) -} - - object EvaluatePython { - def apply(udfs: Seq[PythonUDF], child: LogicalPlan): EvaluatePython = { - val resultAttrs = udfs.zipWithIndex.map { case (u, i) => - AttributeReference(s"pythonUDF$i", u.dataType)() - } - new EvaluatePython(udfs, child, resultAttrs) - } - def takeAndServe(df: DataFrame, n: Int): Int = { registerPicklers() df.withNewExecutionId { diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala index 0934cd135d..d72b3d347d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala @@ -19,10 +19,10 @@ package org.apache.spark.sql.execution.python import scala.collection.mutable -import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression} -import org.apache.spark.sql.catalyst.plans.logical -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression} import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.execution +import org.apache.spark.sql.execution.SparkPlan /** * Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated @@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.rules.Rule * This has the limitation that the input to the Python UDF is not allowed include attributes from * multiple child operators. */ -private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] { +private[spark] object ExtractPythonUDFs extends Rule[SparkPlan] { private def hasPythonUDF(e: Expression): Boolean = { e.find(_.isInstanceOf[PythonUDF]).isDefined @@ -54,49 +54,61 @@ private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] { case e => e.children.flatMap(collectEvaluatableUDF) } - def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators { - // Skip EvaluatePython nodes. - case plan: EvaluatePython => plan + def apply(plan: SparkPlan): SparkPlan = plan transformUp { + case plan: SparkPlan => extract(plan) + } - case plan: LogicalPlan if plan.resolved => - // Extract any PythonUDFs from the current operator. - val udfs = plan.expressions.flatMap(collectEvaluatableUDF).filter(_.resolved) - if (udfs.isEmpty) { - // If there aren't any, we are done. - plan - } else { - val attributeMap = mutable.HashMap[PythonUDF, Expression]() - // Rewrite the child that has the input required for the UDF - val newChildren = plan.children.map { child => - // Pick the UDF we are going to evaluate - val validUdfs = udfs.filter { case udf => - // Check to make sure that the UDF can be evaluated with only the input of this child. - udf.references.subsetOf(child.outputSet) - } - if (validUdfs.nonEmpty) { - val evaluation = EvaluatePython(validUdfs, child) - attributeMap ++= validUdfs.zip(evaluation.resultAttribute) - evaluation - } else { - child - } + /** + * Extract all the PythonUDFs from the current operator. + */ + def extract(plan: SparkPlan): SparkPlan = { + val udfs = plan.expressions.flatMap(collectEvaluatableUDF) + if (udfs.isEmpty) { + // If there aren't any, we are done. + plan + } else { + val attributeMap = mutable.HashMap[PythonUDF, Expression]() + // Rewrite the child that has the input required for the UDF + val newChildren = plan.children.map { child => + // Pick the UDF we are going to evaluate + val validUdfs = udfs.filter { case udf => + // Check to make sure that the UDF can be evaluated with only the input of this child. + udf.references.subsetOf(child.outputSet) } - // Other cases are disallowed as they are ambiguous or would require a cartesian - // product. - udfs.filterNot(attributeMap.contains).foreach { udf => - if (udf.references.subsetOf(plan.inputSet)) { - sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.") - } else { - sys.error(s"Unable to evaluate PythonUDF $udf. Missing input attributes.") + if (validUdfs.nonEmpty) { + val resultAttrs = udfs.zipWithIndex.map { case (u, i) => + AttributeReference(s"pythonUDF$i", u.dataType)() } + val evaluation = BatchPythonEvaluation(validUdfs, child.output ++ resultAttrs, child) + attributeMap ++= validUdfs.zip(resultAttrs) + evaluation + } else { + child + } + } + // Other cases are disallowed as they are ambiguous or would require a cartesian + // product. + udfs.filterNot(attributeMap.contains).foreach { udf => + if (udf.references.subsetOf(plan.inputSet)) { + sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.") + } else { + sys.error(s"Unable to evaluate PythonUDF $udf. Missing input attributes.") } + } + + val rewritten = plan.transformExpressions { + case p: PythonUDF if attributeMap.contains(p) => + attributeMap(p) + }.withNewChildren(newChildren) + // extract remaining python UDFs recursively + val newPlan = extract(rewritten) + if (newPlan.output != plan.output) { // Trim away the new UDF value if it was only used for filtering or something. - logical.Project( - plan.output, - plan.transformExpressions { - case p: PythonUDF if attributeMap.contains(p) => attributeMap(p) - }.withNewChildren(newChildren)) + execution.Project(plan.output, newPlan) + } else { + newPlan } + } } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala index 4f1b837158..59d7e8dd6d 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala @@ -18,7 +18,6 @@ package org.apache.spark.sql.execution.python import org.apache.spark.api.python.PythonFunction -import org.apache.spark.internal.Logging import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable} import org.apache.spark.sql.types.DataType @@ -30,7 +29,7 @@ case class PythonUDF( func: PythonFunction, dataType: DataType, children: Seq[Expression]) - extends Expression with Unevaluable with NonSQLExpression with Logging { + extends Expression with Unevaluable with NonSQLExpression { override def toString: String = s"$name(${children.mkString(", ")})" diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala index cd3d254d1e..cd29def3be 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala @@ -64,7 +64,6 @@ private[sql] class SessionState(ctx: SQLContext) { lazy val analyzer: Analyzer = { new Analyzer(catalog, functionRegistry, conf) { override val extendedResolutionRules = - python.ExtractPythonUDFs :: PreInsertCastAndRename :: DataSourceAnalysis :: (if (conf.runSQLOnFile) new ResolveDataSource(ctx) :: Nil else Nil) diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala index ff40c366c8..829afa8432 100644 --- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala +++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala @@ -60,7 +60,6 @@ private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx) catalog.OrcConversions :: catalog.CreateTables :: catalog.PreInsertionCasts :: - python.ExtractPythonUDFs :: PreInsertCastAndRename :: DataSourceAnalysis :: (if (conf.runSQLOnFile) new ResolveDataSource(ctx) :: Nil else Nil) |