aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/sql/tests.py9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala23
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala94
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala3
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala1
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala1
8 files changed, 64 insertions, 70 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index 536ef55251..e4f79c911c 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -343,6 +343,15 @@ class SQLTests(ReusedPySparkTestCase):
[res] = self.sqlCtx.sql("SELECT MYUDF('')").collect()
self.assertEqual("", res[0])
+ def test_udf_with_aggregate_function(self):
+ df = self.sqlCtx.createDataFrame([(1, "1"), (2, "2"), (1, "2"), (1, "2")], ["key", "value"])
+ from pyspark.sql.functions import udf, col
+ from pyspark.sql.types import BooleanType
+
+ my_filter = udf(lambda a: a == 1, BooleanType())
+ sel = df.select(col("key")).distinct().filter(my_filter(col("key")))
+ self.assertEqual(sel.collect(), [Row(key=1)])
+
def test_basic_functions(self):
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
df = self.sqlCtx.read.json(rdd)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
index 63eb1aa24e..f5e1e77263 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/QueryExecution.scala
@@ -74,6 +74,7 @@ class QueryExecution(val sqlContext: SQLContext, val logical: LogicalPlan) {
/** A sequence of rules that will be applied in order to the physical plan before execution. */
protected def preparations: Seq[Rule[SparkPlan]] = Seq(
+ python.ExtractPythonUDFs,
PlanSubqueries(sqlContext),
EnsureRequirements(sqlContext.conf),
CollapseCodegenStages(sqlContext.conf),
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index e1fabf519a..e52f05a5f4 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -392,8 +392,6 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
case logical.RepartitionByExpression(expressions, child, nPartitions) =>
exchange.ShuffleExchange(HashPartitioning(
expressions, nPartitions.getOrElse(numPartitions)), planLater(child)) :: Nil
- case e @ python.EvaluatePython(udfs, child, _) =>
- python.BatchPythonEvaluation(udfs, e.output, planLater(child)) :: Nil
case LogicalRDD(output, rdd) => PhysicalRDD(output, rdd, "ExistingRDD") :: Nil
case BroadcastHint(child) => planLater(child) :: Nil
case _ => Nil
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
index f3d1c44b25..3b05e29e52 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/EvaluatePython.scala
@@ -35,30 +35,7 @@ import org.apache.spark.sql.catalyst.util.{ArrayBasedMapData, ArrayData, Generic
import org.apache.spark.sql.types._
import org.apache.spark.unsafe.types.UTF8String
-/**
- * Evaluates a list of [[PythonUDF]], appending the result to the end of the input tuple.
- */
-case class EvaluatePython(
- udfs: Seq[PythonUDF],
- child: LogicalPlan,
- resultAttribute: Seq[AttributeReference])
- extends logical.UnaryNode {
-
- def output: Seq[Attribute] = child.output ++ resultAttribute
-
- // References should not include the produced attribute.
- override def references: AttributeSet = AttributeSet(udfs.flatMap(_.references))
-}
-
-
object EvaluatePython {
- def apply(udfs: Seq[PythonUDF], child: LogicalPlan): EvaluatePython = {
- val resultAttrs = udfs.zipWithIndex.map { case (u, i) =>
- AttributeReference(s"pythonUDF$i", u.dataType)()
- }
- new EvaluatePython(udfs, child, resultAttrs)
- }
-
def takeAndServe(df: DataFrame, n: Int): Int = {
registerPicklers()
df.withNewExecutionId {
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
index 0934cd135d..d72b3d347d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/ExtractPythonUDFs.scala
@@ -19,10 +19,10 @@ package org.apache.spark.sql.execution.python
import scala.collection.mutable
-import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression}
-import org.apache.spark.sql.catalyst.plans.logical
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression}
import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.execution
+import org.apache.spark.sql.execution.SparkPlan
/**
* Extracts PythonUDFs from operators, rewriting the query plan so that the UDF can be evaluated
@@ -34,7 +34,7 @@ import org.apache.spark.sql.catalyst.rules.Rule
* This has the limitation that the input to the Python UDF is not allowed include attributes from
* multiple child operators.
*/
-private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] {
+private[spark] object ExtractPythonUDFs extends Rule[SparkPlan] {
private def hasPythonUDF(e: Expression): Boolean = {
e.find(_.isInstanceOf[PythonUDF]).isDefined
@@ -54,49 +54,61 @@ private[spark] object ExtractPythonUDFs extends Rule[LogicalPlan] {
case e => e.children.flatMap(collectEvaluatableUDF)
}
- def apply(plan: LogicalPlan): LogicalPlan = plan resolveOperators {
- // Skip EvaluatePython nodes.
- case plan: EvaluatePython => plan
+ def apply(plan: SparkPlan): SparkPlan = plan transformUp {
+ case plan: SparkPlan => extract(plan)
+ }
- case plan: LogicalPlan if plan.resolved =>
- // Extract any PythonUDFs from the current operator.
- val udfs = plan.expressions.flatMap(collectEvaluatableUDF).filter(_.resolved)
- if (udfs.isEmpty) {
- // If there aren't any, we are done.
- plan
- } else {
- val attributeMap = mutable.HashMap[PythonUDF, Expression]()
- // Rewrite the child that has the input required for the UDF
- val newChildren = plan.children.map { child =>
- // Pick the UDF we are going to evaluate
- val validUdfs = udfs.filter { case udf =>
- // Check to make sure that the UDF can be evaluated with only the input of this child.
- udf.references.subsetOf(child.outputSet)
- }
- if (validUdfs.nonEmpty) {
- val evaluation = EvaluatePython(validUdfs, child)
- attributeMap ++= validUdfs.zip(evaluation.resultAttribute)
- evaluation
- } else {
- child
- }
+ /**
+ * Extract all the PythonUDFs from the current operator.
+ */
+ def extract(plan: SparkPlan): SparkPlan = {
+ val udfs = plan.expressions.flatMap(collectEvaluatableUDF)
+ if (udfs.isEmpty) {
+ // If there aren't any, we are done.
+ plan
+ } else {
+ val attributeMap = mutable.HashMap[PythonUDF, Expression]()
+ // Rewrite the child that has the input required for the UDF
+ val newChildren = plan.children.map { child =>
+ // Pick the UDF we are going to evaluate
+ val validUdfs = udfs.filter { case udf =>
+ // Check to make sure that the UDF can be evaluated with only the input of this child.
+ udf.references.subsetOf(child.outputSet)
}
- // Other cases are disallowed as they are ambiguous or would require a cartesian
- // product.
- udfs.filterNot(attributeMap.contains).foreach { udf =>
- if (udf.references.subsetOf(plan.inputSet)) {
- sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
- } else {
- sys.error(s"Unable to evaluate PythonUDF $udf. Missing input attributes.")
+ if (validUdfs.nonEmpty) {
+ val resultAttrs = udfs.zipWithIndex.map { case (u, i) =>
+ AttributeReference(s"pythonUDF$i", u.dataType)()
}
+ val evaluation = BatchPythonEvaluation(validUdfs, child.output ++ resultAttrs, child)
+ attributeMap ++= validUdfs.zip(resultAttrs)
+ evaluation
+ } else {
+ child
+ }
+ }
+ // Other cases are disallowed as they are ambiguous or would require a cartesian
+ // product.
+ udfs.filterNot(attributeMap.contains).foreach { udf =>
+ if (udf.references.subsetOf(plan.inputSet)) {
+ sys.error(s"Invalid PythonUDF $udf, requires attributes from more than one child.")
+ } else {
+ sys.error(s"Unable to evaluate PythonUDF $udf. Missing input attributes.")
}
+ }
+
+ val rewritten = plan.transformExpressions {
+ case p: PythonUDF if attributeMap.contains(p) =>
+ attributeMap(p)
+ }.withNewChildren(newChildren)
+ // extract remaining python UDFs recursively
+ val newPlan = extract(rewritten)
+ if (newPlan.output != plan.output) {
// Trim away the new UDF value if it was only used for filtering or something.
- logical.Project(
- plan.output,
- plan.transformExpressions {
- case p: PythonUDF if attributeMap.contains(p) => attributeMap(p)
- }.withNewChildren(newChildren))
+ execution.Project(plan.output, newPlan)
+ } else {
+ newPlan
}
+ }
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
index 4f1b837158..59d7e8dd6d 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonUDF.scala
@@ -18,7 +18,6 @@
package org.apache.spark.sql.execution.python
import org.apache.spark.api.python.PythonFunction
-import org.apache.spark.internal.Logging
import org.apache.spark.sql.catalyst.expressions.{Expression, NonSQLExpression, Unevaluable}
import org.apache.spark.sql.types.DataType
@@ -30,7 +29,7 @@ case class PythonUDF(
func: PythonFunction,
dataType: DataType,
children: Seq[Expression])
- extends Expression with Unevaluable with NonSQLExpression with Logging {
+ extends Expression with Unevaluable with NonSQLExpression {
override def toString: String = s"$name(${children.mkString(", ")})"
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
index cd3d254d1e..cd29def3be 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/internal/SessionState.scala
@@ -64,7 +64,6 @@ private[sql] class SessionState(ctx: SQLContext) {
lazy val analyzer: Analyzer = {
new Analyzer(catalog, functionRegistry, conf) {
override val extendedResolutionRules =
- python.ExtractPythonUDFs ::
PreInsertCastAndRename ::
DataSourceAnalysis ::
(if (conf.runSQLOnFile) new ResolveDataSource(ctx) :: Nil else Nil)
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
index ff40c366c8..829afa8432 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveSessionState.scala
@@ -60,7 +60,6 @@ private[hive] class HiveSessionState(ctx: HiveContext) extends SessionState(ctx)
catalog.OrcConversions ::
catalog.CreateTables ::
catalog.PreInsertionCasts ::
- python.ExtractPythonUDFs ::
PreInsertCastAndRename ::
DataSourceAnalysis ::
(if (conf.runSQLOnFile) new ResolveDataSource(ctx) :: Nil else Nil)