diff options
author | Wenchen Fan <wenchen@databricks.com> | 2016-04-20 00:44:02 -0700 |
---|---|---|
committer | Reynold Xin <rxin@databricks.com> | 2016-04-20 00:44:02 -0700 |
commit | 7abe9a6578a99f4df50040d5cfe083c389c7b97f (patch) | |
tree | 59f4ae404f2a0de4fae55da0849103999f87446b /sql | |
parent | 14869ae64eb27830179d4954a5dc3e0a1e1330b4 (diff) | |
download | spark-7abe9a6578a99f4df50040d5cfe083c389c7b97f.tar.gz spark-7abe9a6578a99f4df50040d5cfe083c389c7b97f.tar.bz2 spark-7abe9a6578a99f4df50040d5cfe083c389c7b97f.zip |
[SPARK-9013][SQL] generate MutableProjection directly instead of return a function
`MutableProjection` is not thread-safe and we won't use it in multiple threads. I think the reason that we return `() => MutableProjection` is not about thread safety, but to save the costs of generating code when we need same but individual mutable projections.
However, I only found one place that use this [feature](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala#L122-L123), and comparing to the troubles it brings, I think we should generate `MutableProjection` directly instead of return a function.
Author: Wenchen Fan <wenchen@databricks.com>
Closes #7373 from cloud-fan/project.
Diffstat (limited to 'sql')
14 files changed, 35 insertions, 39 deletions
diff --git a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala index 7f840890f8..f143b40443 100644 --- a/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala +++ b/sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/codegen/GenerateMutableProjection.scala @@ -29,7 +29,7 @@ abstract class BaseMutableProjection extends MutableProjection * It exposes a `target` method, which is used to set the row that will be updated. * The internal [[MutableRow]] object created internally is used only when `target` is not used. */ -object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => MutableProjection] { +object GenerateMutableProjection extends CodeGenerator[Seq[Expression], MutableProjection] { protected def canonicalize(in: Seq[Expression]): Seq[Expression] = in.map(ExpressionCanonicalizer.execute) @@ -40,17 +40,17 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu def generate( expressions: Seq[Expression], inputSchema: Seq[Attribute], - useSubexprElimination: Boolean): (() => MutableProjection) = { + useSubexprElimination: Boolean): MutableProjection = { create(canonicalize(bind(expressions, inputSchema)), useSubexprElimination) } - protected def create(expressions: Seq[Expression]): (() => MutableProjection) = { + protected def create(expressions: Seq[Expression]): MutableProjection = { create(expressions, false) } private def create( expressions: Seq[Expression], - useSubexprElimination: Boolean): (() => MutableProjection) = { + useSubexprElimination: Boolean): MutableProjection = { val ctx = newCodeGenContext() val (validExpr, index) = expressions.zipWithIndex.filter { case (NoOp, _) => false @@ -136,8 +136,6 @@ object GenerateMutableProjection extends CodeGenerator[Seq[Expression], () => Mu logDebug(s"code for ${expressions.mkString(",")}:\n${CodeFormatter.format(code)}") val c = CodeGenerator.compile(code) - () => { - c.generate(ctx.references.toArray).asInstanceOf[MutableProjection] - } + c.generate(ctx.references.toArray).asInstanceOf[MutableProjection] } } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala index 94e676ded6..b682e7d2b1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/CodeGenerationSuite.scala @@ -50,7 +50,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { test("SPARK-8443: split wide projections into blocks due to JVM code size limit") { val length = 5000 val expressions = List.fill(length)(EqualTo(Literal(1), Literal(1))) - val plan = GenerateMutableProjection.generate(expressions)() + val plan = GenerateMutableProjection.generate(expressions) val actual = plan(new GenericMutableRow(length)).toSeq(expressions.map(_.dataType)) val expected = Seq.fill(length)(true) @@ -73,7 +73,7 @@ class CodeGenerationSuite extends SparkFunSuite with ExpressionEvalHelper { val expression = CaseWhen((1 to cases).map(generateCase(_))) - val plan = GenerateMutableProjection.generate(Seq(expression))() + val plan = GenerateMutableProjection.generate(Seq(expression)) val input = new GenericMutableRow(Array[Any](UTF8String.fromString(s"${clauses}:${cases}"))) val actual = plan(input).toSeq(Seq(expression.dataType)) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala index faa90fb1c5..8a9617cfbf 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/ExpressionEvalHelper.scala @@ -110,7 +110,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { inputRow: InternalRow = EmptyRow): Unit = { val plan = generateProject( - GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)(), + GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), expression) val actual = plan(inputRow).get(0, expression.dataType) @@ -166,7 +166,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { checkEvaluationWithOptimization(expression, expected) var plan = generateProject( - GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)(), + GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), expression) var actual = plan(inputRow).get(0, expression.dataType) assert(checkResult(actual, expected)) @@ -259,7 +259,7 @@ trait ExpressionEvalHelper extends GeneratorDrivenPropertyChecks { } val plan = generateProject( - GenerateMutableProjection.generate(Alias(expr, s"Optimized($expr)")() :: Nil)(), + GenerateMutableProjection.generate(Alias(expr, s"Optimized($expr)")() :: Nil), expr) val codegen = plan(inputRow).get(0, expr.dataType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala index 1e5b657f1f..f88c9e8df1 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/MathFunctionsSuite.scala @@ -138,7 +138,7 @@ class MathFunctionsSuite extends SparkFunSuite with ExpressionEvalHelper { inputRow: InternalRow = EmptyRow): Unit = { val plan = generateProject( - GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil)(), + GenerateMutableProjection.generate(Alias(expression, s"Optimized($expression)")() :: Nil), expression) val actual = plan(inputRow).get(0, expression.dataType) diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala index c9616cdb26..06dc3bd33b 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/CodegenExpressionCachingSuite.scala @@ -36,7 +36,7 @@ class CodegenExpressionCachingSuite extends SparkFunSuite { test("GenerateMutableProjection should initialize expressions") { val expr = And(NondeterministicExpression(), NondeterministicExpression()) - val instance = GenerateMutableProjection.generate(Seq(expr))() + val instance = GenerateMutableProjection.generate(Seq(expr)) assert(instance.apply(null).getBoolean(0) === false) } @@ -60,12 +60,12 @@ class CodegenExpressionCachingSuite extends SparkFunSuite { test("GenerateMutableProjection should not share expression instances") { val expr1 = MutableExpression() - val instance1 = GenerateMutableProjection.generate(Seq(expr1))() + val instance1 = GenerateMutableProjection.generate(Seq(expr1)) assert(instance1.apply(null).getBoolean(0) === false) val expr2 = MutableExpression() expr2.mutableState = true - val instance2 = GenerateMutableProjection.generate(Seq(expr2))() + val instance2 = GenerateMutableProjection.generate(Seq(expr2)) assert(instance1.apply(null).getBoolean(0) === false) assert(instance2.apply(null).getBoolean(0) === true) } diff --git a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala index e2a8eb8ee1..b69b74b424 100644 --- a/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala +++ b/sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/expressions/codegen/GeneratedProjectionSuite.scala @@ -76,7 +76,7 @@ class GeneratedProjectionSuite extends SparkFunSuite { val exprs = nestedSchema.fields.zipWithIndex.map { case (f, i) => BoundReference(i, f.dataType, true) } - val mutableProj = GenerateMutableProjection.generate(exprs)() + val mutableProj = GenerateMutableProjection.generate(exprs) val row1 = mutableProj(result) assert(result === row1) val row2 = mutableProj(result) diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala index 415cd4d84a..b64352a9e0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/SparkPlan.scala @@ -352,12 +352,10 @@ abstract class SparkPlan extends QueryPlan[SparkPlan] with Logging with Serializ } } - private[this] def isTesting: Boolean = sys.props.contains("spark.testing") - protected def newMutableProjection( expressions: Seq[Expression], inputSchema: Seq[Attribute], - useSubexprElimination: Boolean = false): () => MutableProjection = { + useSubexprElimination: Boolean = false): MutableProjection = { log.debug(s"Creating MutableProj: $expressions, inputSchema: $inputSchema") GenerateMutableProjection.generate(expressions, inputSchema, useSubexprElimination) } diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala index 8e9214fa25..85ce388de0 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/Window.scala @@ -120,14 +120,14 @@ case class Window( val (exprs, current, bound) = if (offset == 0) { // Use the entire order expression when the offset is 0. val exprs = orderSpec.map(_.child) - val projection = newMutableProjection(exprs, child.output) - (orderSpec, projection(), projection()) + val buildProjection = () => newMutableProjection(exprs, child.output) + (orderSpec, buildProjection(), buildProjection()) } else if (orderSpec.size == 1) { // Use only the first order expression when the offset is non-null. val sortExpr = orderSpec.head val expr = sortExpr.child // Create the projection which returns the current 'value'. - val current = newMutableProjection(expr :: Nil, child.output)() + val current = newMutableProjection(expr :: Nil, child.output) // Flip the sign of the offset when processing the order is descending val boundOffset = sortExpr.direction match { case Descending => -offset @@ -135,7 +135,7 @@ case class Window( } // Create the projection which returns the current 'value' modified by adding the offset. val boundExpr = Add(expr, Cast(Literal.create(boundOffset, IntegerType), expr.dataType)) - val bound = newMutableProjection(boundExpr :: Nil, child.output)() + val bound = newMutableProjection(boundExpr :: Nil, child.output) (sortExpr :: Nil, current, bound) } else { sys.error("Non-Zero range offsets are not supported for windows " + @@ -564,7 +564,7 @@ private[execution] final class OffsetWindowFunctionFrame( ordinal: Int, expressions: Array[Expression], inputSchema: Seq[Attribute], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => () => MutableProjection, + newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection, offset: Int) extends WindowFunctionFrame { /** Rows of the partition currently being processed. */ @@ -604,7 +604,7 @@ private[execution] final class OffsetWindowFunctionFrame( } // Create the projection. - newMutableProjection(boundExpressions, Nil)().target(target) + newMutableProjection(boundExpressions, Nil).target(target) } override def prepare(rows: RowBuffer): Unit = { @@ -886,7 +886,7 @@ private[execution] object AggregateProcessor { functions: Array[Expression], ordinal: Int, inputAttributes: Seq[Attribute], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => () => MutableProjection): + newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection): AggregateProcessor = { val aggBufferAttributes = mutable.Buffer.empty[AttributeReference] val initialValues = mutable.Buffer.empty[Expression] @@ -938,13 +938,13 @@ private[execution] object AggregateProcessor { // Create the projections. val initialProjection = newMutableProjection( initialValues, - partitionSize.toSeq)() + partitionSize.toSeq) val updateProjection = newMutableProjection( updateExpressions, - aggBufferAttributes ++ inputAttributes)() + aggBufferAttributes ++ inputAttributes) val evaluateProjection = newMutableProjection( evaluateExpressions, - aggBufferAttributes)() + aggBufferAttributes) // Create the processor new AggregateProcessor( diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala index 042c731901..81aacb437b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/AggregationIterator.scala @@ -39,7 +39,7 @@ abstract class AggregationIterator( aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection)) + newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection) extends Iterator[UnsafeRow] with Logging { /////////////////////////////////////////////////////////////////////////// @@ -139,7 +139,7 @@ abstract class AggregationIterator( // no-op expressions which are ignored during projection code-generation. case i: ImperativeAggregate => Seq.fill(i.aggBufferAttributes.length)(NoOp) } - newMutableProjection(initExpressions, Nil)() + newMutableProjection(initExpressions, Nil) } // All imperative AggregateFunctions. @@ -175,7 +175,7 @@ abstract class AggregationIterator( // This projection is used to merge buffer values for all expression-based aggregates. val aggregationBufferSchema = functions.flatMap(_.aggBufferAttributes) val updateProjection = - newMutableProjection(mergeExpressions, aggregationBufferSchema ++ inputAttributes)() + newMutableProjection(mergeExpressions, aggregationBufferSchema ++ inputAttributes) (currentBuffer: MutableRow, row: InternalRow) => { // Process all expression-based aggregate functions. @@ -211,7 +211,7 @@ abstract class AggregationIterator( case agg: AggregateFunction => NoOp } val aggregateResult = new SpecificMutableRow(aggregateAttributes.map(_.dataType)) - val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferAttributes)() + val expressionAggEvalProjection = newMutableProjection(evalExpressions, bufferAttributes) expressionAggEvalProjection.target(aggregateResult) val resultProjection = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala index de1491d357..c35d781d3e 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/SortBasedAggregationIterator.scala @@ -34,7 +34,7 @@ class SortBasedAggregationIterator( aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection, numOutputRows: LongSQLMetric) extends AggregationIterator( groupingExpressions, diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala index 89977f9e08..d4cef8f310 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregate.scala @@ -335,7 +335,7 @@ case class TungstenAggregate( val mergeProjection = newMutableProjection( mergeExpr, aggregateBufferAttributes ++ declFunctions.flatMap(_.inputAggBufferAttributes), - subexpressionEliminationEnabled)() + subexpressionEliminationEnabled) val joinedRow = new JoinedRow() var currentKey: UnsafeRow = null diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala index 09384a482d..c368726610 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/TungstenAggregationIterator.scala @@ -82,7 +82,7 @@ class TungstenAggregationIterator( aggregateAttributes: Seq[Attribute], initialInputBufferOffset: Int, resultExpressions: Seq[NamedExpression], - newMutableProjection: (Seq[Expression], Seq[Attribute]) => (() => MutableProjection), + newMutableProjection: (Seq[Expression], Seq[Attribute]) => MutableProjection, originalInputAttributes: Seq[Attribute], inputIter: Iterator[InternalRow], testFallbackStartsAt: Option[(Int, Int)], diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala index f5776e7b8d..4ceb710f4b 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/aggregate/udaf.scala @@ -361,7 +361,7 @@ private[sql] case class ScalaUDAF( val inputAttributes = childrenSchema.toAttributes log.debug( s"Creating MutableProj: $children, inputSchema: $inputAttributes.") - GenerateMutableProjection.generate(children, inputAttributes)() + GenerateMutableProjection.generate(children, inputAttributes) } private[this] lazy val inputToScalaConverters: Any => Any = diff --git a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala index c9ab40a0a9..c49f173ad6 100644 --- a/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala +++ b/sql/core/src/main/scala/org/apache/spark/sql/execution/python/BatchPythonEvaluation.scala @@ -86,7 +86,7 @@ case class BatchPythonEvaluation(udfs: Seq[PythonUDF], output: Seq[Attribute], c } }.toArray }.toArray - val projection = newMutableProjection(allInputs, child.output)() + val projection = newMutableProjection(allInputs, child.output) val schema = StructType(dataTypes.map(dt => StructField("", dt))) val needConversion = dataTypes.exists(EvaluatePython.needConversionInPython) |