aboutsummaryrefslogtreecommitdiff
path: root/sql/core/src
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-09-12 16:35:42 -0700
committerDavies Liu <davies.liu@gmail.com>2016-09-12 16:35:42 -0700
commita91ab705e8c124aa116c3e5b1f3ba88ce832dcde (patch)
treeec4c202b2f8ab704a2ce1058ade1aeb2d91af0f4 /sql/core/src
parentf9c580f11098d95f098936a0b90fa21d71021205 (diff)
downloadspark-a91ab705e8c124aa116c3e5b1f3ba88ce832dcde.tar.gz
spark-a91ab705e8c124aa116c3e5b1f3ba88ce832dcde.tar.bz2
spark-a91ab705e8c124aa116c3e5b1f3ba88ce832dcde.zip
[SPARK-17474] [SQL] fix python udf in TakeOrderedAndProjectExec
## What changes were proposed in this pull request? When there is any Python UDF in the Project between Sort and Limit, it will be collected into TakeOrderedAndProjectExec, ExtractPythonUDFs failed to pull the Python UDFs out because QueryPlan.expressions does not include the expression inside Option[Seq[Expression]]. Ideally, we should fix the `QueryPlan.expressions`, but tried with no luck (it always run into infinite loop). In PR, I changed the TakeOrderedAndProjectExec to no use Option[Seq[Expression]] to workaround it. cc JoshRosen ## How was this patch tested? Added regression test. Author: Davies Liu <davies@databricks.com> Closes #15030 from davies/all_expr.
Diffstat (limited to 'sql/core/src')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala4
3 files changed, 12 insertions, 12 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index c389593b4f..3441ccf53b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -66,22 +66,22 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.ReturnAnswer(rootPlan) => rootPlan match {
case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) =>
- execution.TakeOrderedAndProjectExec(limit, order, None, planLater(child)) :: Nil
+ execution.TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil
case logical.Limit(
IntegerLiteral(limit),
logical.Project(projectList, logical.Sort(order, true, child))) =>
execution.TakeOrderedAndProjectExec(
- limit, order, Some(projectList), planLater(child)) :: Nil
+ limit, order, projectList, planLater(child)) :: Nil
case logical.Limit(IntegerLiteral(limit), child) =>
execution.CollectLimitExec(limit, planLater(child)) :: Nil
case other => planLater(other) :: Nil
}
case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) =>
- execution.TakeOrderedAndProjectExec(limit, order, None, planLater(child)) :: Nil
+ execution.TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil
case logical.Limit(
IntegerLiteral(limit), logical.Project(projectList, logical.Sort(order, true, child))) =>
execution.TakeOrderedAndProjectExec(
- limit, order, Some(projectList), planLater(child)) :: Nil
+ limit, order, projectList, planLater(child)) :: Nil
case _ => Nil
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
index 781c016095..01fbe5b7c2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
@@ -114,11 +114,11 @@ case class GlobalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec {
case class TakeOrderedAndProjectExec(
limit: Int,
sortOrder: Seq[SortOrder],
- projectList: Option[Seq[NamedExpression]],
+ projectList: Seq[NamedExpression],
child: SparkPlan) extends UnaryExecNode {
override def output: Seq[Attribute] = {
- projectList.map(_.map(_.toAttribute)).getOrElse(child.output)
+ projectList.map(_.toAttribute)
}
override def outputPartitioning: Partitioning = SinglePartition
@@ -126,8 +126,8 @@ case class TakeOrderedAndProjectExec(
override def executeCollect(): Array[InternalRow] = {
val ord = new LazilyGeneratedOrdering(sortOrder, child.output)
val data = child.execute().map(_.copy()).takeOrdered(limit)(ord)
- if (projectList.isDefined) {
- val proj = UnsafeProjection.create(projectList.get, child.output)
+ if (projectList != child.output) {
+ val proj = UnsafeProjection.create(projectList, child.output)
data.map(r => proj(r).copy())
} else {
data
@@ -148,8 +148,8 @@ case class TakeOrderedAndProjectExec(
localTopK, child.output, SinglePartition, serializer))
shuffled.mapPartitions { iter =>
val topK = org.apache.spark.util.collection.Utils.takeOrdered(iter.map(_.copy()), limit)(ord)
- if (projectList.isDefined) {
- val proj = UnsafeProjection.create(projectList.get, child.output)
+ if (projectList != child.output) {
+ val proj = UnsafeProjection.create(projectList, child.output)
topK.map(r => proj(r))
} else {
topK
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala
index 3217e34bd8..7e317a4d80 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala
@@ -59,7 +59,7 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext {
checkThatPlansAgree(
generateRandomInputData(),
input =>
- noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, None, input)),
+ noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)),
input =>
GlobalLimitExec(limit,
LocalLimitExec(limit,
@@ -74,7 +74,7 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext {
generateRandomInputData(),
input =>
noOpFilter(
- TakeOrderedAndProjectExec(limit, sortOrder, Some(Seq(input.output.last)), input)),
+ TakeOrderedAndProjectExec(limit, sortOrder, Seq(input.output.last), input)),
input =>
GlobalLimitExec(limit,
LocalLimitExec(limit,