diff options
5 files changed, 122 insertions, 35 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java index 27ae62f121..0ad0f4976c 100644 --- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java +++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java @@ -36,7 +36,7 @@ import org.apache.spark.util.collection.unsafe.sort.RecordComparator; import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter; import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator; -final class UnsafeExternalRowSorter { +public final class UnsafeExternalRowSorter { /** * If positive, forces records to be spilled to disk at the given frequency (measured in numbers @@ -84,8 +84,7 @@ final class UnsafeExternalRowSorter { testSpillFrequency = frequency; } - @VisibleForTesting - void insertRow(UnsafeRow row) throws IOException { + public void insertRow(UnsafeRow row) throws IOException { final long prefix = prefixComputer.computePrefix(row); sorter.insertRecord( row.getBaseObject(), @@ -110,8 +109,7 @@ final class UnsafeExternalRowSorter { sorter.cleanupResources(); } - @VisibleForTesting - Iterator<UnsafeRow> sort() throws IOException { + public Iterator<UnsafeRow> sort() throws IOException { try { final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator(); if (!sortedIterator.hasNext()) { @@ -160,7 +158,6 @@ final class UnsafeExternalRowSorter { } } - public Iterator<UnsafeRow> sort(Iterator<UnsafeRow> inputIterator) throws IOException { while (inputIterator.hasNext()) { insertRow(inputIterator.next()); diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala index 75cb6d1137..2ea889ea72 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala @@ -17,10 +17,12 @@ package org.apache.spark.sql.execution -import org.apache.spark.{InternalAccumulator, SparkEnv, TaskContext} +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.executor.TaskMetrics import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.plans.physical.{Distribution, OrderedDistribution, UnspecifiedDistribution} import org.apache.spark.sql.execution.metric.SQLMetrics @@ -37,7 +39,7 @@ case class Sort( global: Boolean, child: SparkPlan, testSpillFrequency: Int = 0) - extends UnaryNode { + extends UnaryNode with CodegenSupport { override def output: Seq[Attribute] = child.output @@ -50,34 +52,36 @@ case class Sort( "dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"), "spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size")) - protected override def doExecute(): RDD[InternalRow] = { - val schema = child.schema - val childOutput = child.output + def createSorter(): UnsafeExternalRowSorter = { + val ordering = newOrdering(sortOrder, output) + + // The comparator for comparing prefix + val boundSortExpression = BindReferences.bindReference(sortOrder.head, output) + val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression) + + // The generator for prefix + val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression))) + val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { + override def computePrefix(row: InternalRow): Long = { + prefixProjection.apply(row).getLong(0) + } + } + val pageSize = SparkEnv.get.memoryManager.pageSizeBytes + val sorter = new UnsafeExternalRowSorter( + schema, ordering, prefixComparator, prefixComputer, pageSize) + if (testSpillFrequency > 0) { + sorter.setTestSpillFrequency(testSpillFrequency) + } + sorter + } + + protected override def doExecute(): RDD[InternalRow] = { val dataSize = longMetric("dataSize") val spillSize = longMetric("spillSize") child.execute().mapPartitionsInternal { iter => - val ordering = newOrdering(sortOrder, childOutput) - - // The comparator for comparing prefix - val boundSortExpression = BindReferences.bindReference(sortOrder.head, childOutput) - val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression) - - // The generator for prefix - val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression))) - val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer { - override def computePrefix(row: InternalRow): Long = { - prefixProjection.apply(row).getLong(0) - } - } - - val pageSize = SparkEnv.get.memoryManager.pageSizeBytes - val sorter = new UnsafeExternalRowSorter( - schema, ordering, prefixComparator, prefixComputer, pageSize) - if (testSpillFrequency > 0) { - sorter.setTestSpillFrequency(testSpillFrequency) - } + val sorter = createSorter() val metrics = TaskContext.get().taskMetrics() // Remember spill data size of this task before execute this operator so that we can @@ -93,4 +97,74 @@ case class Sort( sortedIterator } } + + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() + } + + // Name of sorter variable used in codegen. + private var sorterVariable: String = _ + + override protected def doProduce(ctx: CodegenContext): String = { + val needToSort = ctx.freshName("needToSort") + ctx.addMutableState("boolean", needToSort, s"$needToSort = true;") + + + // Initialize the class member variables. This includes the instance of the Sorter and + // the iterator to return sorted rows. + val thisPlan = ctx.addReferenceObj("plan", this) + sorterVariable = ctx.freshName("sorter") + ctx.addMutableState(classOf[UnsafeExternalRowSorter].getName, sorterVariable, + s"$sorterVariable = $thisPlan.createSorter();") + val metrics = ctx.freshName("metrics") + ctx.addMutableState(classOf[TaskMetrics].getName, metrics, + s"$metrics = org.apache.spark.TaskContext.get().taskMetrics();") + val sortedIterator = ctx.freshName("sortedIter") + ctx.addMutableState("scala.collection.Iterator<UnsafeRow>", sortedIterator, "") + + val addToSorter = ctx.freshName("addToSorter") + ctx.addNewFunction(addToSorter, + s""" + | private void $addToSorter() throws java.io.IOException { + | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)} + | } + """.stripMargin.trim) + + val outputRow = ctx.freshName("outputRow") + val dataSize = metricTerm(ctx, "dataSize") + val spillSize = metricTerm(ctx, "spillSize") + val spillSizeBefore = ctx.freshName("spillSizeBefore") + s""" + | if ($needToSort) { + | $addToSorter(); + | Long $spillSizeBefore = $metrics.memoryBytesSpilled(); + | $sortedIterator = $sorterVariable.sort(); + | $dataSize.add($sorterVariable.getPeakMemoryUsage()); + | $spillSize.add($metrics.memoryBytesSpilled() - $spillSizeBefore); + | $metrics.incPeakExecutionMemory($sorterVariable.getPeakMemoryUsage()); + | $needToSort = false; + | } + | + | while ($sortedIterator.hasNext()) { + | UnsafeRow $outputRow = (UnsafeRow)$sortedIterator.next(); + | ${consume(ctx, null, outputRow)} + | if (shouldStop()) return; + | } + """.stripMargin.trim + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = { + val colExprs = child.output.zipWithIndex.map { case (attr, i) => + BoundReference(i, attr.dataType, attr.nullable) + } + + ctx.currentVars = input + val code = GenerateUnsafeProjection.createCode(ctx, colExprs) + + s""" + | // Convert the input attributes to an UnsafeRow and add it to the sorter + | ${code.code} + | $sorterVariable.insertRow(${code.value}); + """.stripMargin.trim + } } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index afaddcf357..cb68ca6ada 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -287,7 +287,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) ${code.trim} } } - """ + """.trim // try to compile, helpful for debug val cleanedSource = CodeFormatter.stripExtraNewLines(source) @@ -338,7 +338,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) // There is an UnsafeRow already s""" |append($row.copy()); - """.stripMargin + """.stripMargin.trim } else { assert(input != null) if (input.nonEmpty) { @@ -351,12 +351,12 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan]) s""" |${code.code.trim} |append(${code.value}.copy()); - """.stripMargin + """.stripMargin.trim } else { // There is no columns s""" |append(unsafeRow); - """.stripMargin + """.stripMargin.trim } } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala index 9350205d79..de371d85d9 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala @@ -69,4 +69,13 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext { p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[BroadcastHashJoin]).isDefined) assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2"))) } + + test("Sort should be included in WholeStageCodegen") { + val df = sqlContext.range(3, 0, -1).sort(col("id")) + val plan = df.queryExecution.executedPlan + assert(plan.find(p => + p.isInstanceOf[WholeStageCodegen] && + p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[Sort]).isDefined) + assert(df.collect() === Array(Row(1), Row(2), Row(3))) + } } diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala index c49f2439fc..5b4f6f1d24 100644 --- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala +++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala @@ -154,6 +154,13 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext { ) } + test("Sort metrics") { + // Assume the execution plan is + // WholeStageCodegen(nodeId = 0, Range(nodeId = 2) -> Sort(nodeId = 1)) + val df = sqlContext.range(10).sort('id) + testSparkPlanMetrics(df, 2, Map.empty) + } + test("SortMergeJoin metrics") { // Because SortMergeJoin may skip different rows if the number of partitions is different, this // test should use the deterministic number of partitions. |