aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--python/pyspark/sql/tests.py8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala12
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala4
4 files changed, 20 insertions, 12 deletions
diff --git a/python/pyspark/sql/tests.py b/python/pyspark/sql/tests.py
index fd8e9cec3e..769e454072 100644
--- a/python/pyspark/sql/tests.py
+++ b/python/pyspark/sql/tests.py
@@ -376,6 +376,14 @@ class SQLTests(ReusedPySparkTestCase):
row = df.select(explode(f(*df))).groupBy().sum().first()
self.assertEqual(row[0], 10)
+ def test_udf_with_order_by_and_limit(self):
+ from pyspark.sql.functions import udf
+ my_copy = udf(lambda x: x, IntegerType())
+ df = self.spark.range(10).orderBy("id")
+ res = df.select(df.id, my_copy(df.id).alias("copy")).limit(1)
+ res.explain(True)
+ self.assertEqual(res.collect(), [Row(id=0, copy=0)])
+
def test_basic_functions(self):
rdd = self.sc.parallelize(['{"foo":"bar"}', '{"foo":"baz"}'])
df = self.spark.read.json(rdd)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index c389593b4f..3441ccf53b 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -66,22 +66,22 @@ abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.ReturnAnswer(rootPlan) => rootPlan match {
case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) =>
- execution.TakeOrderedAndProjectExec(limit, order, None, planLater(child)) :: Nil
+ execution.TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil
case logical.Limit(
IntegerLiteral(limit),
logical.Project(projectList, logical.Sort(order, true, child))) =>
execution.TakeOrderedAndProjectExec(
- limit, order, Some(projectList), planLater(child)) :: Nil
+ limit, order, projectList, planLater(child)) :: Nil
case logical.Limit(IntegerLiteral(limit), child) =>
execution.CollectLimitExec(limit, planLater(child)) :: Nil
case other => planLater(other) :: Nil
}
case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) =>
- execution.TakeOrderedAndProjectExec(limit, order, None, planLater(child)) :: Nil
+ execution.TakeOrderedAndProjectExec(limit, order, child.output, planLater(child)) :: Nil
case logical.Limit(
IntegerLiteral(limit), logical.Project(projectList, logical.Sort(order, true, child))) =>
execution.TakeOrderedAndProjectExec(
- limit, order, Some(projectList), planLater(child)) :: Nil
+ limit, order, projectList, planLater(child)) :: Nil
case _ => Nil
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
index 781c016095..01fbe5b7c2 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/limit.scala
@@ -114,11 +114,11 @@ case class GlobalLimitExec(limit: Int, child: SparkPlan) extends BaseLimitExec {
case class TakeOrderedAndProjectExec(
limit: Int,
sortOrder: Seq[SortOrder],
- projectList: Option[Seq[NamedExpression]],
+ projectList: Seq[NamedExpression],
child: SparkPlan) extends UnaryExecNode {
override def output: Seq[Attribute] = {
- projectList.map(_.map(_.toAttribute)).getOrElse(child.output)
+ projectList.map(_.toAttribute)
}
override def outputPartitioning: Partitioning = SinglePartition
@@ -126,8 +126,8 @@ case class TakeOrderedAndProjectExec(
override def executeCollect(): Array[InternalRow] = {
val ord = new LazilyGeneratedOrdering(sortOrder, child.output)
val data = child.execute().map(_.copy()).takeOrdered(limit)(ord)
- if (projectList.isDefined) {
- val proj = UnsafeProjection.create(projectList.get, child.output)
+ if (projectList != child.output) {
+ val proj = UnsafeProjection.create(projectList, child.output)
data.map(r => proj(r).copy())
} else {
data
@@ -148,8 +148,8 @@ case class TakeOrderedAndProjectExec(
localTopK, child.output, SinglePartition, serializer))
shuffled.mapPartitions { iter =>
val topK = org.apache.spark.util.collection.Utils.takeOrdered(iter.map(_.copy()), limit)(ord)
- if (projectList.isDefined) {
- val proj = UnsafeProjection.create(projectList.get, child.output)
+ if (projectList != child.output) {
+ val proj = UnsafeProjection.create(projectList, child.output)
topK.map(r => proj(r))
} else {
topK
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala
index 3217e34bd8..7e317a4d80 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/TakeOrderedAndProjectSuite.scala
@@ -59,7 +59,7 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext {
checkThatPlansAgree(
generateRandomInputData(),
input =>
- noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, None, input)),
+ noOpFilter(TakeOrderedAndProjectExec(limit, sortOrder, input.output, input)),
input =>
GlobalLimitExec(limit,
LocalLimitExec(limit,
@@ -74,7 +74,7 @@ class TakeOrderedAndProjectSuite extends SparkPlanTest with SharedSQLContext {
generateRandomInputData(),
input =>
noOpFilter(
- TakeOrderedAndProjectExec(limit, sortOrder, Some(Seq(input.output.last)), input)),
+ TakeOrderedAndProjectExec(limit, sortOrder, Seq(input.output.last), input)),
input =>
GlobalLimitExec(limit,
LocalLimitExec(limit,