aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorSameer Agarwal <sameer@databricks.com>2016-02-29 12:59:46 -0800
committerYin Huai <yhuai@databricks.com>2016-02-29 12:59:46 -0800
commit4bd697da03079c26fd4409dc128dbff28c737701 (patch)
tree25ed22ebe5f0d54fd29fa5644a53ddc6fa203a84
parent644dbb641afd337ca39733da5153239cf39cdd81 (diff)
downloadspark-4bd697da03079c26fd4409dc128dbff28c737701.tar.gz
spark-4bd697da03079c26fd4409dc128dbff28c737701.tar.bz2
spark-4bd697da03079c26fd4409dc128dbff28c737701.zip
[SPARK-13123][SQL] Implement whole state codegen for sort
## What changes were proposed in this pull request? This PR adds support for implementing whole state codegen for sort. Builds heaving on nongli 's PR: https://github.com/apache/spark/pull/11008 (which actually implements the feature), and adds the following changes on top: - [x] Generated code updates peak execution memory metrics - [x] Unit tests in `WholeStageCodegenSuite` and `SQLMetricsSuite` ## How was this patch tested? New unit tests in `WholeStageCodegenSuite` and `SQLMetricsSuite`. Further, all existing sort tests should pass. Author: Sameer Agarwal <sameer@databricks.com> Author: Nong Li <nong@databricks.com> Closes #11359 from sameeragarwal/sort-codegen.
-rw-r--r--sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java9
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala124
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala8
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala9
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala7
5 files changed, 122 insertions, 35 deletions
diff --git a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
index 27ae62f121..0ad0f4976c 100644
--- a/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
+++ b/sql/catalyst/src/main/java/org/apache/spark/sql/execution/UnsafeExternalRowSorter.java
@@ -36,7 +36,7 @@ import org.apache.spark.util.collection.unsafe.sort.RecordComparator;
import org.apache.spark.util.collection.unsafe.sort.UnsafeExternalSorter;
import org.apache.spark.util.collection.unsafe.sort.UnsafeSorterIterator;
-final class UnsafeExternalRowSorter {
+public final class UnsafeExternalRowSorter {
/**
* If positive, forces records to be spilled to disk at the given frequency (measured in numbers
@@ -84,8 +84,7 @@ final class UnsafeExternalRowSorter {
testSpillFrequency = frequency;
}
- @VisibleForTesting
- void insertRow(UnsafeRow row) throws IOException {
+ public void insertRow(UnsafeRow row) throws IOException {
final long prefix = prefixComputer.computePrefix(row);
sorter.insertRecord(
row.getBaseObject(),
@@ -110,8 +109,7 @@ final class UnsafeExternalRowSorter {
sorter.cleanupResources();
}
- @VisibleForTesting
- Iterator<UnsafeRow> sort() throws IOException {
+ public Iterator<UnsafeRow> sort() throws IOException {
try {
final UnsafeSorterIterator sortedIterator = sorter.getSortedIterator();
if (!sortedIterator.hasNext()) {
@@ -160,7 +158,6 @@ final class UnsafeExternalRowSorter {
}
}
-
public Iterator<UnsafeRow> sort(Iterator<UnsafeRow> inputIterator) throws IOException {
while (inputIterator.hasNext()) {
insertRow(inputIterator.next());
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
index 75cb6d1137..2ea889ea72 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Sort.scala
@@ -17,10 +17,12 @@
package org.apache.spark.sql.execution
-import org.apache.spark.{InternalAccumulator, SparkEnv, TaskContext}
+import org.apache.spark.{SparkEnv, TaskContext}
+import org.apache.spark.executor.TaskMetrics
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.plans.physical.{Distribution, OrderedDistribution, UnspecifiedDistribution}
import org.apache.spark.sql.execution.metric.SQLMetrics
@@ -37,7 +39,7 @@ case class Sort(
global: Boolean,
child: SparkPlan,
testSpillFrequency: Int = 0)
- extends UnaryNode {
+ extends UnaryNode with CodegenSupport {
override def output: Seq[Attribute] = child.output
@@ -50,34 +52,36 @@ case class Sort(
"dataSize" -> SQLMetrics.createSizeMetric(sparkContext, "data size"),
"spillSize" -> SQLMetrics.createSizeMetric(sparkContext, "spill size"))
- protected override def doExecute(): RDD[InternalRow] = {
- val schema = child.schema
- val childOutput = child.output
+ def createSorter(): UnsafeExternalRowSorter = {
+ val ordering = newOrdering(sortOrder, output)
+
+ // The comparator for comparing prefix
+ val boundSortExpression = BindReferences.bindReference(sortOrder.head, output)
+ val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression)
+
+ // The generator for prefix
+ val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression)))
+ val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer {
+ override def computePrefix(row: InternalRow): Long = {
+ prefixProjection.apply(row).getLong(0)
+ }
+ }
+ val pageSize = SparkEnv.get.memoryManager.pageSizeBytes
+ val sorter = new UnsafeExternalRowSorter(
+ schema, ordering, prefixComparator, prefixComputer, pageSize)
+ if (testSpillFrequency > 0) {
+ sorter.setTestSpillFrequency(testSpillFrequency)
+ }
+ sorter
+ }
+
+ protected override def doExecute(): RDD[InternalRow] = {
val dataSize = longMetric("dataSize")
val spillSize = longMetric("spillSize")
child.execute().mapPartitionsInternal { iter =>
- val ordering = newOrdering(sortOrder, childOutput)
-
- // The comparator for comparing prefix
- val boundSortExpression = BindReferences.bindReference(sortOrder.head, childOutput)
- val prefixComparator = SortPrefixUtils.getPrefixComparator(boundSortExpression)
-
- // The generator for prefix
- val prefixProjection = UnsafeProjection.create(Seq(SortPrefix(boundSortExpression)))
- val prefixComputer = new UnsafeExternalRowSorter.PrefixComputer {
- override def computePrefix(row: InternalRow): Long = {
- prefixProjection.apply(row).getLong(0)
- }
- }
-
- val pageSize = SparkEnv.get.memoryManager.pageSizeBytes
- val sorter = new UnsafeExternalRowSorter(
- schema, ordering, prefixComparator, prefixComputer, pageSize)
- if (testSpillFrequency > 0) {
- sorter.setTestSpillFrequency(testSpillFrequency)
- }
+ val sorter = createSorter()
val metrics = TaskContext.get().taskMetrics()
// Remember spill data size of this task before execute this operator so that we can
@@ -93,4 +97,74 @@ case class Sort(
sortedIterator
}
}
+
+ override def upstreams(): Seq[RDD[InternalRow]] = {
+ child.asInstanceOf[CodegenSupport].upstreams()
+ }
+
+ // Name of sorter variable used in codegen.
+ private var sorterVariable: String = _
+
+ override protected def doProduce(ctx: CodegenContext): String = {
+ val needToSort = ctx.freshName("needToSort")
+ ctx.addMutableState("boolean", needToSort, s"$needToSort = true;")
+
+
+ // Initialize the class member variables. This includes the instance of the Sorter and
+ // the iterator to return sorted rows.
+ val thisPlan = ctx.addReferenceObj("plan", this)
+ sorterVariable = ctx.freshName("sorter")
+ ctx.addMutableState(classOf[UnsafeExternalRowSorter].getName, sorterVariable,
+ s"$sorterVariable = $thisPlan.createSorter();")
+ val metrics = ctx.freshName("metrics")
+ ctx.addMutableState(classOf[TaskMetrics].getName, metrics,
+ s"$metrics = org.apache.spark.TaskContext.get().taskMetrics();")
+ val sortedIterator = ctx.freshName("sortedIter")
+ ctx.addMutableState("scala.collection.Iterator<UnsafeRow>", sortedIterator, "")
+
+ val addToSorter = ctx.freshName("addToSorter")
+ ctx.addNewFunction(addToSorter,
+ s"""
+ | private void $addToSorter() throws java.io.IOException {
+ | ${child.asInstanceOf[CodegenSupport].produce(ctx, this)}
+ | }
+ """.stripMargin.trim)
+
+ val outputRow = ctx.freshName("outputRow")
+ val dataSize = metricTerm(ctx, "dataSize")
+ val spillSize = metricTerm(ctx, "spillSize")
+ val spillSizeBefore = ctx.freshName("spillSizeBefore")
+ s"""
+ | if ($needToSort) {
+ | $addToSorter();
+ | Long $spillSizeBefore = $metrics.memoryBytesSpilled();
+ | $sortedIterator = $sorterVariable.sort();
+ | $dataSize.add($sorterVariable.getPeakMemoryUsage());
+ | $spillSize.add($metrics.memoryBytesSpilled() - $spillSizeBefore);
+ | $metrics.incPeakExecutionMemory($sorterVariable.getPeakMemoryUsage());
+ | $needToSort = false;
+ | }
+ |
+ | while ($sortedIterator.hasNext()) {
+ | UnsafeRow $outputRow = (UnsafeRow)$sortedIterator.next();
+ | ${consume(ctx, null, outputRow)}
+ | if (shouldStop()) return;
+ | }
+ """.stripMargin.trim
+ }
+
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode]): String = {
+ val colExprs = child.output.zipWithIndex.map { case (attr, i) =>
+ BoundReference(i, attr.dataType, attr.nullable)
+ }
+
+ ctx.currentVars = input
+ val code = GenerateUnsafeProjection.createCode(ctx, colExprs)
+
+ s"""
+ | // Convert the input attributes to an UnsafeRow and add it to the sorter
+ | ${code.code}
+ | $sorterVariable.insertRow(${code.value});
+ """.stripMargin.trim
+ }
}
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index afaddcf357..cb68ca6ada 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -287,7 +287,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
${code.trim}
}
}
- """
+ """.trim
// try to compile, helpful for debug
val cleanedSource = CodeFormatter.stripExtraNewLines(source)
@@ -338,7 +338,7 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
// There is an UnsafeRow already
s"""
|append($row.copy());
- """.stripMargin
+ """.stripMargin.trim
} else {
assert(input != null)
if (input.nonEmpty) {
@@ -351,12 +351,12 @@ case class WholeStageCodegen(plan: CodegenSupport, children: Seq[SparkPlan])
s"""
|${code.code.trim}
|append(${code.value}.copy());
- """.stripMargin
+ """.stripMargin.trim
} else {
// There is no columns
s"""
|append(unsafeRow);
- """.stripMargin
+ """.stripMargin.trim
}
}
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
index 9350205d79..de371d85d9 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/WholeStageCodegenSuite.scala
@@ -69,4 +69,13 @@ class WholeStageCodegenSuite extends SparkPlanTest with SharedSQLContext {
p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[BroadcastHashJoin]).isDefined)
assert(df.collect() === Array(Row(1, 1, "1"), Row(1, 1, "1"), Row(2, 2, "2")))
}
+
+ test("Sort should be included in WholeStageCodegen") {
+ val df = sqlContext.range(3, 0, -1).sort(col("id"))
+ val plan = df.queryExecution.executedPlan
+ assert(plan.find(p =>
+ p.isInstanceOf[WholeStageCodegen] &&
+ p.asInstanceOf[WholeStageCodegen].plan.isInstanceOf[Sort]).isDefined)
+ assert(df.collect() === Array(Row(1), Row(2), Row(3)))
+ }
}
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
index c49f2439fc..5b4f6f1d24 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/metric/SQLMetricsSuite.scala
@@ -154,6 +154,13 @@ class SQLMetricsSuite extends SparkFunSuite with SharedSQLContext {
)
}
+ test("Sort metrics") {
+ // Assume the execution plan is
+ // WholeStageCodegen(nodeId = 0, Range(nodeId = 2) -> Sort(nodeId = 1))
+ val df = sqlContext.range(10).sort('id)
+ testSparkPlanMetrics(df, 2, Map.empty)
+ }
+
test("SortMergeJoin metrics") {
// Because SortMergeJoin may skip different rows if the number of partitions is different, this
// test should use the deterministic number of partitions.