aboutsummaryrefslogtreecommitdiff
path: root/sql/core
diff options
context:
space:
mode:
Diffstat (limited to 'sql/core')
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala27
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala6
5 files changed, 32 insertions, 12 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 04fc798bf3..5708df82de 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -858,7 +858,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
experimental.extraStrategies ++ (
DataSourceStrategy ::
DDLStrategy ::
- TakeOrdered ::
+ TakeOrderedAndProject ::
HashAggregation ::
LeftSemiJoin ::
HashJoin ::
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 2b8d302942..47f56b2b7e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -169,7 +169,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
log.debug(
s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled")
if(codegenEnabled && expressions.forall(_.isThreadSafe)) {
-
GenerateMutableProjection.generate(expressions, inputSchema)
} else {
() => new InterpretedMutableProjection(expressions, inputSchema)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 1ff1cc224d..21912cf249 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -213,10 +213,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
protected lazy val singleRowRdd =
sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): InternalRow), 1)
- object TakeOrdered extends Strategy {
+ object TakeOrderedAndProject extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) =>
- execution.TakeOrdered(limit, order, planLater(child)) :: Nil
+ execution.TakeOrderedAndProject(limit, order, None, planLater(child)) :: Nil
+ case logical.Limit(
+ IntegerLiteral(limit),
+ logical.Project(projectList, logical.Sort(order, true, child))) =>
+ execution.TakeOrderedAndProject(limit, order, Some(projectList), planLater(child)) :: Nil
case _ => Nil
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 7aedd630e3..647c4ab5cb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -39,8 +39,8 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends
@transient lazy val buildProjection = newMutableProjection(projectList, child.output)
protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter =>
- val resuableProjection = buildProjection()
- iter.map(resuableProjection)
+ val reusableProjection = buildProjection()
+ iter.map(reusableProjection)
}
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
@@ -147,12 +147,18 @@ case class Limit(limit: Int, child: SparkPlan)
/**
* :: DeveloperApi ::
- * Take the first limit elements as defined by the sortOrder. This is logically equivalent to
- * having a [[Limit]] operator after a [[Sort]] operator. This could have been named TopK, but
- * Spark's top operator does the opposite in ordering so we name it TakeOrdered to avoid confusion.
+ * Take the first limit elements as defined by the sortOrder, and do projection if needed.
+ * This is logically equivalent to having a [[Limit]] operator after a [[Sort]] operator,
+ * or having a [[Project]] operator between them.
+ * This could have been named TopK, but Spark's top operator does the opposite in ordering
+ * so we name it TakeOrdered to avoid confusion.
*/
@DeveloperApi
-case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) extends UnaryNode {
+case class TakeOrderedAndProject(
+ limit: Int,
+ sortOrder: Seq[SortOrder],
+ projectList: Option[Seq[NamedExpression]],
+ child: SparkPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
@@ -160,8 +166,13 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
private val ord: RowOrdering = new RowOrdering(sortOrder, child.output)
- private def collectData(): Array[InternalRow] =
- child.execute().map(_.copy()).takeOrdered(limit)(ord)
+ // TODO: remove @transient after figure out how to clean closure at InsertIntoHiveTable.
+ @transient private val projection = projectList.map(new InterpretedProjection(_, child.output))
+
+ private def collectData(): Array[InternalRow] = {
+ val data = child.execute().map(_.copy()).takeOrdered(limit)(ord)
+ projection.map(data.map(_)).getOrElse(data)
+ }
override def executeCollect(): Array[Row] = {
val converter = CatalystTypeConverters.createToScalaConverter(schema)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 5854ab48db..3dd24130af 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -141,4 +141,10 @@ class PlannerSuite extends SparkFunSuite {
setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold)
}
+
+ test("efficient limit -> project -> sort") {
+ val query = testData.sort('key).select('value).limit(2).logicalPlan
+ val planned = planner.TakeOrderedAndProject(query)
+ assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject])
+ }
}