aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorLiang-Chi Hsieh <simonh@tw.ibm.com>2016-04-01 14:02:32 -0700
committerDavies Liu <davies.liu@gmail.com>2016-04-01 14:02:32 -0700
commit3e991dbc310a4a33eec7f3909adce50bf8268d04 (patch)
treefde4fe5795f815d06f5be22020618d52411ec0ae
parent1b829ce13990b40fd8d7c9efcc2ae55c4dbc861c (diff)
downloadspark-3e991dbc310a4a33eec7f3909adce50bf8268d04.tar.gz
spark-3e991dbc310a4a33eec7f3909adce50bf8268d04.tar.bz2
spark-3e991dbc310a4a33eec7f3909adce50bf8268d04.zip
[SPARK-13674] [SQL] Add wholestage codegen support to Sample
JIRA: https://issues.apache.org/jira/browse/SPARK-13674 ## What changes were proposed in this pull request? Sample operator doesn't support wholestage codegen now. This pr is to add support to it. ## How was this patch tested? A test is added into `BenchmarkWholeStageCodegen`. Besides, all tests should be passed. Author: Liang-Chi Hsieh <simonh@tw.ibm.com> Author: Liang-Chi Hsieh <viirya@gmail.com> Closes #11517 from viirya/add-wholestage-sample.
-rw-r--r--core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala2
-rw-r--r--project/MimaExcludes.scala4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java4
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala12
-rw-r--r--sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala72
-rw-r--r--sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala25
6 files changed, 104 insertions, 15 deletions
diff --git a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
index 2921b939bc..d397cca4b4 100644
--- a/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
+++ b/core/src/main/scala/org/apache/spark/util/random/RandomSampler.scala
@@ -186,7 +186,7 @@ class BernoulliSampler[T: ClassTag](fraction: Double) extends RandomSampler[T, T
* @tparam T item type
*/
@DeveloperApi
-class PoissonSampler[T: ClassTag](
+class PoissonSampler[T](
fraction: Double,
useGapSamplingIfPossible: Boolean) extends RandomSampler[T, T] {
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index ff11775412..2be490b942 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -597,6 +597,10 @@ object MimaExcludes {
// for multilayer perceptron.
// This class is marked as `private`.
ProblemFilters.exclude[MissingClassProblem]("org.apache.spark.ml.ann.SoftmaxFunction")
+ ) ++ Seq(
+ // [SPARK-13674][SQL] Add wholestage codegen support to Sample
+ ProblemFilters.exclude[IncompatibleMethTypeProblem]("org.apache.spark.util.random.PoissonSampler.this"),
+ ProblemFilters.exclude[DirectMissingMethodProblem]("org.apache.spark.util.random.PoissonSampler.this")
)
case v if v.startsWith("1.6") =>
Seq(
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java
index dbea8521be..c2633a9f8c 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java
@@ -36,6 +36,8 @@ public abstract class BufferedRowIterator {
protected UnsafeRow unsafeRow = new UnsafeRow(0);
private long startTimeNs = System.nanoTime();
+ protected int partitionIndex = -1;
+
public boolean hasNext() throws IOException {
if (currentRows.isEmpty()) {
processNext();
@@ -58,7 +60,7 @@ public abstract class BufferedRowIterator {
/**
* Initializes from array of iterators of InternalRow.
*/
- public abstract void init(Iterator<InternalRow> iters[]);
+ public abstract void init(int index, Iterator<InternalRow> iters[]);
/**
* Append a row to currentRows.
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
index 6a779abd40..9bdf611f6e 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala
@@ -17,7 +17,7 @@
package org.apache.spark.sql.execution
-import org.apache.spark.broadcast
+import org.apache.spark.{broadcast, TaskContext}
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
@@ -323,7 +323,8 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
this.references = references;
}
- public void init(scala.collection.Iterator inputs[]) {
+ public void init(int index, scala.collection.Iterator inputs[]) {
+ partitionIndex = index;
${ctx.initMutableStates()}
}
@@ -351,10 +352,10 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
val rdds = child.asInstanceOf[CodegenSupport].upstreams()
assert(rdds.size <= 2, "Up to two upstream RDDs can be supported")
if (rdds.length == 1) {
- rdds.head.mapPartitions { iter =>
+ rdds.head.mapPartitionsWithIndex { (index, iter) =>
val clazz = CodeGenerator.compile(cleanedSource)
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
- buffer.init(Array(iter))
+ buffer.init(index, Array(iter))
new Iterator[InternalRow] {
override def hasNext: Boolean = {
val v = buffer.hasNext
@@ -367,9 +368,10 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup
} else {
// Right now, we support up to two upstreams.
rdds.head.zipPartitions(rdds(1)) { (leftIter, rightIter) =>
+ val partitionIndex = TaskContext.getPartitionId()
val clazz = CodeGenerator.compile(cleanedSource)
val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator]
- buffer.init(Array(leftIter, rightIter))
+ buffer.init(partitionIndex, Array(leftIter, rightIter))
new Iterator[InternalRow] {
override def hasNext: Boolean = {
val v = buffer.hasNext
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
index fca662760d..a6a14df6a3 100644
--- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
+++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala
@@ -20,11 +20,11 @@ package org.apache.spark.sql.execution
import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions._
-import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer}
+import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer, GenerateUnsafeProjection}
import org.apache.spark.sql.catalyst.plans.physical._
import org.apache.spark.sql.execution.metric.SQLMetrics
import org.apache.spark.sql.types.LongType
-import org.apache.spark.util.random.PoissonSampler
+import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler}
case class Project(projectList: Seq[NamedExpression], child: SparkPlan)
extends UnaryNode with CodegenSupport {
@@ -223,9 +223,12 @@ case class Sample(
upperBound: Double,
withReplacement: Boolean,
seed: Long,
- child: SparkPlan) extends UnaryNode {
+ child: SparkPlan) extends UnaryNode with CodegenSupport {
override def output: Seq[Attribute] = child.output
+ private[sql] override lazy val metrics = Map(
+ "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows"))
+
protected override def doExecute(): RDD[InternalRow] = {
if (withReplacement) {
// Disable gap sampling since the gap sampling method buffers two rows internally,
@@ -239,6 +242,63 @@ case class Sample(
child.execute().randomSampleWithRange(lowerBound, upperBound, seed)
}
}
+
+ override def upstreams(): Seq[RDD[InternalRow]] = {
+ child.asInstanceOf[CodegenSupport].upstreams()
+ }
+
+ protected override def doProduce(ctx: CodegenContext): String = {
+ child.asInstanceOf[CodegenSupport].produce(ctx, this)
+ }
+
+ override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = {
+ val numOutput = metricTerm(ctx, "numOutputRows")
+ val sampler = ctx.freshName("sampler")
+
+ if (withReplacement) {
+ val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName
+ val initSampler = ctx.freshName("initSampler")
+ ctx.addMutableState(s"$samplerClass<UnsafeRow>", sampler,
+ s"$initSampler();")
+
+ ctx.addNewFunction(initSampler,
+ s"""
+ | private void $initSampler() {
+ | $sampler = new $samplerClass<UnsafeRow>($upperBound - $lowerBound, false);
+ | java.util.Random random = new java.util.Random(${seed}L);
+ | long randomSeed = random.nextLong();
+ | int loopCount = 0;
+ | while (loopCount < partitionIndex) {
+ | randomSeed = random.nextLong();
+ | loopCount += 1;
+ | }
+ | $sampler.setSeed(randomSeed);
+ | }
+ """.stripMargin.trim)
+
+ val samplingCount = ctx.freshName("samplingCount")
+ s"""
+ | int $samplingCount = $sampler.sample();
+ | while ($samplingCount-- > 0) {
+ | $numOutput.add(1);
+ | ${consume(ctx, input)}
+ | }
+ """.stripMargin.trim
+ } else {
+ val samplerClass = classOf[BernoulliCellSampler[UnsafeRow]].getName
+ ctx.addMutableState(s"$samplerClass<UnsafeRow>", sampler,
+ s"""
+ | $sampler = new $samplerClass<UnsafeRow>($lowerBound, $upperBound, false);
+ | $sampler.setSeed(${seed}L + partitionIndex);
+ """.stripMargin.trim)
+
+ s"""
+ | if ($sampler.sample() == 0) continue;
+ | $numOutput.add(1);
+ | ${consume(ctx, input)}
+ """.stripMargin.trim
+ }
+ }
}
case class Range(
@@ -320,11 +380,7 @@ case class Range(
| // initialize Range
| if (!$initTerm) {
| $initTerm = true;
- | if ($input.hasNext()) {
- | initRange(((InternalRow) $input.next()).getInt(0));
- | } else {
- | return;
- | }
+ | initRange(partitionIndex);
| }
|
| while (!$overflow && $checkEnd) {
diff --git a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
index 003d3e062e..55906793c0 100644
--- a/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
+++ b/sql/core/src/test/scala/org/apache/spark/sql/execution/BenchmarkWholeStageCodegen.scala
@@ -85,6 +85,31 @@ class BenchmarkWholeStageCodegen extends SparkFunSuite {
*/
}
+ ignore("range/sample/sum") {
+ val N = 500 << 20
+ runBenchmark("range/sample/sum", N) {
+ sqlContext.range(N).sample(true, 0.01).groupBy().sum().collect()
+ }
+ /*
+ Westmere E56xx/L56xx/X56xx (Nehalem-C)
+ range/sample/sum: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ -------------------------------------------------------------------------------------------
+ range/sample/sum codegen=false 53888 / 56592 9.7 102.8 1.0X
+ range/sample/sum codegen=true 41614 / 42607 12.6 79.4 1.3X
+ */
+
+ runBenchmark("range/sample/sum", N) {
+ sqlContext.range(N).sample(false, 0.01).groupBy().sum().collect()
+ }
+ /*
+ Westmere E56xx/L56xx/X56xx (Nehalem-C)
+ range/sample/sum: Best/Avg Time(ms) Rate(M/s) Per Row(ns) Relative
+ -------------------------------------------------------------------------------------------
+ range/sample/sum codegen=false 12982 / 13384 40.4 24.8 1.0X
+ range/sample/sum codegen=true 7074 / 7383 74.1 13.5 1.8X
+ */
+ }
+
ignore("stat functions") {
val N = 100L << 20