diff options
author | Liang-Chi Hsieh <simonh@tw.ibm.com> | 2016-04-01 14:02:32 -0700 |
---|---|---|
committer | Davies Liu <davies.liu@gmail.com> | 2016-04-01 14:02:32 -0700 |
commit | 3e991dbc310a4a33eec7f3909adce50bf8268d04 (patch) | |
tree | fde4fe5795f815d06f5be22020618d52411ec0ae /sql/core/src/main | |
parent | 1b829ce13990b40fd8d7c9efcc2ae55c4dbc861c (diff) | |
download | spark-3e991dbc310a4a33eec7f3909adce50bf8268d04.tar.gz spark-3e991dbc310a4a33eec7f3909adce50bf8268d04.tar.bz2 spark-3e991dbc310a4a33eec7f3909adce50bf8268d04.zip |
[SPARK-13674] [SQL] Add wholestage codegen support to Sample
JIRA: https://issues.apache.org/jira/browse/SPARK-13674
## What changes were proposed in this pull request?
Sample operator doesn't support wholestage codegen now. This pr is to add support to it.
## How was this patch tested?
A test is added into `BenchmarkWholeStageCodegen`. Besides, all tests should be passed.
Author: Liang-Chi Hsieh <simonh@tw.ibm.com>
Author: Liang-Chi Hsieh <viirya@gmail.com>
Closes #11517 from viirya/add-wholestage-sample.
Diffstat (limited to 'sql/core/src/main')
3 files changed, 74 insertions, 14 deletions
diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java index dbea8521be..c2633a9f8c 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/BufferedRowIterator.java @@ -36,6 +36,8 @@ public abstract class BufferedRowIterator { protected UnsafeRow unsafeRow = new UnsafeRow(0); private long startTimeNs = System.nanoTime(); + protected int partitionIndex = -1; + public boolean hasNext() throws IOException { if (currentRows.isEmpty()) { processNext(); @@ -58,7 +60,7 @@ public abstract class BufferedRowIterator { /** * Initializes from array of iterators of InternalRow. */ - public abstract void init(Iterator<InternalRow> iters[]); + public abstract void init(int index, Iterator<InternalRow> iters[]); /** * Append a row to currentRows. diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala index 6a779abd40..9bdf611f6e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/WholeStageCodegen.scala @@ -17,7 +17,7 @@ package org.apache.spark.sql.execution -import org.apache.spark.broadcast +import org.apache.spark.{broadcast, TaskContext} import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ @@ -323,7 +323,8 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup this.references = references; } - public void init(scala.collection.Iterator inputs[]) { + public void init(int index, scala.collection.Iterator inputs[]) { + partitionIndex = index; ${ctx.initMutableStates()} } @@ -351,10 +352,10 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup val rdds = child.asInstanceOf[CodegenSupport].upstreams() assert(rdds.size <= 2, "Up to two upstream RDDs can be supported") if (rdds.length == 1) { - rdds.head.mapPartitions { iter => + rdds.head.mapPartitionsWithIndex { (index, iter) => val clazz = CodeGenerator.compile(cleanedSource) val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] - buffer.init(Array(iter)) + buffer.init(index, Array(iter)) new Iterator[InternalRow] { override def hasNext: Boolean = { val v = buffer.hasNext @@ -367,9 +368,10 @@ case class WholeStageCodegen(child: SparkPlan) extends UnaryNode with CodegenSup } else { // Right now, we support up to two upstreams. rdds.head.zipPartitions(rdds(1)) { (leftIter, rightIter) => + val partitionIndex = TaskContext.getPartitionId() val clazz = CodeGenerator.compile(cleanedSource) val buffer = clazz.generate(references).asInstanceOf[BufferedRowIterator] - buffer.init(Array(leftIter, rightIter)) + buffer.init(partitionIndex, Array(leftIter, rightIter)) new Iterator[InternalRow] { override def hasNext: Boolean = { val v = buffer.hasNext diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala index fca662760d..a6a14df6a3 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/basicOperators.scala @@ -20,11 +20,11 @@ package org.apache.spark.sql.execution import org.apache.spark.rdd.{PartitionwiseSampledRDD, RDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer} +import org.apache.spark.sql.catalyst.expressions.codegen.{CodegenContext, ExprCode, ExpressionCanonicalizer, GenerateUnsafeProjection} import org.apache.spark.sql.catalyst.plans.physical._ import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.sql.types.LongType -import org.apache.spark.util.random.PoissonSampler +import org.apache.spark.util.random.{BernoulliCellSampler, PoissonSampler} case class Project(projectList: Seq[NamedExpression], child: SparkPlan) extends UnaryNode with CodegenSupport { @@ -223,9 +223,12 @@ case class Sample( upperBound: Double, withReplacement: Boolean, seed: Long, - child: SparkPlan) extends UnaryNode { + child: SparkPlan) extends UnaryNode with CodegenSupport { override def output: Seq[Attribute] = child.output + private[sql] override lazy val metrics = Map( + "numOutputRows" -> SQLMetrics.createLongMetric(sparkContext, "number of output rows")) + protected override def doExecute(): RDD[InternalRow] = { if (withReplacement) { // Disable gap sampling since the gap sampling method buffers two rows internally, @@ -239,6 +242,63 @@ case class Sample( child.execute().randomSampleWithRange(lowerBound, upperBound, seed) } } + + override def upstreams(): Seq[RDD[InternalRow]] = { + child.asInstanceOf[CodegenSupport].upstreams() + } + + protected override def doProduce(ctx: CodegenContext): String = { + child.asInstanceOf[CodegenSupport].produce(ctx, this) + } + + override def doConsume(ctx: CodegenContext, input: Seq[ExprCode], row: ExprCode): String = { + val numOutput = metricTerm(ctx, "numOutputRows") + val sampler = ctx.freshName("sampler") + + if (withReplacement) { + val samplerClass = classOf[PoissonSampler[UnsafeRow]].getName + val initSampler = ctx.freshName("initSampler") + ctx.addMutableState(s"$samplerClass<UnsafeRow>", sampler, + s"$initSampler();") + + ctx.addNewFunction(initSampler, + s""" + | private void $initSampler() { + | $sampler = new $samplerClass<UnsafeRow>($upperBound - $lowerBound, false); + | java.util.Random random = new java.util.Random(${seed}L); + | long randomSeed = random.nextLong(); + | int loopCount = 0; + | while (loopCount < partitionIndex) { + | randomSeed = random.nextLong(); + | loopCount += 1; + | } + | $sampler.setSeed(randomSeed); + | } + """.stripMargin.trim) + + val samplingCount = ctx.freshName("samplingCount") + s""" + | int $samplingCount = $sampler.sample(); + | while ($samplingCount-- > 0) { + | $numOutput.add(1); + | ${consume(ctx, input)} + | } + """.stripMargin.trim + } else { + val samplerClass = classOf[BernoulliCellSampler[UnsafeRow]].getName + ctx.addMutableState(s"$samplerClass<UnsafeRow>", sampler, + s""" + | $sampler = new $samplerClass<UnsafeRow>($lowerBound, $upperBound, false); + | $sampler.setSeed(${seed}L + partitionIndex); + """.stripMargin.trim) + + s""" + | if ($sampler.sample() == 0) continue; + | $numOutput.add(1); + | ${consume(ctx, input)} + """.stripMargin.trim + } + } } case class Range( @@ -320,11 +380,7 @@ case class Range( | // initialize Range | if (!$initTerm) { | $initTerm = true; - | if ($input.hasNext()) { - | initRange(((InternalRow) $input.next()).getInt(0)); - | } else { - | return; - | } + | initRange(partitionIndex); | } | | while (!$overflow && $checkEnd) { |