aboutsummaryrefslogtreecommitdiff
path: root/sql
diff options
context:
space:
mode:
authorDavies Liu <davies@databricks.com>2016-02-25 00:13:07 -0800
committerDavies Liu <davies.liu@gmail.com>2016-02-25 00:13:07 -0800
commit07f92ef1fa090821bef9c60689bf41909d781ee7 (patch)
tree10d8f563352dffea1a2b3363bb2659174187562a /sql
parent264533b553be806b6c45457201952e83c028ec78 (diff)
downloadspark-07f92ef1fa090821bef9c60689bf41909d781ee7.tar.gz
spark-07f92ef1fa090821bef9c60689bf41909d781ee7.tar.bz2
spark-07f92ef1fa090821bef9c60689bf41909d781ee7.zip
[SPARK-13376] [SPARK-13476] [SQL] improve column pruning
## What changes were proposed in this pull request? This PR mostly rewrite the ColumnPruning rule to support most of the SQL logical plans (except those for Dataset). This PR also fix a bug in Generate, it should always output UnsafeRow, added an regression test for that. ## How was this patch tested? This is test by unit tests, also manually test with TPCDS Q78, which could prune all unused columns successfully, improved the performance by 78% (from 22s to 12s). Author: Davies Liu <davies@databricks.com> Closes #11354 from davies/fix_column_pruning.
Diffstat (limited to 'sql')
-rw-r--r--sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala128
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala128
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala80
-rw-r--r--sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala2
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala28
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala7
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala8
7 files changed, 215 insertions, 166 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
index 1f05f2065c..2b804976f3 100644
--- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
+++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/Optimizer.scala
@@ -313,97 +313,85 @@ object SetOperationPushDown extends Rule[LogicalPlan] with PredicateHelper {
*/
object ColumnPruning extends Rule[LogicalPlan] {
def apply(plan: LogicalPlan): LogicalPlan = plan transform {
- case a @ Aggregate(_, _, e @ Expand(projects, output, child))
- if (e.outputSet -- a.references).nonEmpty =>
- val newOutput = output.filter(a.references.contains(_))
- val newProjects = projects.map { proj =>
- proj.zip(output).filter { case (e, a) =>
+ // Prunes the unused columns from project list of Project/Aggregate/Window/Expand
+ case p @ Project(_, p2: Project) if (p2.outputSet -- p.references).nonEmpty =>
+ p.copy(child = p2.copy(projectList = p2.projectList.filter(p.references.contains)))
+ case p @ Project(_, a: Aggregate) if (a.outputSet -- p.references).nonEmpty =>
+ p.copy(
+ child = a.copy(aggregateExpressions = a.aggregateExpressions.filter(p.references.contains)))
+ case p @ Project(_, w: Window) if (w.outputSet -- p.references).nonEmpty =>
+ p.copy(child = w.copy(
+ projectList = w.projectList.filter(p.references.contains),
+ windowExpressions = w.windowExpressions.filter(p.references.contains)))
+ case a @ Project(_, e @ Expand(_, _, grandChild)) if (e.outputSet -- a.references).nonEmpty =>
+ val newOutput = e.output.filter(a.references.contains(_))
+ val newProjects = e.projections.map { proj =>
+ proj.zip(e.output).filter { case (e, a) =>
newOutput.contains(a)
}.unzip._1
}
- a.copy(child = Expand(newProjects, newOutput, child))
+ a.copy(child = Expand(newProjects, newOutput, grandChild))
+ // TODO: support some logical plan for Dataset
- case a @ Aggregate(_, _, e @ Expand(_, _, child))
- if (child.outputSet -- e.references -- a.references).nonEmpty =>
- a.copy(child = e.copy(child = prunedChild(child, e.references ++ a.references)))
-
- // Eliminate attributes that are not needed to calculate the specified aggregates.
+ // Prunes the unused columns from child of Aggregate/Window/Expand/Generate
case a @ Aggregate(_, _, child) if (child.outputSet -- a.references).nonEmpty =>
- a.copy(child = Project(a.references.toSeq, child))
-
- // Eliminate attributes that are not needed to calculate the Generate.
+ a.copy(child = prunedChild(child, a.references))
+ case w @ Window(_, _, _, _, child) if (child.outputSet -- w.references).nonEmpty =>
+ w.copy(child = prunedChild(child, w.references))
+ case e @ Expand(_, _, child) if (child.outputSet -- e.references).nonEmpty =>
+ e.copy(child = prunedChild(child, e.references))
case g: Generate if !g.join && (g.child.outputSet -- g.references).nonEmpty =>
- g.copy(child = Project(g.references.toSeq, g.child))
+ g.copy(child = prunedChild(g.child, g.references))
+ // Turn off `join` for Generate if no column from it's child is used
case p @ Project(_, g: Generate) if g.join && p.references.subsetOf(g.generatedSet) =>
p.copy(child = g.copy(join = false))
- case p @ Project(projectList, g: Generate) if g.join =>
- val neededChildOutput = p.references -- g.generatorOutput ++ g.references
- if (neededChildOutput == g.child.outputSet) {
- p
+ // Eliminate unneeded attributes from right side of a LeftSemiJoin.
+ case j @ Join(left, right, LeftSemi, condition) =>
+ j.copy(right = prunedChild(right, j.references))
+
+ // all the columns will be used to compare, so we can't prune them
+ case p @ Project(_, _: SetOperation) => p
+ case p @ Project(_, _: Distinct) => p
+ // Eliminate unneeded attributes from children of Union.
+ case p @ Project(_, u: Union) =>
+ if ((u.outputSet -- p.references).nonEmpty) {
+ val firstChild = u.children.head
+ val newOutput = prunedChild(firstChild, p.references).output
+ // pruning the columns of all children based on the pruned first child.
+ val newChildren = u.children.map { p =>
+ val selected = p.output.zipWithIndex.filter { case (a, i) =>
+ newOutput.contains(firstChild.output(i))
+ }.map(_._1)
+ Project(selected, p)
+ }
+ p.copy(child = u.withNewChildren(newChildren))
} else {
- Project(projectList, g.copy(child = Project(neededChildOutput.toSeq, g.child)))
+ p
}
- case p @ Project(projectList, a @ Aggregate(groupingExpressions, aggregateExpressions, child))
- if (a.outputSet -- p.references).nonEmpty =>
- Project(
- projectList,
- Aggregate(
- groupingExpressions,
- aggregateExpressions.filter(e => p.references.contains(e)),
- child))
-
- // Eliminate unneeded attributes from either side of a Join.
- case Project(projectList, Join(left, right, joinType, condition)) =>
- // Collect the list of all references required either above or to evaluate the condition.
- val allReferences: AttributeSet =
- AttributeSet(
- projectList.flatMap(_.references.iterator)) ++
- condition.map(_.references).getOrElse(AttributeSet(Seq.empty))
-
- /** Applies a projection only when the child is producing unnecessary attributes */
- def pruneJoinChild(c: LogicalPlan): LogicalPlan = prunedChild(c, allReferences)
+ // Can't prune the columns on LeafNode
+ case p @ Project(_, l: LeafNode) => p
- Project(projectList, Join(pruneJoinChild(left), pruneJoinChild(right), joinType, condition))
-
- // Eliminate unneeded attributes from right side of a LeftSemiJoin.
- case Join(left, right, LeftSemi, condition) =>
- // Collect the list of all references required to evaluate the condition.
- val allReferences: AttributeSet =
- condition.map(_.references).getOrElse(AttributeSet(Seq.empty))
-
- Join(left, prunedChild(right, allReferences), LeftSemi, condition)
-
- // Push down project through limit, so that we may have chance to push it further.
- case Project(projectList, Limit(exp, child)) =>
- Limit(exp, Project(projectList, child))
-
- // Push down project if possible when the child is sort.
- case p @ Project(projectList, s @ Sort(_, _, grandChild)) =>
- if (s.references.subsetOf(p.outputSet)) {
- s.copy(child = Project(projectList, grandChild))
+ // Eliminate no-op Projects
+ case p @ Project(projectList, child) if child.output == p.output => child
+
+ // for all other logical plans that inherits the output from it's children
+ case p @ Project(_, child) =>
+ val required = child.references ++ p.references
+ if ((child.inputSet -- required).nonEmpty) {
+ val newChildren = child.children.map(c => prunedChild(c, required))
+ p.copy(child = child.withNewChildren(newChildren))
} else {
- val neededReferences = s.references ++ p.references
- if (neededReferences == grandChild.outputSet) {
- // No column we can prune, return the original plan.
- p
- } else {
- // Do not use neededReferences.toSeq directly, should respect grandChild's output order.
- val newProjectList = grandChild.output.filter(neededReferences.contains)
- p.copy(child = s.copy(child = Project(newProjectList, grandChild)))
- }
+ p
}
-
- // Eliminate no-op Projects
- case Project(projectList, child) if child.output == projectList => child
}
/** Applies a projection only when the child is producing unnecessary attributes */
private def prunedChild(c: LogicalPlan, allReferences: AttributeSet) =
if ((c.outputSet -- allReferences.filter(c.outputSet.contains)).nonEmpty) {
- Project(allReferences.filter(c.outputSet.contains).toSeq, c)
+ Project(c.output.filter(allReferences.contains), c)
} else {
c
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
index c890fffc40..715d01a3cd 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/ColumnPruningSuite.scala
@@ -17,9 +17,10 @@
package org.apache.spark.sql.catalyst.optimizer
+import org.apache.spark.sql.catalyst.analysis
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.dsl.plans._
-import org.apache.spark.sql.catalyst.expressions.{Explode, Literal}
+import org.apache.spark.sql.catalyst.expressions.{Ascending, Explode, Literal, SortOrder}
import org.apache.spark.sql.catalyst.plans.PlanTest
import org.apache.spark.sql.catalyst.plans.logical._
import org.apache.spark.sql.catalyst.rules.RuleExecutor
@@ -119,11 +120,134 @@ class ColumnPruningSuite extends PlanTest {
Seq('c, Literal.create(null, StringType), 1),
Seq('c, 'a, 2)),
Seq('c, 'aa.int, 'gid.int),
- Project(Seq('c, 'a),
+ Project(Seq('a, 'c),
input))).analyze
comparePlans(optimized, expected)
}
+ test("Column pruning on Filter") {
+ val input = LocalRelation('a.int, 'b.string, 'c.double)
+ val query = Project('a :: Nil, Filter('c > Literal(0.0), input)).analyze
+ val expected =
+ Project('a :: Nil,
+ Filter('c > Literal(0.0),
+ Project(Seq('a, 'c), input))).analyze
+ comparePlans(Optimize.execute(query), expected)
+ }
+
+ test("Column pruning on except/intersect/distinct") {
+ val input = LocalRelation('a.int, 'b.string, 'c.double)
+ val query = Project('a :: Nil, Except(input, input)).analyze
+ comparePlans(Optimize.execute(query), query)
+
+ val query2 = Project('a :: Nil, Intersect(input, input)).analyze
+ comparePlans(Optimize.execute(query2), query2)
+ val query3 = Project('a :: Nil, Distinct(input)).analyze
+ comparePlans(Optimize.execute(query3), query3)
+ }
+
+ test("Column pruning on Project") {
+ val input = LocalRelation('a.int, 'b.string, 'c.double)
+ val query = Project('a :: Nil, Project(Seq('a, 'b), input)).analyze
+ val expected = Project(Seq('a), input).analyze
+ comparePlans(Optimize.execute(query), expected)
+ }
+
+ test("column pruning for group") {
+ val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
+ val originalQuery =
+ testRelation
+ .groupBy('a)('a, count('b))
+ .select('a)
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+ val correctAnswer =
+ testRelation
+ .select('a)
+ .groupBy('a)('a).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("column pruning for group with alias") {
+ val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
+
+ val originalQuery =
+ testRelation
+ .groupBy('a)('a as 'c, count('b))
+ .select('c)
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+ val correctAnswer =
+ testRelation
+ .select('a)
+ .groupBy('a)('a as 'c).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("column pruning for Project(ne, Limit)") {
+ val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
+
+ val originalQuery =
+ testRelation
+ .select('a, 'b)
+ .limit(2)
+ .select('a)
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+ val correctAnswer =
+ testRelation
+ .select('a)
+ .limit(2).analyze
+
+ comparePlans(optimized, correctAnswer)
+ }
+
+ test("push down project past sort") {
+ val testRelation = LocalRelation('a.int, 'b.int, 'c.int)
+ val x = testRelation.subquery('x)
+
+ // push down valid
+ val originalQuery = {
+ x.select('a, 'b)
+ .sortBy(SortOrder('a, Ascending))
+ .select('a)
+ }
+
+ val optimized = Optimize.execute(originalQuery.analyze)
+ val correctAnswer =
+ x.select('a)
+ .sortBy(SortOrder('a, Ascending)).analyze
+
+ comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer))
+
+ // push down invalid
+ val originalQuery1 = {
+ x.select('a, 'b)
+ .sortBy(SortOrder('a, Ascending))
+ .select('b)
+ }
+
+ val optimized1 = Optimize.execute(originalQuery1.analyze)
+ val correctAnswer1 =
+ x.select('a, 'b)
+ .sortBy(SortOrder('a, Ascending))
+ .select('b).analyze
+
+ comparePlans(optimized1, analysis.EliminateSubqueryAliases(correctAnswer1))
+ }
+
+ test("Column pruning on Union") {
+ val input1 = LocalRelation('a.int, 'b.string, 'c.double)
+ val input2 = LocalRelation('c.int, 'd.string, 'e.double)
+ val query = Project('b :: Nil,
+ Union(input1 :: input2 :: Nil)).analyze
+ val expected = Project('b :: Nil,
+ Union(Project('b :: Nil, input1) :: Project('d :: Nil, input2) :: Nil)).analyze
+ comparePlans(Optimize.execute(query), expected)
+ }
+
// todo: add more tests for column pruning
}
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
index 70b34cbb24..7d60862f5a 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/FilterPushdownSuite.scala
@@ -41,7 +41,6 @@ class FilterPushdownSuite extends PlanTest {
PushPredicateThroughJoin,
PushPredicateThroughGenerate,
PushPredicateThroughAggregate,
- ColumnPruning,
CollapseProject) :: Nil
}
@@ -65,52 +64,6 @@ class FilterPushdownSuite extends PlanTest {
comparePlans(optimized, correctAnswer)
}
- test("column pruning for group") {
- val originalQuery =
- testRelation
- .groupBy('a)('a, count('b))
- .select('a)
-
- val optimized = Optimize.execute(originalQuery.analyze)
- val correctAnswer =
- testRelation
- .select('a)
- .groupBy('a)('a).analyze
-
- comparePlans(optimized, correctAnswer)
- }
-
- test("column pruning for group with alias") {
- val originalQuery =
- testRelation
- .groupBy('a)('a as 'c, count('b))
- .select('c)
-
- val optimized = Optimize.execute(originalQuery.analyze)
- val correctAnswer =
- testRelation
- .select('a)
- .groupBy('a)('a as 'c).analyze
-
- comparePlans(optimized, correctAnswer)
- }
-
- test("column pruning for Project(ne, Limit)") {
- val originalQuery =
- testRelation
- .select('a, 'b)
- .limit(2)
- .select('a)
-
- val optimized = Optimize.execute(originalQuery.analyze)
- val correctAnswer =
- testRelation
- .select('a)
- .limit(2).analyze
-
- comparePlans(optimized, correctAnswer)
- }
-
// After this line is unimplemented.
test("simple push down") {
val originalQuery =
@@ -604,39 +557,6 @@ class FilterPushdownSuite extends PlanTest {
comparePlans(optimized, originalQuery)
}
- test("push down project past sort") {
- val x = testRelation.subquery('x)
-
- // push down valid
- val originalQuery = {
- x.select('a, 'b)
- .sortBy(SortOrder('a, Ascending))
- .select('a)
- }
-
- val optimized = Optimize.execute(originalQuery.analyze)
- val correctAnswer =
- x.select('a)
- .sortBy(SortOrder('a, Ascending)).analyze
-
- comparePlans(optimized, analysis.EliminateSubqueryAliases(correctAnswer))
-
- // push down invalid
- val originalQuery1 = {
- x.select('a, 'b)
- .sortBy(SortOrder('a, Ascending))
- .select('b)
- }
-
- val optimized1 = Optimize.execute(originalQuery1.analyze)
- val correctAnswer1 =
- x.select('a, 'b)
- .sortBy(SortOrder('a, Ascending))
- .select('b).analyze
-
- comparePlans(optimized1, analysis.EliminateSubqueryAliases(correctAnswer1))
- }
-
test("push project and filter down into sample") {
val x = testRelation.subquery('x)
val originalQuery =
diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
index 1ab53a1257..2f382bbda0 100644
--- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
+++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/JoinOptimizationSuite.scala
@@ -108,7 +108,7 @@ class JoinOptimizationSuite extends PlanTest {
Project(Seq($"x.key", $"y.key"),
Join(
Project(Seq($"x.key"), SubqueryAlias("x", input)),
- Project(Seq($"y.key"), BroadcastHint(SubqueryAlias("y", input))),
+ BroadcastHint(Project(Seq($"y.key"), SubqueryAlias("y", input))),
Inner, None)).analyze
comparePlans(optimized, expected)
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
index 4db88a09d8..6bc4649d43 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Generate.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.execution
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.execution.metric.SQLMetrics
/**
* For lazy computing, be sure the generator.terminate() called in the very last
@@ -54,17 +55,19 @@ case class Generate(
child: SparkPlan)
extends UnaryNode {
+ private[sql] override lazy val metrics = Map(
+ "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+
override def expressions: Seq[Expression] = generator :: Nil
val boundGenerator = BindReferences.bindReference(generator, child.output)
protected override def doExecute(): RDD[InternalRow] = {
// boundGenerator.terminate() should be triggered after all of the rows in the partition
- if (join) {
+ val rows = if (join) {
child.execute().mapPartitionsInternal { iter =>
- val generatorNullRow = InternalRow.fromSeq(Seq.fill[Any](generator.elementTypes.size)(null))
+ val generatorNullRow = new GenericInternalRow(generator.elementTypes.size)
val joinedRow = new JoinedRow
- val proj = UnsafeProjection.create(output, output)
iter.flatMap { row =>
// we should always set the left (child output)
@@ -73,19 +76,26 @@ case class Generate(
if (outer && outputRows.isEmpty) {
joinedRow.withRight(generatorNullRow) :: Nil
} else {
- outputRows.map(or => joinedRow.withRight(or))
+ outputRows.map(joinedRow.withRight)
}
- } ++ LazyIterator(() => boundGenerator.terminate()).map { row =>
+ } ++ LazyIterator(boundGenerator.terminate).map { row =>
// we leave the left side as the last element of its child output
// keep it the same as Hive does
- proj(joinedRow.withRight(row))
+ joinedRow.withRight(row)
}
}
} else {
child.execute().mapPartitionsInternal { iter =>
- val proj = UnsafeProjection.create(output, output)
- (iter.flatMap(row => boundGenerator.eval(row)) ++
- LazyIterator(() => boundGenerator.terminate())).map(proj)
+ iter.flatMap(boundGenerator.eval) ++ LazyIterator(boundGenerator.terminate)
+ }
+ }
+
+ val numOutputRows = longMetric("numOutputRows")
+ rows.mapPartitionsInternal { iter =>
+ val proj = UnsafeProjection.create(output, output)
+ iter.map { r =>
+ numOutputRows += 1
+ proj(r)
}
}
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala
index 4858140229..22d4278085 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/columnar/InMemoryColumnarTableScan.scala
@@ -26,7 +26,8 @@ import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.MultiInstanceRelation
import org.apache.spark.sql.catalyst.dsl.expressions._
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Statistics}
+import org.apache.spark.sql.catalyst.plans.logical
+import org.apache.spark.sql.catalyst.plans.logical.Statistics
import org.apache.spark.sql.catalyst.plans.physical.Partitioning
import org.apache.spark.sql.execution.{LeafNode, SparkPlan}
import org.apache.spark.sql.execution.metric.SQLMetrics
@@ -63,7 +64,7 @@ private[sql] case class InMemoryRelation(
@transient private[sql] var _cachedColumnBuffers: RDD[CachedBatch] = null,
@transient private[sql] var _statistics: Statistics = null,
private[sql] var _batchStats: Accumulable[ArrayBuffer[InternalRow], InternalRow] = null)
- extends LogicalPlan with MultiInstanceRelation {
+ extends logical.LeafNode with MultiInstanceRelation {
override def producedAttributes: AttributeSet = outputSet
@@ -184,8 +185,6 @@ private[sql] case class InMemoryRelation(
_cachedColumnBuffers, statisticsToBePropagated, batchStats)
}
- override def children: Seq[LogicalPlan] = Seq.empty
-
override def newInstance(): this.type = {
new InMemoryRelation(
output.map(_.newInstance()),
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
index 4930c485da..b8d1b5a6ae 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/DataFrameSuite.scala
@@ -194,6 +194,14 @@ class DataFrameSuite extends QueryTest with SharedSQLContext {
Row("a", Seq("a"), 1) :: Nil)
}
+ test("sort after generate with join=true") {
+ val df = Seq((Array("a"), 1)).toDF("a", "b")
+
+ checkAnswer(
+ df.select($"*", explode($"a").as("c")).sortWithinPartitions("b", "c"),
+ Row(Seq("a"), 1, "a") :: Nil)
+ }
+
test("selectExpr") {
checkAnswer(
testData.selectExpr("abs(key)", "value"),