aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorWenchen Fan <cloud0fan@outlook.com>2015-06-24 13:28:50 -0700
committerMichael Armbrust <michael@databricks.com>2015-06-24 13:28:50 -0700
commitf04b5672c5a5562f8494df3b0df23235285c9e9e (patch)
treeefcff611b304a79e7c3e3f38f8455d8212f664fb
parentb84d4b4dfe8ced1b96a0c74ef968a20a1bba8231 (diff)
downloadspark-f04b5672c5a5562f8494df3b0df23235285c9e9e.tar.gz
spark-f04b5672c5a5562f8494df3b0df23235285c9e9e.tar.bz2
spark-f04b5672c5a5562f8494df3b0df23235285c9e9e.zip
[SPARK-7289] handle project -> limit -> sort efficiently
make the `TakeOrdered` strategy and operator more general, such that it can optionally handle a projection when necessary Author: Wenchen Fan <cloud0fan@outlook.com> Closes #6780 from cloud-fan/limit and squashes the following commits: 34aa07b [Wenchen Fan] revert 07d5456 [Wenchen Fan] clean closure 20821ec [Wenchen Fan] fix 3676a82 [Wenchen Fan] address comments b558549 [Wenchen Fan] address comments 214842b [Wenchen Fan] fix style 2d8be83 [Wenchen Fan] add LimitPushDown 948f740 [Wenchen Fan] fix existing
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala52
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala1
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala8
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala27
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala6
-rw-r--r--sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala2
8 files changed, 62 insertions, 40 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 98b4476076..bfd24287c9 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -39,19 +39,22 @@ object DefaultOptimizer extends Optimizer {
Batch("Distinct", FixedPoint(100),
ReplaceDistinctWithAggregate) ::
Batch("Operator Optimizations", FixedPoint(100),
- UnionPushdown,
- CombineFilters,
+ // Operator push down
+ UnionPushDown,
+ PushPredicateThroughJoin,
PushPredicateThroughProject,
PushPredicateThroughGenerate,
ColumnPruning,
+ // Operator combine
ProjectCollapsing,
+ CombineFilters,
CombineLimits,
+ // Constant folding
NullPropagation,
OptimizeIn,
ConstantFolding,
LikeSimplification,
BooleanSimplification,
- PushPredicateThroughJoin,
RemovePositive,
SimplifyFilters,
SimplifyCasts,
@@ -63,25 +66,25 @@ object DefaultOptimizer extends Optimizer {
}
/**
- * Pushes operations to either side of a Union.
- */
-object UnionPushdown extends Rule[LogicalPlan] {
+ * Pushes operations to either side of a Union.
+ */
+object UnionPushDown extends Rule[LogicalPlan] {
/**
- * Maps Attributes from the left side to the corresponding Attribute on the right side.
- */
- def buildRewrites(union: Union): AttributeMap[Attribute] = {
+ * Maps Attributes from the left side to the corresponding Attribute on the right side.
+ */
+ private def buildRewrites(union: Union): AttributeMap[Attribute] = {
assert(union.left.output.size == union.right.output.size)
AttributeMap(union.left.output.zip(union.right.output))
}
/**
- * Rewrites an expression so that it can be pushed to the right side of a Union operator.
- * This method relies on the fact that the output attributes of a union are always equal
- * to the left child's output.
- */
- def pushToRight[A <: Expression](e: A, rewrites: AttributeMap[Attribute]): A = {
+ * Rewrites an expression so that it can be pushed to the right side of a Union operator.
+ * This method relies on the fact that the output attributes of a union are always equal
+ * to the left child's output.
+ */
+ private def pushToRight[A <: Expression](e: A, rewrites: AttributeMap[Attribute]) = {
val result = e transform {
case a: Attribute => rewrites(a)
}
@@ -108,7 +111,6 @@ object UnionPushdown extends Rule[LogicalPlan] {
}
}
-
/**
* Attempts to eliminate the reading of unneeded columns from the query plan using the following
* transformations:
@@ -117,7 +119,6 @@ object UnionPushdown extends Rule[LogicalPlan] {
* - Aggregate
* - Project <- Join
* - LeftSemiJoin
- * - Performing alias substitution.
*/
object ColumnPruning extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
@@ -159,10 +160,11 @@ object ColumnPruning extends Rule[LogicalPlan] {
Join(left, prunedChild(right, allReferences), LeftSemi, condition)
+ // Push down project through limit, so that we may have chance to push it further.
case Project(projectList, Limit(exp, child)) =>
Limit(exp, Project(projectList, child))
- // push down project if possible when the child is sort
+ // Push down project if possible when the child is sort
case p @ Project(projectList, s @ Sort(_, _, grandChild))
if s.references.subsetOf(p.outputSet) =>
s.copy(child = Project(projectList, grandChild))
@@ -181,8 +183,8 @@ object ColumnPruning extends Rule[LogicalPlan] {
}
/**
- * Combines two adjacent [[Project]] operators into one, merging the
- * expressions into one single expression.
+ * Combines two adjacent [[Project]] operators into one and perform alias substitution,
+ * merging the expressions into one single expression.
*/
object ProjectCollapsing extends Rule[LogicalPlan] {
@@ -222,10 +224,10 @@ object ProjectCollapsing extends Rule[LogicalPlan] {
object LikeSimplification extends Rule[LogicalPlan] {
// if guards below protect from escapes on trailing %.
// Cases like "something\%" are not optimized, but this does not affect correctness.
- val startsWith = "([^_%]+)%".r
- val endsWith = "%([^_%]+)".r
- val contains = "%([^_%]+)%".r
- val equalTo = "([^_%]*)".r
+ private val startsWith = "([^_%]+)%".r
+ private val endsWith = "%([^_%]+)".r
+ private val contains = "%([^_%]+)%".r
+ private val equalTo = "([^_%]*)".r
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case Like(l, Literal(utf, StringType)) =>
@@ -497,7 +499,7 @@ object PushPredicateThroughProject extends Rule[LogicalPlan] {
grandChild))
}
- def replaceAlias(condition: Expression, sourceAliases: Map[Attribute, Expression]): Expression = {
+ private def replaceAlias(condition: Expression, sourceAliases: Map[Attribute, Expression]) = {
condition transform {
case a: AttributeReference => sourceAliases.getOrElse(a, a)
}
@@ -682,7 +684,7 @@ object DecimalAggregates extends Rule[LogicalPlan] {
import Decimal.MAX_LONG_DIGITS
/** Maximum number of decimal digits representable precisely in a Double */
- val MAX_DOUBLE_DIGITS = 15
+ private val MAX_DOUBLE_DIGITS = 15
def apply(plan: LogicalPlan): LogicalPlan = plan transformAllExpressions {
case Sum(e @ DecimalType.Expression(prec, scale)) if prec + 10 <= MAX_LONG_DIGITS =>
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala
index 35f50be46b..ec379489a6 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/UnionPushdownSuite.scala
@@ -24,13 +24,13 @@ import org.apache.spark.sql.catalyst.rules._
import org.apache.spark.sql.catalyst.dsl.plans._
import org.apache.spark.sql.catalyst.dsl.expressions._
-class UnionPushdownSuite extends PlanTest {
+class UnionPushDownSuite extends PlanTest {
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("Subqueries", Once,
EliminateSubQueries) ::
Batch("Union Pushdown", Once,
- UnionPushdown) :: Nil
+ UnionPushDown) :: Nil
}
val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
index 04fc798bf3..5708df82de 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/SQLContext.scala
@@ -858,7 +858,7 @@ class SQLContext(@transient val sparkContext: SparkContext)
experimental.extraStrategies ++ (
DataSourceStrategy ::
DDLStrategy ::
- TakeOrdered ::
+ TakeOrderedAndProject ::
HashAggregation ::
LeftSemiJoin ::
HashJoin ::
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
index 2b8d302942..47f56b2b7e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala
@@ -169,7 +169,6 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ
log.debug(
s"Creating MutableProj: $expressions, inputSchema: $inputSchema, codegen:$codegenEnabled")
if(codegenEnabled && expressions.forall(_.isThreadSafe)) {
-
GenerateMutableProjection.generate(expressions, inputSchema)
} else {
() => new InterpretedMutableProjection(expressions, inputSchema)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
index 1ff1cc224d..21912cf249 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkStrategies.scala
@@ -213,10 +213,14 @@ private[sql] abstract class SparkStrategies extends QueryPlanner[SparkPlan] {
protected lazy val singleRowRdd =
sparkContext.parallelize(Seq(new GenericRow(Array[Any]()): InternalRow), 1)
- object TakeOrdered extends Strategy {
+ object TakeOrderedAndProject extends Strategy {
def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
case logical.Limit(IntegerLiteral(limit), logical.Sort(order, true, child)) =>
- execution.TakeOrdered(limit, order, planLater(child)) :: Nil
+ execution.TakeOrderedAndProject(limit, order, None, planLater(child)) :: Nil
+ case logical.Limit(
+ IntegerLiteral(limit),
+ logical.Project(projectList, logical.Sort(order, true, child))) =>
+ execution.TakeOrderedAndProject(limit, order, Some(projectList), planLater(child)) :: Nil
case _ => Nil
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index 7aedd630e3..647c4ab5cb 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -39,8 +39,8 @@ case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends
@transient lazy val buildProjection = newMutableProjection(projectList, child.output)
protected override def doExecute(): RDD[InternalRow] = child.execute().mapPartitions { iter =>
- val resuableProjection = buildProjection()
- iter.map(resuableProjection)
+ val reusableProjection = buildProjection()
+ iter.map(reusableProjection)
}
override def outputOrdering: Seq[SortOrder] = child.outputOrdering
@@ -147,12 +147,18 @@ case class Limit(limit: Int, child: SparkPlan)
/**
* :: DeveloperApi ::
- * Take the first limit elements as defined by the sortOrder. This is logically equivalent to
- * having a [[Limit]] operator after a [[Sort]] operator. This could have been named TopK, but
- * Spark's top operator does the opposite in ordering so we name it TakeOrdered to avoid confusion.
+ * Take the first limit elements as defined by the sortOrder, and do projection if needed.
+ * This is logically equivalent to having a [[Limit]] operator after a [[Sort]] operator,
+ * or having a [[Project]] operator between them.
+ * This could have been named TopK, but Spark's top operator does the opposite in ordering
+ * so we name it TakeOrdered to avoid confusion.
*/
@DeveloperApi
-case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan) extends UnaryNode {
+case class TakeOrderedAndProject(
+ limit: Int,
+ sortOrder: Seq[SortOrder],
+ projectList: Option[Seq[NamedExpression]],
+ child: SparkPlan) extends UnaryNode {
override def output: Seq[Attribute] = child.output
@@ -160,8 +166,13 @@ case class TakeOrdered(limit: Int, sortOrder: Seq[SortOrder], child: SparkPlan)
private val ord: RowOrdering = new RowOrdering(sortOrder, child.output)
- private def collectData(): Array[InternalRow] =
- child.execute().map(_.copy()).takeOrdered(limit)(ord)
+ // TODO: remove @transient after figure out how to clean closure at InsertIntoHiveTable.
+ @transient private val projection = projectList.map(new InterpretedProjection(_, child.output))
+
+ private def collectData(): Array[InternalRow] = {
+ val data = child.execute().map(_.copy()).takeOrdered(limit)(ord)
+ projection.map(data.map(_)).getOrElse(data)
+ }
override def executeCollect(): Array[Row] = {
val converter = CatalystTypeConverters.createToScalaConverter(schema)
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
index 5854ab48db..3dd24130af 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/PlannerSuite.scala
@@ -141,4 +141,10 @@ class PlannerSuite extends SparkFunSuite {
setConf(SQLConf.AUTO_BROADCASTJOIN_THRESHOLD, origThreshold)
}
+
+ test("efficient limit -> project -> sort") {
+ val query = testData.sort('key).select('value).limit(2).logicalPlan
+ val planned = planner.TakeOrderedAndProject(query)
+ assert(planned.head.isInstanceOf[execution.TakeOrderedAndProject])
+ }
}
diff --git a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
index cf05c6c989..8021f915bb 100644
--- a/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
+++ b/sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveContext.scala
@@ -442,7 +442,7 @@ class HiveContext(sc: SparkContext) extends SQLContext(sc) {
HiveCommandStrategy(self),
HiveDDLStrategy,
DDLStrategy,
- TakeOrdered,
+ TakeOrderedAndProject,
ParquetOperations,
InMemoryScans,
ParquetConversion, // Must be before HiveTableScans